Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions test/data/test_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def test_wikitext2(self):
self.assertEqual(tokens_ids, [2, 286, 503, 700])

# Add test for the subset of the standard datasets
train_iter, valid_iter, test_iter = torchtext.experimental.datasets.raw.WikiText2(data_select=('train', 'valid', 'test'))
train_iter, valid_iter, test_iter = torchtext.experimental.datasets.raw.WikiText2(split=('train', 'valid', 'test'))
self._helper_test_func(len(train_iter), 36718, next(iter(train_iter)), ' \n')
self._helper_test_func(len(valid_iter), 3760, next(iter(valid_iter)), ' \n')
self._helper_test_func(len(test_iter), 4358, next(iter(test_iter)), ' \n')
del train_iter, valid_iter, test_iter
train_dataset, test_dataset = WikiText2(data_select=('train', 'test'))
train_dataset, test_dataset = WikiText2(split=('train', 'test'))
train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset)))
test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset)))
self._helper_test_func(len(train_data), 2049990, train_data[20:25],
Expand Down Expand Up @@ -105,14 +105,14 @@ def test_penntreebank(self):
self.assertEqual(tokens_ids, [2, 2550, 3344, 1125])

# Add test for the subset of the standard datasets
train_dataset, test_dataset = PennTreebank(data_select=('train', 'test'))
train_dataset, test_dataset = PennTreebank(split=('train', 'test'))
train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset)))
test_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, test_dataset)))
self._helper_test_func(len(train_data), 924412, train_data[20:25],
[9919, 9920, 9921, 9922, 9188])
self._helper_test_func(len(test_data), 82114, test_data[30:35],
[397, 93, 4, 16, 7])
train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(data_select=('train', 'test'))
train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(split=('train', 'test'))
self._helper_test_func(len(train_iter), 42068, next(iter(train_iter))[:15], ' aer banknote b')
self._helper_test_func(len(test_iter), 3761, next(iter(test_iter))[:25], " no it was n't black mond")
del train_iter, test_iter
Expand All @@ -130,7 +130,7 @@ def test_text_classification(self):
[2351, 758, 96, 38581, 2351, 220, 5, 396, 3, 14786])

# Add test for the subset of the standard datasets
train_dataset, = AG_NEWS(data_select=('train'))
train_dataset, = AG_NEWS(split=('train'))
self._helper_test_func(len(train_dataset), 120000, train_dataset[-1][1][:10],
[2155, 223, 2405, 30, 3010, 2204, 54, 3603, 4930, 2405])
train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS()
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_imdb(self):
new_train_data, new_test_data = IMDB(vocab=new_vocab)

# Add test for the subset of the standard datasets
train_dataset, = IMDB(data_select=('train'))
train_dataset, = IMDB(split=('train'))
self._helper_test_func(len(train_dataset), 25000, train_dataset[0][1][:10],
[13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92])
train_iter, test_iter = torchtext.experimental.datasets.raw.IMDB()
Expand Down Expand Up @@ -240,15 +240,15 @@ def test_multi30k(self):
[18, 24, 1168, 807, 16, 56, 83, 335, 1338])

# Add test for the subset of the standard datasets
train_iter, valid_iter = torchtext.experimental.datasets.raw.Multi30k(data_select=('train', 'valid'))
train_iter, valid_iter = torchtext.experimental.datasets.raw.Multi30k(split=('train', 'valid'))
self._helper_test_func(len(train_iter), 29000, ' '.join(next(iter(train_iter))),
' '.join(['Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n',
'Two young, White males are outside near many bushes.\n']))
self._helper_test_func(len(valid_iter), 1014, ' '.join(next(iter(valid_iter))),
' '.join(['Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen\n',
'A group of men are loading cotton onto a truck\n']))
del train_iter, valid_iter
train_dataset, = Multi30k(data_select=('train'))
train_dataset, = Multi30k(split=('train'))

