From 6a3a8b4ea6575633c3a2d87bda4e350e08b1fbcd Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 4 Feb 2021 12:41:20 -0800 Subject: [PATCH] fix a bug in setup_iter func in RawTextIterableDataset --- test/data/test_builtin_datasets.py | 6 ++++++ torchtext/experimental/datasets/raw/common.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index fcef4f1507..ca53feff6a 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -138,6 +138,12 @@ def test_text_classification(self): self._helper_test_func(len(test_iter), 7600, next(iter(test_iter))[1][:25], 'Fears for T N pension aft') del train_iter, test_iter + def test_num_lines_of_setup_iter_dataset(self): + train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS() + train_iter.setup_iter(start=10, num_lines=100) + _data = [item for item in train_iter] + self.assertEqual(len(_data), 100) + def test_imdb(self): from torchtext.experimental.datasets import IMDB from torchtext.vocab import Vocab diff --git a/torchtext/experimental/datasets/raw/common.py b/torchtext/experimental/datasets/raw/common.py index 9be6505168..06415830c3 100644 --- a/torchtext/experimental/datasets/raw/common.py +++ b/torchtext/experimental/datasets/raw/common.py @@ -40,7 +40,7 @@ def __iter__(self): for i, item in enumerate(self._iterator): if i < self.start: continue - if self.num_lines and i > (self.start + self.num_lines): + if self.num_lines and i >= (self.start + self.num_lines): break yield item