加载数据集可以用huggingface的,主要要改下面两个函数,python环境要加上huggingface-hub和datasets包,具体操作可以看huggingface下载和使用数据集的文档。
数据集地址: Salesforce/wikitext · Datasets at Hugging Face
使用文档: Using
Datasets · Hugging Face
为了做练习题1,我也把sentence_tokenizer参数化了。
from datasets import load_dataset
def load_data_wiki(batch_size, max_len, sentence_tokenizer=lambda para: para.strip().lower().split(' . ')):
"""加载WikiText-2数据集"""
paragraphs = _read_wiki(sentence_tokenizer=sentence_tokenizer)
train_set = _WikiTextDataset(paragraphs, max_len)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=0)
return train_iter, train_set.vocab
def _read_wiki(dataset_type='train', sentence_tokenizer=lambda para: para.strip().lower().split(' . ')):
"""
读取huggingface上的wikitext数据集
:param dataset_type: 'train' or 'test' or 'validation'
:return: 双重列表,段落然后句子
"""
wiki_dataset = load_dataset('Salesforce/wikitext', 'wikitext-2-v1', cache_dir='../data')
lines = wiki_dataset[dataset_type]['text']
paragraphs = list(filter(lambda sens: len(sens) >= 2, map(sentence_tokenizer, lines))) # 取一段里大于等于2句的,否则丢弃这段
random.shuffle(paragraphs)
return paragraphs
练习题1:
# * 练习: 使用nltk punkt语句次元分析器
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
sentences = 'This is great ! Why not ?'
print(nltk.tokenize.sent_tokenize(sentences))
train_iter, vocab = load_data_wiki(batch_size, max_len, nltk.tokenize.sent_tokenize)
len(vocab), vocab.to_tokens(8226)