用于预训练BERT的数据集

加载数据集可以用huggingface的,主要要改下面两个函数,python环境要加上huggingface-hub和datasets包,具体操作可以看huggingface下载和使用数据集的文档。
数据集地址: Salesforce/wikitext · Datasets at Hugging Face
使用文档: Using :hugs: Datasets · Hugging Face

为了做练习题1,我也把sentence_tokenizer参数化了。

from datasets import load_dataset


def load_data_wiki(batch_size, max_len, sentence_tokenizer=lambda para: para.strip().lower().split(' . ')):
    """加载WikiText-2数据集"""
    paragraphs = _read_wiki(sentence_tokenizer=sentence_tokenizer)
    train_set = _WikiTextDataset(paragraphs, max_len)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=0)
    return train_iter, train_set.vocab


def _read_wiki(dataset_type='train', sentence_tokenizer=lambda para: para.strip().lower().split(' . ')):
    """
    读取huggingface上的wikitext数据集
    :param dataset_type: 'train' or 'test' or 'validation'
    :return: 双重列表,段落然后句子
    """
    wiki_dataset = load_dataset('Salesforce/wikitext', 'wikitext-2-v1', cache_dir='../data')
    lines = wiki_dataset[dataset_type]['text']
    paragraphs = list(filter(lambda sens: len(sens) >= 2, map(sentence_tokenizer, lines)))  # 取一段里大于等于2句的,否则丢弃这段
    random.shuffle(paragraphs)
    return paragraphs

练习题1:

# * 练习: 使用nltk punkt语句次元分析器
import nltk

nltk.download('punkt')
nltk.download('punkt_tab')

sentences = 'This is great ! Why not ?'
print(nltk.tokenize.sent_tokenize(sentences))

train_iter, vocab = load_data_wiki(batch_size, max_len, nltk.tokenize.sent_tokenize)

len(vocab), vocab.to_tokens(8226)