diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 95f2894634..67d168d02a 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -58,9 +58,9 @@ def test_wikitext2(self): # Add test for the subset of the standard datasets train_iter, valid_iter, test_iter = torchtext.experimental.datasets.raw.WikiText2(split=('train', 'valid', 'test')) - self._helper_test_func(len(train_iter), 36718, next(iter(train_iter)), ' \n') - self._helper_test_func(len(valid_iter), 3760, next(iter(valid_iter)), ' \n') - self._helper_test_func(len(test_iter), 4358, next(iter(test_iter)), ' \n') + self._helper_test_func(len(train_iter), 36718, next(train_iter), ' \n') + self._helper_test_func(len(valid_iter), 3760, next(valid_iter), ' \n') + self._helper_test_func(len(test_iter), 4358, next(test_iter), ' \n') del train_iter, valid_iter, test_iter train_dataset, test_dataset = WikiText2(split=('train', 'test')) train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset))) @@ -113,8 +113,8 @@ def test_penntreebank(self): self._helper_test_func(len(test_data), 82114, test_data[30:35], [397, 93, 4, 16, 7]) train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(split=('train', 'test')) - self._helper_test_func(len(train_iter), 42068, next(iter(train_iter))[:15], ' aer banknote b') - self._helper_test_func(len(test_iter), 3761, next(iter(test_iter))[:25], " no it was n't black mond") + self._helper_test_func(len(train_iter), 42068, next(train_iter)[:15], ' aer banknote b') + self._helper_test_func(len(test_iter), 3761, next(test_iter)[:25], " no it was n't black mond") del train_iter, test_iter def test_text_classification(self): @@ -134,8 +134,8 @@ def test_text_classification(self): self._helper_test_func(len(train_dataset), 120000, train_dataset[-1][1][:10], [2155, 223, 2405, 30, 3010, 2204, 54, 3603, 4930, 2405]) train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS() - self._helper_test_func(len(train_iter), 120000, next(iter(train_iter))[1][:25], 'Wall St. Bears Claw Back ') - self._helper_test_func(len(test_iter), 7600, next(iter(test_iter))[1][:25], 'Fears for T N pension aft') + self._helper_test_func(len(train_iter), 120000, next(train_iter)[1][:25], 'Wall St. Bears Claw Back ') + self._helper_test_func(len(test_iter), 7600, next(test_iter)[1][:25], 'Fears for T N pension aft') del train_iter, test_iter def test_num_lines_of_dataset(self): @@ -151,6 +151,19 @@ def test_offset_dataset(self): 'Non-OPEC Nations Sho', 'Google IPO Auction O', 'Dollar Falls Broadly']) + def test_next_method_dataset(self): + train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS() + for_count = 0 + next_count = 0 + for line in train_iter: + for_count += 1 + try: + next(train_iter) + next_count += 1 + except: + break + self.assertEqual((for_count, next_count), (60000, 60000)) + def test_imdb(self): from torchtext.experimental.datasets import IMDB from torchtext.vocab import Vocab @@ -171,8 +184,8 @@ def test_imdb(self): self._helper_test_func(len(train_dataset), 25000, train_dataset[0][1][:10], [13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92]) train_iter, test_iter = torchtext.experimental.datasets.raw.IMDB() - self._helper_test_func(len(train_iter), 25000, next(iter(train_iter))[1][:25], 'I rented I AM CURIOUS-YEL') - self._helper_test_func(len(test_iter), 25000, next(iter(test_iter))[1][:25], 'I love sci-fi and am will') + self._helper_test_func(len(train_iter), 25000, next(train_iter)[1][:25], 'I rented I AM CURIOUS-YEL') + self._helper_test_func(len(test_iter), 25000, next(test_iter)[1][:25], 'I love sci-fi and am will') del train_iter, test_iter def test_iwslt(self): @@ -248,10 +261,10 @@ def test_multi30k(self): # Add test for the subset of the standard datasets train_iter, valid_iter = torchtext.experimental.datasets.raw.Multi30k(split=('train', 'valid')) - self._helper_test_func(len(train_iter), 29000, ' '.join(next(iter(train_iter))), + self._helper_test_func(len(train_iter), 29000, ' '.join(next(train_iter)), ' '.join(['Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n', 'Two young, White males are outside near many bushes.\n'])) - self._helper_test_func(len(valid_iter), 1014, ' '.join(next(iter(valid_iter))), + self._helper_test_func(len(valid_iter), 1014, ' '.join(next(valid_iter)), ' '.join(['Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen\n', 'A group of men are loading cotton onto a truck\n'])) del train_iter, valid_iter @@ -323,9 +336,9 @@ def test_udpos_sequence_tagging(self): ([262, 16, 5728, 45, 289, 701, 1160, 4436, 10660, 585], [6, 20, 8, 10, 8, 8, 24, 13, 8, 15])) train_iter, valid_iter = torchtext.experimental.datasets.raw.UDPOS(split=('train', 'valid')) - self._helper_test_func(len(train_iter), 12543, ' '.join(next(iter(train_iter))[0][:5]), + self._helper_test_func(len(train_iter), 12543, ' '.join(next(train_iter)[0][:5]), ' '.join(['Al', '-', 'Zaman', ':', 'American'])) - self._helper_test_func(len(valid_iter), 2002, ' '.join(next(iter(valid_iter))[0][:5]), + self._helper_test_func(len(valid_iter), 2002, ' '.join(next(valid_iter)[0][:5]), ' '.join(['From', 'the', 'AP', 'comes', 'this'])) del train_iter, valid_iter @@ -376,9 +389,9 @@ def test_conll_sequence_tagging(self): [18, 17, 12, 19, 10, 6, 3, 3, 4, 4], [3, 5, 7, 7, 3, 2, 6, 6, 3, 2])) train_iter, test_iter = torchtext.experimental.datasets.raw.CoNLL2000Chunking() - self._helper_test_func(len(train_iter), 8936, ' '.join(next(iter(train_iter))[0][:5]), + self._helper_test_func(len(train_iter), 8936, ' '.join(next(train_iter)[0][:5]), ' '.join(['Confidence', 'in', 'the', 'pound', 'is'])) - self._helper_test_func(len(test_iter), 2012, ' '.join(next(iter(test_iter))[0][:5]), + self._helper_test_func(len(test_iter), 2012, ' '.join(next(test_iter)[0][:5]), ' '.join(['Rockwell', 'International', 'Corp.', "'s", 'Tulsa'])) del train_iter, test_iter @@ -405,9 +418,9 @@ def test_squad1(self): self._helper_test_func(len(train_dataset), 87599, (question[:5], ans_pos[0]), ([7, 24, 86, 52, 2], [72, 72])) train_iter, dev_iter = torchtext.experimental.datasets.raw.SQuAD1() - self._helper_test_func(len(train_iter), 87599, next(iter(train_iter))[0][:50], + self._helper_test_func(len(train_iter), 87599, next(train_iter)[0][:50], 'Architecturally, the school has a Catholic charact') - self._helper_test_func(len(dev_iter), 10570, next(iter(dev_iter))[0][:50], + self._helper_test_func(len(dev_iter), 10570, next(dev_iter)[0][:50], 'Super Bowl 50 was an American football game to det') del train_iter, dev_iter @@ -434,8 +447,8 @@ def test_squad2(self): self._helper_test_func(len(train_dataset), 130319, (question[:5], ans_pos[0]), ([84, 50, 1421, 12, 5439], [9, 9])) train_iter, dev_iter = torchtext.experimental.datasets.raw.SQuAD2() - self._helper_test_func(len(train_iter), 130319, next(iter(train_iter))[0][:50], + self._helper_test_func(len(train_iter), 130319, next(train_iter)[0][:50], 'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-Y') - self._helper_test_func(len(dev_iter), 11873, next(iter(dev_iter))[0][:50], + self._helper_test_func(len(dev_iter), 11873, next(dev_iter)[0][:50], 'The Normans (Norman: Nourmands; French: Normands; ') del train_iter, dev_iter diff --git a/torchtext/experimental/datasets/raw/common.py b/torchtext/experimental/datasets/raw/common.py index 1267b703ff..e53786ae54 100644 --- a/torchtext/experimental/datasets/raw/common.py +++ b/torchtext/experimental/datasets/raw/common.py @@ -32,6 +32,10 @@ def __iter__(self): break yield item + def __next__(self): + item = next(self._iterator) + return item + def __len__(self): return self.num_lines