# This change is due to the BC breaking in spacy 3.0
self._helper_test_func(len(train_dataset), 29000, train_dataset[20],
Expand Down Expand Up @@ -311,11 +311,11 @@ def test_udpos_sequence_tagging(self):
self.assertEqual(tokens_ids, [1206, 8, 69, 60, 157, 452])

# Add test for the subset of the standard datasets
train_dataset, = UDPOS(data_select=('train'))
train_dataset, = UDPOS(split=('train'))
self._helper_test_func(len(train_dataset), 12543, (train_dataset[0][0][:10], train_dataset[-1][2][:10]),
([262, 16, 5728, 45, 289, 701, 1160, 4436, 10660, 585],
[6, 20, 8, 10, 8, 8, 24, 13, 8, 15]))
train_iter, valid_iter = torchtext.experimental.datasets.raw.UDPOS(data_select=('train', 'valid'))
train_iter, valid_iter = torchtext.experimental.datasets.raw.UDPOS(split=('train', 'valid'))
self._helper_test_func(len(train_iter), 12543, ' '.join(next(iter(train_iter))[0][:5]),
' '.join(['Al', '-', 'Zaman', ':', 'American']))
self._helper_test_func(len(valid_iter), 2002, ' '.join(next(iter(valid_iter))[0][:5]),
Expand Down Expand Up @@ -358,7 +358,7 @@ def test_conll_sequence_tagging(self):
self.assertEqual(tokens_ids, [970, 5, 135, 43, 214, 690])

# Add test for the subset of the standard datasets
train_dataset, = CoNLL2000Chunking(data_select=('train'))
train_dataset, = CoNLL2000Chunking(split=('train'))
self._helper_test_func(len(train_dataset), 8936, (train_dataset[0][0][:10], train_dataset[0][1][:10],
train_dataset[0][2][:10], train_dataset[-1][0][:10],
train_dataset[-1][1][:10], train_dataset[-1][2][:10]),
Expand Down Expand Up @@ -393,7 +393,7 @@ def test_squad1(self):
new_train_data, new_test_data = SQuAD1(vocab=new_vocab)

# Add test for the subset of the standard datasets
train_dataset, = SQuAD1(data_select=('train'))
train_dataset, = SQuAD1(split=('train'))
context, question, answers, ans_pos = train_dataset[100]
self._helper_test_func(len(train_dataset), 87599, (question[:5], ans_pos[0]),
([7, 24, 86, 52, 2], [72, 72]))
Expand Down Expand Up @@ -422,7 +422,7 @@ def test_squad2(self):
new_train_data, new_test_data = SQuAD2(vocab=new_vocab)

# Add test for the subset of the standard datasets
train_dataset, = SQuAD2(data_select=('train'))
train_dataset, = SQuAD2(split=('train'))
context, question, answers, ans_pos = train_dataset[200]
self._helper_test_func(len(train_dataset), 130319, (question[:5], ans_pos[0]),
([84, 50, 1421, 12, 5439], [9, 9]))
Expand Down
4 changes: 2 additions & 2 deletions test/experimental/test_with_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def test_wikitext103(self):
self.assertEqual(tokens_ids, [2, 320, 437, 687])

# Add test for the subset of the standard datasets
train_dataset, test_dataset = torchtext.experimental.datasets.raw.WikiText103(data_select=('train', 'test'))
train_dataset, test_dataset = torchtext.experimental.datasets.raw.WikiText103(split=('train', 'test'))
self._helper_test_func(len(train_dataset), 1801350, next(iter(train_dataset)), ' \n')
self._helper_test_func(len(test_dataset), 4358, next(iter(test_dataset)), ' \n')
train_dataset, test_dataset = WikiText103(vocab=builtin_vocab, data_select=('train', 'test'))
train_dataset, test_dataset = WikiText103(vocab=builtin_vocab, split=('train', 'test'))
self._helper_test_func(len(train_dataset), 1801350, train_dataset[10][:5],
[2, 69, 12, 14, 265])
self._helper_test_func(len(test_dataset), 4358, test_dataset[28][:5],
Expand Down
52 changes: 26 additions & 26 deletions torchtext/experimental/datasets/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,19 @@ def get_vocab(self):
return self.vocab


def _setup_datasets(dataset_name, tokenizer, root, vocab, data_select, year, language):
def _setup_datasets(dataset_name, tokenizer, root, vocab, split, year, language):
if tokenizer is None:
tokenizer = get_tokenizer('basic_english')

data_select = check_default_set(data_select, ('train', 'test', 'valid'))
split = check_default_set(split, ('train', 'test', 'valid'))

if vocab is None:
if 'train' not in data_select:
if 'train' not in split:
raise TypeError("Must pass a vocab if train is not selected.")
if dataset_name == 'WMTNewsCrawl':
raw_train, = raw.DATASETS[dataset_name](root=root, data_select=('train',), year=year, language=language)
raw_train, = raw.DATASETS[dataset_name](root=root, split=('train',), year=year, language=language)
else:
raw_train, = raw.DATASETS[dataset_name](root=root, data_select=('train',))
raw_train, = raw.DATASETS[dataset_name](root=root, split=('train',))
logger_.info('Building Vocab based on train data')
vocab = build_vocab(raw_train, tokenizer)
logger_.info('Vocab has %d entries', len(vocab))
Expand All @@ -79,16 +79,16 @@ def text_transform(line):
return torch.tensor([vocab[token] for token in tokenizer(line)], dtype=torch.long)

if dataset_name == 'WMTNewsCrawl':
raw_datasets = raw.DATASETS[dataset_name](root=root, data_select=data_select, year=year, language=language)
raw_datasets = raw.DATASETS[dataset_name](root=root, split=split, year=year, language=language)
else:
raw_datasets = raw.DATASETS[dataset_name](root=root, data_select=data_select)
raw_data = {name: list(map(text_transform, raw_dataset)) for name, raw_dataset in zip(data_select, raw_datasets)}
logger_.info('Building datasets for {}'.format(data_select))
raw_datasets = raw.DATASETS[dataset_name](root=root, split=split)
raw_data = {name: list(map(text_transform, raw_dataset)) for name, raw_dataset in zip(split, raw_datasets)}
logger_.info('Building datasets for {}'.format(split))
return tuple(LanguageModelingDataset(raw_data[item], vocab, text_transform)
for item in data_select)
for item in split)


def WikiText2(tokenizer=None, root='.data', vocab=None, data_select=('train', 'valid', 'test')):
def WikiText2(tokenizer=None, root='.data', vocab=None, split=('train', 'valid', 'test')):
""" Defines WikiText2 datasets.

Create language modeling dataset: WikiText2
Expand All @@ -102,7 +102,7 @@ def WikiText2(tokenizer=None, root='.data', vocab=None, data_select=('train', 'v
root: Directory where the datasets are saved. Default: ".data"
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
data_select: a string or tupel for the returned datasets. Default: ('train', 'valid','test')
split: a string or tuple for the returned datasets. Default: ('train', 'valid','test')
By default, all the three datasets (train, test, valid) are generated. Users
could also choose any one or two of them, for example ('train', 'test') or
just a string 'train'. If 'train' is not in the tuple or string, a vocab
Expand All @@ -116,13 +116,13 @@ def WikiText2(tokenizer=None, root='.data', vocab=None, data_select=('train', 'v
>>> train_dataset, valid_dataset, test_dataset = WikiText2(tokenizer=tokenizer)
>>> vocab = train_dataset.get_vocab()
>>> valid_dataset, = WikiText2(tokenizer=tokenizer, vocab=vocab,
data_select='valid')
split='valid')

"""
return _setup_datasets("WikiText2", tokenizer, root, vocab, data_select, None, None)
return _setup_datasets("WikiText2", tokenizer, root, vocab, split, None, None)


def WikiText103(tokenizer=None, root='.data', vocab=None, data_select=('train', 'valid', 'test')):
def WikiText103(tokenizer=None, root='.data', vocab=None, split=('train', 'valid', 'test')):
""" Defines WikiText103 datasets.

Create language modeling dataset: WikiText103
Expand All @@ -136,7 +136,7 @@ def WikiText103(tokenizer=None, root='.data', vocab=None, data_select=('train',
root: Directory where the datasets are saved. Default: ".data"
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
data_select: a string or tupel for the returned datasets. Default: ('train', 'valid', 'test')
split: a string or tuple for the returned datasets. Default: ('train', 'valid', 'test')
By default, all the three datasets (train, test, valid) are generated. Users
could also choose any one or two of them, for example ('train', 'test') or
just a string 'train'. If 'train' is not in the tuple or string, a vocab
Expand All @@ -150,14 +150,14 @@ def WikiText103(tokenizer=None, root='.data', vocab=None, data_select=('train',
>>> train_dataset, valid_dataset, test_dataset = WikiText103(tokenizer=tokenizer)
>>> vocab = train_dataset.get_vocab()
>>> valid_dataset, = WikiText103(tokenizer=tokenizer, vocab=vocab,
data_select='valid')
split='valid')

"""

return _setup_datasets("WikiText103", tokenizer, root, vocab, data_select, None, None)
return _setup_datasets("WikiText103", tokenizer, root, vocab, split, None, None)


def PennTreebank(tokenizer=None, root='.data', vocab=None, data_select=('train', 'valid', 'test')):
def PennTreebank(tokenizer=None, root='.data', vocab=None, split=('train', 'valid', 'test')):
""" Defines PennTreebank datasets.

Create language modeling dataset: PennTreebank
Expand All @@ -171,7 +171,7 @@ def PennTreebank(tokenizer=None, root='.data', vocab=None, data_select=('train',
root: Directory where the datasets are saved. Default: ".data"
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
data_select: a string or tupel for the returned datasets. Default: ('train', 'valid', 'test')
split: a string or tuple for the returned datasets. Default: ('train', 'valid', 'test')
By default, all the three datasets (train, test, valid) are generated. Users
could also choose any one or two of them, for example ('train', 'test') or
just a string 'train'. If 'train' is not in the tuple or string, a vocab
Expand All @@ -185,14 +185,14 @@ def PennTreebank(tokenizer=None, root='.data', vocab=None, data_select=('train',
>>> train_dataset, valid_dataset, test_dataset = PennTreebank(tokenizer=tokenizer)
>>> vocab = train_dataset.get_vocab()
>>> valid_dataset, = PennTreebank(tokenizer=tokenizer, vocab=vocab,
data_select='valid')
split='valid')

"""

return _setup_datasets("PennTreebank", tokenizer, root, vocab, data_select, None, None)
return _setup_datasets("PennTreebank", tokenizer, root, vocab, split, None, None)


def WMTNewsCrawl(tokenizer=None, root='.data', vocab=None, data_select=('train'), year=2010, language='en'):
def WMTNewsCrawl(tokenizer=None, root='.data', vocab=None, split=('train'), year=2010, language='en'):
""" Defines WMTNewsCrawl datasets.

Create language modeling dataset: WMTNewsCrawl
Expand All @@ -206,7 +206,7 @@ def WMTNewsCrawl(tokenizer=None, root='.data', vocab=None, data_select=('train')
root: Directory where the datasets are saved. Default: ".data"
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
data_select: a string or tuple for the returned datasets
split: a string or tuple for the returned datasets
(Default: ('train',))
year: the year of the dataset (Default: 2010)
language: the language of the dataset (Default: 'en')
Expand All @@ -215,12 +215,12 @@ def WMTNewsCrawl(tokenizer=None, root='.data', vocab=None, data_select=('train')
>>> from torchtext.experimental.datasets import WMTNewsCrawl
>>> from torchtext.data.utils import get_tokenizer
>>> tokenizer = get_tokenizer("spacy")
>>> train_dataset, = WMTNewsCrawl(tokenizer=tokenizer, data_select='train')
>>> train_dataset, = WMTNewsCrawl(tokenizer=tokenizer, split='train')

Note: WMTNewsCrawl provides datasets based on the year and language instead of train/valid/test.
"""

return _setup_datasets("WMTNewsCrawl", tokenizer, root, vocab, data_select, year, language)
return _setup_datasets("WMTNewsCrawl", tokenizer, root, vocab, split, year, language)


DATASETS = {
Expand Down
Loading