From d26aad7b939ff82644649fadc097b3428a7555e6 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 4 Feb 2021 12:28:26 -0800 Subject: [PATCH 1/5] add __next__ method to RawTextIterableDataset --- test/data/test_builtin_datasets.py | 38 +++++++++---------- torchtext/experimental/datasets/raw/common.py | 4 ++ 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index fcef4f1507..cae5b28c18 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -58,9 +58,9 @@ def test_wikitext2(self): # Add test for the subset of the standard datasets train_iter, valid_iter, test_iter = torchtext.experimental.datasets.raw.WikiText2(data_select=('train', 'valid', 'test')) - self._helper_test_func(len(train_iter), 36718, next(iter(train_iter)), ' \n') - self._helper_test_func(len(valid_iter), 3760, next(iter(valid_iter)), ' \n') - self._helper_test_func(len(test_iter), 4358, next(iter(test_iter)), ' \n') + self._helper_test_func(len(train_iter), 36718, next(train_iter), ' \n') + self._helper_test_func(len(valid_iter), 3760, next(valid_iter), ' \n') + self._helper_test_func(len(test_iter), 4358, next(test_iter), ' \n') del train_iter, valid_iter, test_iter train_dataset, test_dataset = WikiText2(data_select=('train', 'test')) train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset))) @@ -113,8 +113,8 @@ def test_penntreebank(self): self._helper_test_func(len(test_data), 82114, test_data[30:35], [397, 93, 4, 16, 7]) train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(data_select=('train', 'test')) - self._helper_test_func(len(train_iter), 42068, next(iter(train_iter))[:15], ' aer banknote b') - self._helper_test_func(len(test_iter), 3761, next(iter(test_iter))[:25], " no it was n't black mond") + self._helper_test_func(len(train_iter), 42068, next(train_iter)[:15], ' aer banknote b') + self._helper_test_func(len(test_iter), 3761, next(test_iter)[:25], " no it was n't black mond") del train_iter, test_iter def test_text_classification(self): @@ -134,8 +134,8 @@ def test_text_classification(self): self._helper_test_func(len(train_dataset), 120000, train_dataset[-1][1][:10], [2155, 223, 2405, 30, 3010, 2204, 54, 3603, 4930, 2405]) train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS() - self._helper_test_func(len(train_iter), 120000, next(iter(train_iter))[1][:25], 'Wall St. Bears Claw Back ') - self._helper_test_func(len(test_iter), 7600, next(iter(test_iter))[1][:25], 'Fears for T N pension aft') + self._helper_test_func(len(train_iter), 120000, next(train_iter)[1][:25], 'Wall St. Bears Claw Back ') + self._helper_test_func(len(test_iter), 7600, next(test_iter)[1][:25], 'Fears for T N pension aft') del train_iter, test_iter def test_imdb(self): @@ -158,8 +158,8 @@ def test_imdb(self): self._helper_test_func(len(train_dataset), 25000, train_dataset[0][1][:10], [13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92]) train_iter, test_iter = torchtext.experimental.datasets.raw.IMDB() - self._helper_test_func(len(train_iter), 25000, next(iter(train_iter))[1][:25], 'I rented I AM CURIOUS-YEL') - self._helper_test_func(len(test_iter), 25000, next(iter(test_iter))[1][:25], 'I love sci-fi and am will') + self._helper_test_func(len(train_iter), 25000, next(train_iter)[1][:25], 'I rented I AM CURIOUS-YEL') + self._helper_test_func(len(test_iter), 25000, next(test_iter)[1][:25], 'I love sci-fi and am will') del train_iter, test_iter def test_iwslt(self): @@ -226,10 +226,10 @@ def test_multi30k(self): # Add test for the subset of the standard datasets train_iter, valid_iter = torchtext.experimental.datasets.raw.Multi30k(data_select=('train', 'valid')) - self._helper_test_func(len(train_iter), 29000, ' '.join(next(iter(train_iter))), + self._helper_test_func(len(train_iter), 29000, ' '.join(next(train_iter)), ' '.join(['Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n', 'Two young, White males are outside near many bushes.\n'])) - self._helper_test_func(len(valid_iter), 1014, ' '.join(next(iter(valid_iter))), + self._helper_test_func(len(valid_iter), 1014, ' '.join(next(valid_iter)), ' '.join(['Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen\n', 'A group of men are loading cotton onto a truck\n'])) del train_iter, valid_iter @@ -298,9 +298,9 @@ def test_udpos_sequence_tagging(self): ([262, 16, 5728, 45, 289, 701, 1160, 4436, 10660, 585], [6, 20, 8, 10, 8, 8, 24, 13, 8, 15])) train_iter, valid_iter = torchtext.experimental.datasets.raw.UDPOS(data_select=('train', 'valid')) - self._helper_test_func(len(train_iter), 12543, ' '.join(next(iter(train_iter))[0][:5]), + self._helper_test_func(len(train_iter), 12543, ' '.join(next(train_iter)[0][:5]), ' '.join(['Al', '-', 'Zaman', ':', 'American'])) - self._helper_test_func(len(valid_iter), 2002, ' '.join(next(iter(valid_iter))[0][:5]), + self._helper_test_func(len(valid_iter), 2002, ' '.join(next(valid_iter)[0][:5]), ' '.join(['From', 'the', 'AP', 'comes', 'this'])) del train_iter, valid_iter @@ -351,9 +351,9 @@ def test_conll_sequence_tagging(self): [18, 17, 12, 19, 10, 6, 3, 3, 4, 4], [3, 5, 7, 7, 3, 2, 6, 6, 3, 2])) train_iter, test_iter = torchtext.experimental.datasets.raw.CoNLL2000Chunking() - self._helper_test_func(len(train_iter), 8936, ' '.join(next(iter(train_iter))[0][:5]), + self._helper_test_func(len(train_iter), 8936, ' '.join(next(train_iter)[0][:5]), ' '.join(['Confidence', 'in', 'the', 'pound', 'is'])) - self._helper_test_func(len(test_iter), 2012, ' '.join(next(iter(test_iter))[0][:5]), + self._helper_test_func(len(test_iter), 2012, ' '.join(next(test_iter)[0][:5]), ' '.join(['Rockwell', 'International', 'Corp.', "'s", 'Tulsa'])) del train_iter, test_iter @@ -380,9 +380,9 @@ def test_squad1(self): self._helper_test_func(len(train_dataset), 87599, (question[:5], ans_pos[0]), ([7, 24, 86, 52, 2], [72, 72])) train_iter, dev_iter = torchtext.experimental.datasets.raw.SQuAD1() - self._helper_test_func(len(train_iter), 87599, next(iter(train_iter))[0][:50], + self._helper_test_func(len(train_iter), 87599, next(train_iter)[0][:50], 'Architecturally, the school has a Catholic charact') - self._helper_test_func(len(dev_iter), 10570, next(iter(dev_iter))[0][:50], + self._helper_test_func(len(dev_iter), 10570, next(dev_iter)[0][:50], 'Super Bowl 50 was an American football game to det') del train_iter, dev_iter @@ -409,8 +409,8 @@ def test_squad2(self): self._helper_test_func(len(train_dataset), 130319, (question[:5], ans_pos[0]), ([84, 50, 1421, 12, 5439], [9, 9])) train_iter, dev_iter = torchtext.experimental.datasets.raw.SQuAD2() - self._helper_test_func(len(train_iter), 130319, next(iter(train_iter))[0][:50], + self._helper_test_func(len(train_iter), 130319, next(train_iter)[0][:50], 'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-Y') - self._helper_test_func(len(dev_iter), 11873, next(iter(dev_iter))[0][:50], + self._helper_test_func(len(dev_iter), 11873, next(dev_iter)[0][:50], 'The Normans (Norman: Nourmands; French: Normands; ') del train_iter, dev_iter diff --git a/torchtext/experimental/datasets/raw/common.py b/torchtext/experimental/datasets/raw/common.py index 9be6505168..96d293b98d 100644 --- a/torchtext/experimental/datasets/raw/common.py +++ b/torchtext/experimental/datasets/raw/common.py @@ -44,6 +44,10 @@ def __iter__(self): break yield item + def __next__(self): + item = self._iterator.__next__() + return item + def __len__(self): if self.has_setup: return self.num_lines From 7b887f3aec807533829f138b5052a9fdde430545 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 5 Feb 2021 07:29:31 -0800 Subject: [PATCH 2/5] add a CI test --- test/data/test_builtin_datasets.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 88df04479b..9e4acf5cfc 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -144,6 +144,15 @@ def test_num_lines_of_setup_iter_dataset(self): _data = [item for item in train_iter] self.assertEqual(len(_data), 100) + def test_next_method_dataset(self): + train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS() + dataset_len = len(train_iter) + container = [] + for line in train_iter: + container.append(line) + container.append(next(train_iter)) + self.assertEqual(len(container), dataset_len) + def test_imdb(self): from torchtext.experimental.datasets import IMDB from torchtext.vocab import Vocab From d15ad30a8857f18c55f9ab7160162feacb9646ba Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 8 Feb 2021 16:16:46 -0800 Subject: [PATCH 3/5] refactor the test --- test/data/test_builtin_datasets.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 37d21fdbeb..524974d3f2 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -146,12 +146,16 @@ def test_num_lines_of_setup_iter_dataset(self): def test_next_method_dataset(self): train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS() - dataset_len = len(train_iter) - container = [] + for_count = 0 + next_count = 0 for line in train_iter: - container.append(line) - container.append(next(train_iter)) - self.assertEqual(len(container), dataset_len) + for_count += 1 + try: + next(train_iter) + next_count += 1 + except: + print(for_count, next_count) + self.assertEqual((for_count, next_count), (60000, 60000)) def test_imdb(self): from torchtext.experimental.datasets import IMDB From ed9aa015ad89f6a3e6f25d119f722586b22f8482 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 8 Feb 2021 16:18:50 -0800 Subject: [PATCH 4/5] checkpoint --- test/data/test_builtin_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 524974d3f2..5d6131c1a1 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -154,7 +154,7 @@ def test_next_method_dataset(self): next(train_iter) next_count += 1 except: - print(for_count, next_count) + break self.assertEqual((for_count, next_count), (60000, 60000)) def test_imdb(self): From 54e6f97e936eeb89451e81c74ea1a47aa9bf2b24 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 9 Feb 2021 10:40:51 -0800 Subject: [PATCH 5/5] switch to next() method --- torchtext/experimental/datasets/raw/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/experimental/datasets/raw/common.py b/torchtext/experimental/datasets/raw/common.py index b62708ee3e..e53786ae54 100644 --- a/torchtext/experimental/datasets/raw/common.py +++ b/torchtext/experimental/datasets/raw/common.py @@ -33,7 +33,7 @@ def __iter__(self): yield item def __next__(self): - item = self._iterator.__next__() + item = next(self._iterator) return item def __len__(self):