From f5dc22f23c21853c0e17e6cf7d47d65ef96f0718 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sat, 6 Feb 2021 12:30:33 -0800 Subject: [PATCH 1/7] checkppoint --- test/data/test_builtin_datasets.py | 13 +++++ torchtext/experimental/datasets/raw/common.py | 8 ++- .../datasets/raw/language_modeling.py | 24 ++++---- .../datasets/raw/question_answer.py | 14 +++-- .../datasets/raw/sequence_tagging.py | 14 +++-- .../datasets/raw/text_classification.py | 58 ++++++++++++------- .../experimental/datasets/raw/translation.py | 19 +++--- 7 files changed, 98 insertions(+), 52 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 8ed1a0c04e..b431ea60db 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -144,6 +144,19 @@ def test_num_lines_of_setup_iter_dataset(self): _data = [item for item in train_iter] self.assertEqual(len(_data), 100) + def test_offset_stride_dataset(self): + train_iter, test_iter = AG_NEWS(data_select=('train', 'test'), offset=10, stride=1) + container = [text[:20] for idx, (label, text) in enumerate(train_iter) if idx < 5] + self.assertEqual(container, ['Oil and Economy Clou', 'No Need for OPEC to ', + 'Non-OPEC Nations Sho', 'Google IPO Auction O', + 'Dollar Falls Broadly']) + + train_iter, test_iter = AG_NEWS(data_select=('train', 'test'), offset=100, stride=5) + container = [text[:20] for idx, (label, text) in enumerate(test_iter) if idx < 5] + self.assertEqual(container, ['Olympic history for ', 'Edwards Banned from ', + 'Yahoo! Ups Ante for ', 'Buckeyes have lots t', + 'Oil prices bubble to']) + def test_imdb(self): from torchtext.experimental.datasets import IMDB from torchtext.vocab import Vocab diff --git a/torchtext/experimental/datasets/raw/common.py b/torchtext/experimental/datasets/raw/common.py index 06415830c3..89bf3961db 100644 --- a/torchtext/experimental/datasets/raw/common.py +++ b/torchtext/experimental/datasets/raw/common.py @@ -14,7 +14,7 @@ class RawTextIterableDataset(torch.utils.data.IterableDataset): """Defines an abstraction for raw text iterable datasets. """ - def __init__(self, name, full_num_lines, iterator): + def __init__(self, name, full_num_lines, iterator, offset=0, stride=1): """Initiate text-classification dataset. """ super(RawTextIterableDataset, self).__init__() @@ -22,8 +22,10 @@ def __init__(self, name, full_num_lines, iterator): self.full_num_lines = full_num_lines self._iterator = iterator self.has_setup = False - self.start = 0 + self.start = offset self.num_lines = None + self.setup_iter(offset) + self.stride = stride def setup_iter(self, start=0, num_lines=None): self.start = start @@ -40,6 +42,8 @@ def __iter__(self): for i, item in enumerate(self._iterator): if i < self.start: continue + if (i - self.start) % self.stride != 0: + continue if self.num_lines and i >= (self.start + self.num_lines): break yield item diff --git a/torchtext/experimental/datasets/raw/language_modeling.py b/torchtext/experimental/datasets/raw/language_modeling.py index b14c127316..14b61329b3 100644 --- a/torchtext/experimental/datasets/raw/language_modeling.py +++ b/torchtext/experimental/datasets/raw/language_modeling.py @@ -17,7 +17,7 @@ } -def _setup_datasets(dataset_name, root, data_select, year, language): +def _setup_datasets(dataset_name, root, data_select, year, language, offset): data_select = check_default_set(data_select, ('train', 'test', 'valid')) if isinstance(data_select, str): data_select = [data_select] @@ -55,10 +55,10 @@ def _setup_datasets(dataset_name, root, data_select, year, language): data[item] = iter(io.open(_path[item], encoding="utf8")) return tuple(RawTextIterableDataset(dataset_name, - NUM_LINES[dataset_name][item], data[item]) for item in data_select) + NUM_LINES[dataset_name][item], data[item], offset=offset) for item in data_select) -def WikiText2(root='.data', data_select=('train', 'valid', 'test')): +def WikiText2(root='.data', data_select=('train', 'valid', 'test'), offset=0): """ Defines WikiText2 datasets. Create language modeling dataset: WikiText2 @@ -72,6 +72,7 @@ def WikiText2(root='.data', data_select=('train', 'valid', 'test')): just a string 'train'. If 'train' is not in the tuple or string, a vocab object should be provided which will be used to process valid and/or test data. + offset: the number of the starting line. Default: 0 Examples: >>> from torchtext.experimental.raw.datasets import WikiText2 @@ -80,10 +81,10 @@ def WikiText2(root='.data', data_select=('train', 'valid', 'test')): """ - return _setup_datasets("WikiText2", root, data_select, None, None) + return _setup_datasets("WikiText2", root, data_select, None, None, offset) -def WikiText103(root='.data', data_select=('train', 'valid', 'test')): +def WikiText103(root='.data', data_select=('train', 'valid', 'test'), offset=0): """ Defines WikiText103 datasets. Create language modeling dataset: WikiText103 @@ -96,6 +97,7 @@ def WikiText103(root='.data', data_select=('train', 'valid', 'test')): could also choose any one or two of them, for example ('train', 'test'). If 'train' is not in the tuple, an vocab object should be provided which will be used to process valid and/or test data. + offset: the number of the starting line. Default: 0 Examples: >>> from torchtext.experimental.datasets.raw import WikiText103 @@ -103,10 +105,10 @@ def WikiText103(root='.data', data_select=('train', 'valid', 'test')): >>> valid_dataset, = WikiText103(data_select='valid') """ - return _setup_datasets("WikiText103", root, data_select, None, None) + return _setup_datasets("WikiText103", root, data_select, None, None, offset) -def PennTreebank(root='.data', data_select=('train', 'valid', 'test')): +def PennTreebank(root='.data', data_select=('train', 'valid', 'test'), offset=0): """ Defines PennTreebank datasets. Create language modeling dataset: PennTreebank @@ -121,6 +123,7 @@ def PennTreebank(root='.data', data_select=('train', 'valid', 'test')): just a string 'train'. If 'train' is not in the tuple or string, a vocab object should be provided which will be used to process valid and/or test data. + offset: the number of the starting line. Default: 0 Examples: >>> from torchtext.experimental.datasets.raw import PennTreebank @@ -129,10 +132,10 @@ def PennTreebank(root='.data', data_select=('train', 'valid', 'test')): """ - return _setup_datasets("PennTreebank", root, data_select, None, None) + return _setup_datasets("PennTreebank", root, data_select, None, None, offset) -def WMTNewsCrawl(root='.data', data_select=('train'), year=2010, language='en'): +def WMTNewsCrawl(root='.data', data_select=('train'), year=2010, language='en', offset=0): """ Defines WMT News Crawl. Create language modeling dataset: WMTNewsCrawl @@ -143,11 +146,12 @@ def WMTNewsCrawl(root='.data', data_select=('train'), year=2010, language='en'): (Default: 'train') year: the year of the dataset (Default: 2010) language: the language of the dataset (Default: 'en') + offset: the number of the starting line. Default: 0 Note: WMTNewsCrawl provides datasets based on the year and language instead of train/valid/test. """ - return _setup_datasets("WMTNewsCrawl", root, data_select, year, language) + return _setup_datasets("WMTNewsCrawl", root, data_select, year, language, offset) DATASETS = { diff --git a/torchtext/experimental/datasets/raw/question_answer.py b/torchtext/experimental/datasets/raw/question_answer.py index d21dbdbc55..ba7aeebb44 100644 --- a/torchtext/experimental/datasets/raw/question_answer.py +++ b/torchtext/experimental/datasets/raw/question_answer.py @@ -29,15 +29,15 @@ def _create_data_from_json(data_path): yield (_context, _question, _answers, _answer_start) -def _setup_datasets(dataset_name, root, data_select): +def _setup_datasets(dataset_name, root, data_select, offset): data_select = check_default_set(data_select, ('train', 'dev')) extracted_files = {key: download_from_url(URLS[dataset_name][key], root=root, hash_value=MD5[dataset_name][key], hash_type='md5') for key in data_select} return tuple(RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item], - _create_data_from_json(extracted_files[item])) for item in data_select) + _create_data_from_json(extracted_files[item]), offset=offset) for item in data_select) -def SQuAD1(root='.data', data_select=('train', 'dev')): +def SQuAD1(root='.data', data_select=('train', 'dev'), offset=0): """ A dataset iterator yields the data of Stanford Question Answering dataset - SQuAD1.0. The iterator yields a tuple of (raw context, raw question, a list of raw answer, a list of answer positions in the raw context). @@ -51,6 +51,7 @@ def SQuAD1(root='.data', data_select=('train', 'dev')): data_select: a string or tuple for the returned datasets (Default: ('train', 'dev')) By default, both datasets (train, dev) are generated. Users could also choose any one or two of them, for example ('train', 'dev') or just a string 'train'. + offset: the number of the starting line. Default: 0 Examples: >>> train_dataset, dev_dataset = torchtext.experimental.datasets.raw.SQuAD1() @@ -58,10 +59,10 @@ def SQuAD1(root='.data', data_select=('train', 'dev')): >>> print(idx, (context, question, answer, ans_pos)) """ - return _setup_datasets("SQuAD1", root, data_select) + return _setup_datasets("SQuAD1", root, data_select, offset) -def SQuAD2(root='.data', data_select=('train', 'dev')): +def SQuAD2(root='.data', data_select=('train', 'dev'), offset=0): """ A dataset iterator yields the data of Stanford Question Answering dataset - SQuAD2.0. The iterator yields a tuple of (raw context, raw question, a list of raw answer, a list of answer positions in the raw context). @@ -75,6 +76,7 @@ def SQuAD2(root='.data', data_select=('train', 'dev')): data_select: a string or tuple for the returned datasets (Default: ('train', 'dev')) By default, both datasets (train, dev) are generated. Users could also choose any one or two of them, for example ('train', 'dev') or just a string 'train'. + offset: the number of the starting line. Default: 0 Examples: >>> train_dataset, dev_dataset = torchtext.experimental.datasets.raw.SQuAD2() @@ -82,7 +84,7 @@ def SQuAD2(root='.data', data_select=('train', 'dev')): >>> print(idx, (context, question, answer, ans_pos)) """ - return _setup_datasets("SQuAD2", root, data_select) + return _setup_datasets("SQuAD2", root, data_select, offset) DATASETS = { diff --git a/torchtext/experimental/datasets/raw/sequence_tagging.py b/torchtext/experimental/datasets/raw/sequence_tagging.py index c1a67261f1..9d10c1f540 100644 --- a/torchtext/experimental/datasets/raw/sequence_tagging.py +++ b/torchtext/experimental/datasets/raw/sequence_tagging.py @@ -39,7 +39,7 @@ def _construct_filepath(paths, file_suffix): return None -def _setup_datasets(dataset_name, separator, root, data_select): +def _setup_datasets(dataset_name, separator, root, data_select, offset): data_select = check_default_set(data_select, target_select=('train', 'valid', 'test')) extracted_files = [] if isinstance(URLS[dataset_name], dict): @@ -60,11 +60,11 @@ def _setup_datasets(dataset_name, separator, root, data_select): "test": _construct_filepath(extracted_files, "test.txt") } return tuple(RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item], - _create_data_from_iob(data_filenames[item], separator)) + _create_data_from_iob(data_filenames[item], separator), offset=offset) if data_filenames[item] is not None else None for item in data_select) -def UDPOS(root=".data", data_select=('train', 'valid', 'test')): +def UDPOS(root=".data", data_select=('train', 'valid', 'test'), offset=0): """ Universal Dependencies English Web Treebank Separately returns the training and test dataset @@ -75,15 +75,16 @@ def UDPOS(root=".data", data_select=('train', 'valid', 'test')): By default, all the datasets (train, valid, test) are generated. Users could also choose any one or two of them, for example ('train', 'valid', 'test') or just a string 'train'. + offset: the number of the starting line. Default: 0 Examples: >>> from torchtext.experimental.datasets.raw import UDPOS >>> train_dataset, valid_dataset, test_dataset = UDPOS() """ - return _setup_datasets("UDPOS", "\t", root, data_select) + return _setup_datasets("UDPOS", "\t", root, data_select, offset) -def CoNLL2000Chunking(root=".data", data_select=('train', 'test')): +def CoNLL2000Chunking(root=".data", data_select=('train', 'test'), offset=0): """ CoNLL 2000 Chunking Dataset Separately returns the training and test dataset @@ -93,12 +94,13 @@ def CoNLL2000Chunking(root=".data", data_select=('train', 'test')): data_select: a string or tuple for the returned datasets (Default: ('train', 'test')) By default, both datasets (train, test) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. + offset: the number of the starting line. Default: 0 Examples: >>> from torchtext.experimental.datasets.raw import CoNLL2000Chunking >>> train_dataset, test_dataset = CoNLL2000Chunking() """ - return _setup_datasets("CoNLL2000Chunking", " ", root, data_select) + return _setup_datasets("CoNLL2000Chunking", " ", root, data_select, offset) DATASETS = { diff --git a/torchtext/experimental/datasets/raw/text_classification.py b/torchtext/experimental/datasets/raw/text_classification.py index 694e4d5d29..f03370c47c 100644 --- a/torchtext/experimental/datasets/raw/text_classification.py +++ b/torchtext/experimental/datasets/raw/text_classification.py @@ -33,7 +33,7 @@ def _create_data_from_csv(data_path): yield int(row[0]), ' '.join(row[1:]) -def _setup_datasets(dataset_name, root, data_select): +def _setup_datasets(dataset_name, root, data_select, offset, stride): data_select = check_default_set(data_select, target_select=('train', 'test')) if dataset_name == 'AG_NEWS': extracted_files = [download_from_url(URLS[dataset_name][item], root=root, @@ -51,10 +51,10 @@ def _setup_datasets(dataset_name, root, data_select): if fname.endswith('test.csv'): cvs_path['test'] = fname return tuple(RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item], - _create_data_from_csv(cvs_path[item])) for item in data_select) + _create_data_from_csv(cvs_path[item]), offset, stride) for item in data_select) -def AG_NEWS(root='.data', data_select=('train', 'test')): +def AG_NEWS(root='.data', data_select=('train', 'test'), offset=0, stride=1): """ Defines AG_NEWS datasets. Create supervised learning dataset: AG_NEWS @@ -66,15 +66,17 @@ def AG_NEWS(root='.data', data_select=('train', 'test')): data_select: a string or tuple for the returned datasets. Default: ('train', 'test') By default, both datasets (train, test) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. + offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> train, test = torchtext.experimental.datasets.raw.AG_NEWS() """ - return _setup_datasets("AG_NEWS", root, data_select) + return _setup_datasets("AG_NEWS", root, data_select, offset, stride) -def SogouNews(root='.data', data_select=('train', 'test')): +def SogouNews(root='.data', data_select=('train', 'test'), offset=0, stride=1): """ Defines SogouNews datasets. Create supervised learning dataset: SogouNews @@ -86,15 +88,17 @@ def SogouNews(root='.data', data_select=('train', 'test')): data_select: a string or tuple for the returned datasets. Default: ('train', 'test') By default, both datasets (train, test) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. + offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> train, test = torchtext.experimental.datasets.raw.SogouNews() """ - return _setup_datasets("SogouNews", root, data_select) + return _setup_datasets("SogouNews", root, data_select, offset, stride) -def DBpedia(root='.data', data_select=('train', 'test')): +def DBpedia(root='.data', data_select=('train', 'test'), offset=0, stride=1): """ Defines DBpedia datasets. Create supervised learning dataset: DBpedia @@ -106,15 +110,17 @@ def DBpedia(root='.data', data_select=('train', 'test')): data_select: a string or tuple for the returned datasets. Default: ('train', 'test') By default, both datasets (train, test) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. + offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> train, test = torchtext.experimental.datasets.raw.DBpedia() """ - return _setup_datasets("DBpedia", root, data_select) + return _setup_datasets("DBpedia", root, data_select, offset, stride) -def YelpReviewPolarity(root='.data', data_select=('train', 'test')): +def YelpReviewPolarity(root='.data', data_select=('train', 'test'), offset=0, stride=1): """ Defines YelpReviewPolarity datasets. Create supervised learning dataset: YelpReviewPolarity @@ -126,15 +132,17 @@ def YelpReviewPolarity(root='.data', data_select=('train', 'test')): data_select: a string or tuple for the returned datasets. Default: ('train', 'test') By default, both datasets (train, test) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. + offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> train, test = torchtext.experimental.datasets.raw.YelpReviewPolarity() """ - return _setup_datasets("YelpReviewPolarity", root, data_select) + return _setup_datasets("YelpReviewPolarity", root, data_select, offset, stride) -def YelpReviewFull(root='.data', data_select=('train', 'test')): +def YelpReviewFull(root='.data', data_select=('train', 'test'), offset=0, stride=1): """ Defines YelpReviewFull datasets. Create supervised learning dataset: YelpReviewFull @@ -146,15 +154,17 @@ def YelpReviewFull(root='.data', data_select=('train', 'test')): data_select: a string or tuple for the returned datasets. Default: ('train', 'test') By default, both datasets (train, test) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. + offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> train, test = torchtext.experimental.datasets.raw.YelpReviewFull() """ - return _setup_datasets("YelpReviewFull", root, data_select) + return _setup_datasets("YelpReviewFull", root, data_select, offset, stride) -def YahooAnswers(root='.data', data_select=('train', 'test')): +def YahooAnswers(root='.data', data_select=('train', 'test'), offset=0, stride=1): """ Defines YahooAnswers datasets. Create supervised learning dataset: YahooAnswers @@ -166,15 +176,17 @@ def YahooAnswers(root='.data', data_select=('train', 'test')): data_select: a string or tuple for the returned datasets. Default: ('train', 'test') By default, both datasets (train, test) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. + offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> train, test = torchtext.experimental.datasets.raw.YahooAnswers() """ - return _setup_datasets("YahooAnswers", root, data_select) + return _setup_datasets("YahooAnswers", root, data_select, offset, stride) -def AmazonReviewPolarity(root='.data', data_select=('train', 'test')): +def AmazonReviewPolarity(root='.data', data_select=('train', 'test'), offset=0, stride=1): """ Defines AmazonReviewPolarity datasets. Create supervised learning dataset: AmazonReviewPolarity @@ -186,15 +198,17 @@ def AmazonReviewPolarity(root='.data', data_select=('train', 'test')): data_select: a string or tuple for the returned datasets. Default: ('train', 'test') By default, both datasets (train, test) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. + offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> train, test = torchtext.experimental.datasets.raw.AmazonReviewPolarity() """ - return _setup_datasets("AmazonReviewPolarity", root, data_select) + return _setup_datasets("AmazonReviewPolarity", root, data_select, offset, stride) -def AmazonReviewFull(root='.data', data_select=('train', 'test')): +def AmazonReviewFull(root='.data', data_select=('train', 'test'), offset=0, stride=1): """ Defines AmazonReviewFull datasets. Create supervised learning dataset: AmazonReviewFull @@ -206,12 +220,14 @@ def AmazonReviewFull(root='.data', data_select=('train', 'test')): data_select: a string or tuple for the returned datasets. Default: ('train', 'test') By default, both datasets (train, test) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. + offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> train, test = torchtext.experimental.datasets.raw.AmazonReviewFull() """ - return _setup_datasets("AmazonReviewFull", root, data_select) + return _setup_datasets("AmazonReviewFull", root, data_select, offset, stride) def generate_imdb_data(key, extracted_files): @@ -224,7 +240,7 @@ def generate_imdb_data(key, extracted_files): yield label, f.read() -def IMDB(root='.data', data_select=('train', 'test')): +def IMDB(root='.data', data_select=('train', 'test'), offset=0, stride=1): """ Defines raw IMDB datasets. Create supervised learning dataset: IMDB @@ -236,6 +252,8 @@ def IMDB(root='.data', data_select=('train', 'test')): data_select: a string or tuple for the returned datasets. Default: ('train', 'test') By default, both datasets (train, test) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. + offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> train, test = torchtext.experimental.datasets.raw.IMDB() @@ -246,7 +264,7 @@ def IMDB(root='.data', data_select=('train', 'test')): extracted_files = extract_archive(dataset_tar) return tuple(RawTextIterableDataset("IMDB", NUM_LINES["IMDB"][item], generate_imdb_data(item, - extracted_files)) for item in data_select) + extracted_files), offset, stride) for item in data_select) DATASETS = { diff --git a/torchtext/experimental/datasets/raw/translation.py b/torchtext/experimental/datasets/raw/translation.py index 25960bead8..0c2d603e37 100644 --- a/torchtext/experimental/datasets/raw/translation.py +++ b/torchtext/experimental/datasets/raw/translation.py @@ -116,7 +116,7 @@ def _construct_filepaths(paths, src_filename, tgt_filename): def _setup_datasets(dataset_name, train_filenames, valid_filenames, test_filenames, - data_select, root): + data_select, root, offset): data_select = check_default_set(data_select, ('train', 'valid', 'test')) if not isinstance(train_filenames, tuple) and not isinstance(valid_filenames, tuple) \ and not isinstance(test_filenames, tuple): @@ -184,7 +184,7 @@ def _iter(src_data_iter, tgt_data_iter): yield item datasets.append( - RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][key], _iter(src_data_iter, tgt_data_iter))) + RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][key], _iter(src_data_iter, tgt_data_iter), offset)) return tuple(datasets) @@ -192,7 +192,7 @@ def _iter(src_data_iter, tgt_data_iter): def Multi30k(train_filenames=("train.de", "train.en"), valid_filenames=("val.de", "val.en"), test_filenames=("test_2016_flickr.de", "test_2016_flickr.en"), - data_select=('train', 'valid', 'test'), root='.data'): + data_select=('train', 'valid', 'test'), root='.data', offset=0): """ Define translation datasets: Multi30k Separately returns train/valid/test datasets as a tuple The available dataset include: @@ -259,12 +259,13 @@ def Multi30k(train_filenames=("train.de", "train.en"), just a string 'train'. If 'train' is not in the tuple or string, a vocab object should be provided which will be used to process valid and/or test data. root: Directory where the datasets are saved. Default: ".data" + offset: the number of the starting line. Default: 0 Examples: >>> from torchtext.experimental.datasets.raw import Multi30k >>> train_dataset, valid_dataset, test_dataset = Multi30k() """ - return _setup_datasets("Multi30k", train_filenames, valid_filenames, test_filenames, data_select, root) + return _setup_datasets("Multi30k", train_filenames, valid_filenames, test_filenames, data_select, root, offset) def IWSLT(train_filenames=('train.de-en.de', 'train.de-en.en'), @@ -272,7 +273,7 @@ def IWSLT(train_filenames=('train.de-en.de', 'train.de-en.en'), 'IWSLT16.TED.tst2013.de-en.en'), test_filenames=('IWSLT16.TED.tst2014.de-en.de', 'IWSLT16.TED.tst2014.de-en.en'), - data_select=('train', 'valid', 'test'), root='.data'): + data_select=('train', 'valid', 'test'), root='.data', offset=0): """ Define translation datasets: IWSLT Separately returns train/valid/test datasets The available datasets include: @@ -425,12 +426,13 @@ def IWSLT(train_filenames=('train.de-en.de', 'train.de-en.en'), just a string 'train'. If 'train' is not in the tuple or string, a vocab object should be provided which will be used to process valid and/or test data. root: Directory where the datasets are saved. Default: ".data" + offset: the number of the starting line. Default: 0 Examples: >>> from torchtext.experimental.datasets.raw import IWSLT >>> train_dataset, valid_dataset, test_dataset = IWSLT() """ - return _setup_datasets("IWSLT", train_filenames, valid_filenames, test_filenames, data_select, root) + return _setup_datasets("IWSLT", train_filenames, valid_filenames, test_filenames, data_select, root, offset) def WMT14(train_filenames=('train.tok.clean.bpe.32000.de', @@ -439,7 +441,7 @@ def WMT14(train_filenames=('train.tok.clean.bpe.32000.de', 'newstest2013.tok.bpe.32000.en'), test_filenames=('newstest2014.tok.bpe.32000.de', 'newstest2014.tok.bpe.32000.en'), - data_select=('train', 'valid', 'test'), root='.data'): + data_select=('train', 'valid', 'test'), root='.data', offset=0): """ Define translation datasets: WMT14 Separately returns train/valid/test datasets The available datasets include: @@ -507,12 +509,13 @@ def WMT14(train_filenames=('train.tok.clean.bpe.32000.de', just a string 'train'. If 'train' is not in the tuple or string, a vocab object should be provided which will be used to process valid and/or test data. root: Directory where the datasets are saved. Default: ".data" + offset: the number of the starting line. Default: 0 Examples: >>> from torchtext.experimental.datasets.raw import WMT14 >>> train_dataset, valid_dataset, test_dataset = WMT14() """ - return _setup_datasets("WMT14", train_filenames, valid_filenames, test_filenames, data_select, root) + return _setup_datasets("WMT14", train_filenames, valid_filenames, test_filenames, data_select, root, offset) DATASETS = { From 12ea82eaf61576230ed6c1f67bb34d140d25238c Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sat, 6 Feb 2021 12:41:52 -0800 Subject: [PATCH 2/7] checkpoint --- .../datasets/raw/language_modeling.py | 24 +++++++++++-------- .../datasets/raw/question_answer.py | 14 ++++++----- .../datasets/raw/sequence_tagging.py | 14 ++++++----- .../experimental/datasets/raw/translation.py | 19 ++++++++------- 4 files changed, 41 insertions(+), 30 deletions(-) diff --git a/torchtext/experimental/datasets/raw/language_modeling.py b/torchtext/experimental/datasets/raw/language_modeling.py index 14b61329b3..0e23aa0adb 100644 --- a/torchtext/experimental/datasets/raw/language_modeling.py +++ b/torchtext/experimental/datasets/raw/language_modeling.py @@ -17,7 +17,7 @@ } -def _setup_datasets(dataset_name, root, data_select, year, language, offset): +def _setup_datasets(dataset_name, root, data_select, year, language, offset, stride): data_select = check_default_set(data_select, ('train', 'test', 'valid')) if isinstance(data_select, str): data_select = [data_select] @@ -55,10 +55,10 @@ def _setup_datasets(dataset_name, root, data_select, year, language, offset): data[item] = iter(io.open(_path[item], encoding="utf8")) return tuple(RawTextIterableDataset(dataset_name, - NUM_LINES[dataset_name][item], data[item], offset=offset) for item in data_select) + NUM_LINES[dataset_name][item], data[item], offset=offset, stride=stride) for item in data_select) -def WikiText2(root='.data', data_select=('train', 'valid', 'test'), offset=0): +def WikiText2(root='.data', data_select=('train', 'valid', 'test'), offset=0, stride=1): """ Defines WikiText2 datasets. Create language modeling dataset: WikiText2 @@ -73,6 +73,7 @@ def WikiText2(root='.data', data_select=('train', 'valid', 'test'), offset=0): object should be provided which will be used to process valid and/or test data. offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> from torchtext.experimental.raw.datasets import WikiText2 @@ -81,10 +82,10 @@ def WikiText2(root='.data', data_select=('train', 'valid', 'test'), offset=0): """ - return _setup_datasets("WikiText2", root, data_select, None, None, offset) + return _setup_datasets("WikiText2", root, data_select, None, None, offset, stride) -def WikiText103(root='.data', data_select=('train', 'valid', 'test'), offset=0): +def WikiText103(root='.data', data_select=('train', 'valid', 'test'), offset=0, stride=1): """ Defines WikiText103 datasets. Create language modeling dataset: WikiText103 @@ -98,6 +99,7 @@ def WikiText103(root='.data', data_select=('train', 'valid', 'test'), offset=0): If 'train' is not in the tuple, an vocab object should be provided which will be used to process valid and/or test data. offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> from torchtext.experimental.datasets.raw import WikiText103 @@ -105,10 +107,10 @@ def WikiText103(root='.data', data_select=('train', 'valid', 'test'), offset=0): >>> valid_dataset, = WikiText103(data_select='valid') """ - return _setup_datasets("WikiText103", root, data_select, None, None, offset) + return _setup_datasets("WikiText103", root, data_select, None, None, offset, stride) -def PennTreebank(root='.data', data_select=('train', 'valid', 'test'), offset=0): +def PennTreebank(root='.data', data_select=('train', 'valid', 'test'), offset=0, stride=1): """ Defines PennTreebank datasets. Create language modeling dataset: PennTreebank @@ -124,6 +126,7 @@ def PennTreebank(root='.data', data_select=('train', 'valid', 'test'), offset=0) object should be provided which will be used to process valid and/or test data. offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> from torchtext.experimental.datasets.raw import PennTreebank @@ -132,10 +135,10 @@ def PennTreebank(root='.data', data_select=('train', 'valid', 'test'), offset=0) """ - return _setup_datasets("PennTreebank", root, data_select, None, None, offset) + return _setup_datasets("PennTreebank", root, data_select, None, None, offset, stride) -def WMTNewsCrawl(root='.data', data_select=('train'), year=2010, language='en', offset=0): +def WMTNewsCrawl(root='.data', data_select=('train'), year=2010, language='en', offset=0, stride=1): """ Defines WMT News Crawl. Create language modeling dataset: WMTNewsCrawl @@ -147,11 +150,12 @@ def WMTNewsCrawl(root='.data', data_select=('train'), year=2010, language='en', year: the year of the dataset (Default: 2010) language: the language of the dataset (Default: 'en') offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Note: WMTNewsCrawl provides datasets based on the year and language instead of train/valid/test. """ - return _setup_datasets("WMTNewsCrawl", root, data_select, year, language, offset) + return _setup_datasets("WMTNewsCrawl", root, data_select, year, language, offset, stride) DATASETS = { diff --git a/torchtext/experimental/datasets/raw/question_answer.py b/torchtext/experimental/datasets/raw/question_answer.py index ba7aeebb44..06a1c3afc7 100644 --- a/torchtext/experimental/datasets/raw/question_answer.py +++ b/torchtext/experimental/datasets/raw/question_answer.py @@ -29,15 +29,15 @@ def _create_data_from_json(data_path): yield (_context, _question, _answers, _answer_start) -def _setup_datasets(dataset_name, root, data_select, offset): +def _setup_datasets(dataset_name, root, data_select, offset, stride): data_select = check_default_set(data_select, ('train', 'dev')) extracted_files = {key: download_from_url(URLS[dataset_name][key], root=root, hash_value=MD5[dataset_name][key], hash_type='md5') for key in data_select} return tuple(RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item], - _create_data_from_json(extracted_files[item]), offset=offset) for item in data_select) + _create_data_from_json(extracted_files[item]), offset=offset, stride=stride) for item in data_select) -def SQuAD1(root='.data', data_select=('train', 'dev'), offset=0): +def SQuAD1(root='.data', data_select=('train', 'dev'), offset=0, stride=1): """ A dataset iterator yields the data of Stanford Question Answering dataset - SQuAD1.0. The iterator yields a tuple of (raw context, raw question, a list of raw answer, a list of answer positions in the raw context). @@ -52,6 +52,7 @@ def SQuAD1(root='.data', data_select=('train', 'dev'), offset=0): By default, both datasets (train, dev) are generated. Users could also choose any one or two of them, for example ('train', 'dev') or just a string 'train'. offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> train_dataset, dev_dataset = torchtext.experimental.datasets.raw.SQuAD1() @@ -59,10 +60,10 @@ def SQuAD1(root='.data', data_select=('train', 'dev'), offset=0): >>> print(idx, (context, question, answer, ans_pos)) """ - return _setup_datasets("SQuAD1", root, data_select, offset) + return _setup_datasets("SQuAD1", root, data_select, offset, stride) -def SQuAD2(root='.data', data_select=('train', 'dev'), offset=0): +def SQuAD2(root='.data', data_select=('train', 'dev'), offset=0, stride=1): """ A dataset iterator yields the data of Stanford Question Answering dataset - SQuAD2.0. The iterator yields a tuple of (raw context, raw question, a list of raw answer, a list of answer positions in the raw context). @@ -77,6 +78,7 @@ def SQuAD2(root='.data', data_select=('train', 'dev'), offset=0): By default, both datasets (train, dev) are generated. Users could also choose any one or two of them, for example ('train', 'dev') or just a string 'train'. offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> train_dataset, dev_dataset = torchtext.experimental.datasets.raw.SQuAD2() @@ -84,7 +86,7 @@ def SQuAD2(root='.data', data_select=('train', 'dev'), offset=0): >>> print(idx, (context, question, answer, ans_pos)) """ - return _setup_datasets("SQuAD2", root, data_select, offset) + return _setup_datasets("SQuAD2", root, data_select, offset, stride) DATASETS = { diff --git a/torchtext/experimental/datasets/raw/sequence_tagging.py b/torchtext/experimental/datasets/raw/sequence_tagging.py index 9d10c1f540..4c1a8754ab 100644 --- a/torchtext/experimental/datasets/raw/sequence_tagging.py +++ b/torchtext/experimental/datasets/raw/sequence_tagging.py @@ -39,7 +39,7 @@ def _construct_filepath(paths, file_suffix): return None -def _setup_datasets(dataset_name, separator, root, data_select, offset): +def _setup_datasets(dataset_name, separator, root, data_select, offset, stride): data_select = check_default_set(data_select, target_select=('train', 'valid', 'test')) extracted_files = [] if isinstance(URLS[dataset_name], dict): @@ -60,11 +60,11 @@ def _setup_datasets(dataset_name, separator, root, data_select, offset): "test": _construct_filepath(extracted_files, "test.txt") } return tuple(RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item], - _create_data_from_iob(data_filenames[item], separator), offset=offset) + _create_data_from_iob(data_filenames[item], separator), offset=offset, stride=stride) if data_filenames[item] is not None else None for item in data_select) -def UDPOS(root=".data", data_select=('train', 'valid', 'test'), offset=0): +def UDPOS(root=".data", data_select=('train', 'valid', 'test'), offset=0, stride=1): """ Universal Dependencies English Web Treebank Separately returns the training and test dataset @@ -76,15 +76,16 @@ def UDPOS(root=".data", data_select=('train', 'valid', 'test'), offset=0): Users could also choose any one or two of them, for example ('train', 'valid', 'test') or just a string 'train'. offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> from torchtext.experimental.datasets.raw import UDPOS >>> train_dataset, valid_dataset, test_dataset = UDPOS() """ - return _setup_datasets("UDPOS", "\t", root, data_select, offset) + return _setup_datasets("UDPOS", "\t", root, data_select, offset, stride) -def CoNLL2000Chunking(root=".data", data_select=('train', 'test'), offset=0): +def CoNLL2000Chunking(root=".data", data_select=('train', 'test'), offset=0, stride=1): """ CoNLL 2000 Chunking Dataset Separately returns the training and test dataset @@ -95,12 +96,13 @@ def CoNLL2000Chunking(root=".data", data_select=('train', 'test'), offset=0): By default, both datasets (train, test) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> from torchtext.experimental.datasets.raw import CoNLL2000Chunking >>> train_dataset, test_dataset = CoNLL2000Chunking() """ - return _setup_datasets("CoNLL2000Chunking", " ", root, data_select, offset) + return _setup_datasets("CoNLL2000Chunking", " ", root, data_select, offset, stride) DATASETS = { diff --git a/torchtext/experimental/datasets/raw/translation.py b/torchtext/experimental/datasets/raw/translation.py index 0c2d603e37..914b4db92c 100644 --- a/torchtext/experimental/datasets/raw/translation.py +++ b/torchtext/experimental/datasets/raw/translation.py @@ -116,7 +116,7 @@ def _construct_filepaths(paths, src_filename, tgt_filename): def _setup_datasets(dataset_name, train_filenames, valid_filenames, test_filenames, - data_select, root, offset): + data_select, root, offset, stride): data_select = check_default_set(data_select, ('train', 'valid', 'test')) if not isinstance(train_filenames, tuple) and not isinstance(valid_filenames, tuple) \ and not isinstance(test_filenames, tuple): @@ -184,7 +184,7 @@ def _iter(src_data_iter, tgt_data_iter): yield item datasets.append( - RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][key], _iter(src_data_iter, tgt_data_iter), offset)) + RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][key], _iter(src_data_iter, tgt_data_iter), offset, stride)) return tuple(datasets) @@ -192,7 +192,7 @@ def _iter(src_data_iter, tgt_data_iter): def Multi30k(train_filenames=("train.de", "train.en"), valid_filenames=("val.de", "val.en"), test_filenames=("test_2016_flickr.de", "test_2016_flickr.en"), - data_select=('train', 'valid', 'test'), root='.data', offset=0): + data_select=('train', 'valid', 'test'), root='.data', offset=0, stride=1): """ Define translation datasets: Multi30k Separately returns train/valid/test datasets as a tuple The available dataset include: @@ -260,12 +260,13 @@ def Multi30k(train_filenames=("train.de", "train.en"), object should be provided which will be used to process valid and/or test data. root: Directory where the datasets are saved. Default: ".data" offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> from torchtext.experimental.datasets.raw import Multi30k >>> train_dataset, valid_dataset, test_dataset = Multi30k() """ - return _setup_datasets("Multi30k", train_filenames, valid_filenames, test_filenames, data_select, root, offset) + return _setup_datasets("Multi30k", train_filenames, valid_filenames, test_filenames, data_select, root, offset, stride) def IWSLT(train_filenames=('train.de-en.de', 'train.de-en.en'), @@ -273,7 +274,7 @@ def IWSLT(train_filenames=('train.de-en.de', 'train.de-en.en'), 'IWSLT16.TED.tst2013.de-en.en'), test_filenames=('IWSLT16.TED.tst2014.de-en.de', 'IWSLT16.TED.tst2014.de-en.en'), - data_select=('train', 'valid', 'test'), root='.data', offset=0): + data_select=('train', 'valid', 'test'), root='.data', offset=0, stride=1): """ Define translation datasets: IWSLT Separately returns train/valid/test datasets The available datasets include: @@ -427,12 +428,13 @@ def IWSLT(train_filenames=('train.de-en.de', 'train.de-en.en'), object should be provided which will be used to process valid and/or test data. root: Directory where the datasets are saved. Default: ".data" offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> from torchtext.experimental.datasets.raw import IWSLT >>> train_dataset, valid_dataset, test_dataset = IWSLT() """ - return _setup_datasets("IWSLT", train_filenames, valid_filenames, test_filenames, data_select, root, offset) + return _setup_datasets("IWSLT", train_filenames, valid_filenames, test_filenames, data_select, root, offset, stride) def WMT14(train_filenames=('train.tok.clean.bpe.32000.de', @@ -441,7 +443,7 @@ def WMT14(train_filenames=('train.tok.clean.bpe.32000.de', 'newstest2013.tok.bpe.32000.en'), test_filenames=('newstest2014.tok.bpe.32000.de', 'newstest2014.tok.bpe.32000.en'), - data_select=('train', 'valid', 'test'), root='.data', offset=0): + data_select=('train', 'valid', 'test'), root='.data', offset=0, stride=1): """ Define translation datasets: WMT14 Separately returns train/valid/test datasets The available datasets include: @@ -510,12 +512,13 @@ def WMT14(train_filenames=('train.tok.clean.bpe.32000.de', object should be provided which will be used to process valid and/or test data. root: Directory where the datasets are saved. Default: ".data" offset: the number of the starting line. Default: 0 + stride: stride - 1 is the number of the lines to skip. Default: 1 Examples: >>> from torchtext.experimental.datasets.raw import WMT14 >>> train_dataset, valid_dataset, test_dataset = WMT14() """ - return _setup_datasets("WMT14", train_filenames, valid_filenames, test_filenames, data_select, root, offset) + return _setup_datasets("WMT14", train_filenames, valid_filenames, test_filenames, data_select, root, offset, stride) DATASETS = { From 16d57c8b04226ce1abb01ba175eda251d73eea47 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sat, 6 Feb 2021 12:45:57 -0800 Subject: [PATCH 3/7] checkpoint --- torchtext/experimental/datasets/raw/text_classification.py | 4 ++-- torchtext/experimental/datasets/raw/translation.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtext/experimental/datasets/raw/text_classification.py b/torchtext/experimental/datasets/raw/text_classification.py index f03370c47c..9e2f44f4a4 100644 --- a/torchtext/experimental/datasets/raw/text_classification.py +++ b/torchtext/experimental/datasets/raw/text_classification.py @@ -51,7 +51,7 @@ def _setup_datasets(dataset_name, root, data_select, offset, stride): if fname.endswith('test.csv'): cvs_path['test'] = fname return tuple(RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item], - _create_data_from_csv(cvs_path[item]), offset, stride) for item in data_select) + _create_data_from_csv(cvs_path[item]), offset=offset, stride=stride) for item in data_select) def AG_NEWS(root='.data', data_select=('train', 'test'), offset=0, stride=1): @@ -264,7 +264,7 @@ def IMDB(root='.data', data_select=('train', 'test'), offset=0, stride=1): extracted_files = extract_archive(dataset_tar) return tuple(RawTextIterableDataset("IMDB", NUM_LINES["IMDB"][item], generate_imdb_data(item, - extracted_files), offset, stride) for item in data_select) + extracted_files), offset=offset, stride=stride) for item in data_select) DATASETS = { diff --git a/torchtext/experimental/datasets/raw/translation.py b/torchtext/experimental/datasets/raw/translation.py index 914b4db92c..75c390327c 100644 --- a/torchtext/experimental/datasets/raw/translation.py +++ b/torchtext/experimental/datasets/raw/translation.py @@ -184,7 +184,7 @@ def _iter(src_data_iter, tgt_data_iter): yield item datasets.append( - RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][key], _iter(src_data_iter, tgt_data_iter), offset, stride)) + RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][key], _iter(src_data_iter, tgt_data_iter), offset=offset, stride=stride)) return tuple(datasets) From d89ac48ca4a45f562383c21f98a62b43164f9615 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sat, 6 Feb 2021 13:06:37 -0800 Subject: [PATCH 4/7] checkpoint --- test/data/test_builtin_datasets.py | 6 ++++-- torchtext/experimental/datasets/raw/common.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index b431ea60db..e46a3f6182 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -145,13 +145,15 @@ def test_num_lines_of_setup_iter_dataset(self): self.assertEqual(len(_data), 100) def test_offset_stride_dataset(self): - train_iter, test_iter = AG_NEWS(data_select=('train', 'test'), offset=10, stride=1) + train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS(data_select=('train', 'test'), + offset=10, stride=1) container = [text[:20] for idx, (label, text) in enumerate(train_iter) if idx < 5] self.assertEqual(container, ['Oil and Economy Clou', 'No Need for OPEC to ', 'Non-OPEC Nations Sho', 'Google IPO Auction O', 'Dollar Falls Broadly']) - train_iter, test_iter = AG_NEWS(data_select=('train', 'test'), offset=100, stride=5) + train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS(data_select=('train', 'test'), + offset=100, stride=5) container = [text[:20] for idx, (label, text) in enumerate(test_iter) if idx < 5] self.assertEqual(container, ['Olympic history for ', 'Edwards Banned from ', 'Yahoo! Ups Ante for ', 'Buckeyes have lots t', diff --git a/torchtext/experimental/datasets/raw/common.py b/torchtext/experimental/datasets/raw/common.py index 89bf3961db..a701afc211 100644 --- a/torchtext/experimental/datasets/raw/common.py +++ b/torchtext/experimental/datasets/raw/common.py @@ -24,7 +24,7 @@ def __init__(self, name, full_num_lines, iterator, offset=0, stride=1): self.has_setup = False self.start = offset self.num_lines = None - self.setup_iter(offset) + self.setup_iter(start=offset, num_lines=full_num_lines - offset) self.stride = stride def setup_iter(self, start=0, num_lines=None): From 30cfdb65458e0a7d84a4131d2b1dfe5f4e1dc971 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 8 Feb 2021 15:28:27 -0800 Subject: [PATCH 5/7] remove stride --- test/data/test_builtin_datasets.py | 11 ++--------- torchtext/experimental/datasets/raw/common.py | 5 +---- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index eb0f81e994..5953ef7878 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -144,21 +144,14 @@ def test_num_lines_of_setup_iter_dataset(self): _data = [item for item in train_iter] self.assertEqual(len(_data), 100) - def test_offset_stride_dataset(self): + def test_offset_dataset(self): train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS(data_select=('train', 'test'), - offset=10, stride=1) + offset=10) container = [text[:20] for idx, (label, text) in enumerate(train_iter) if idx < 5] self.assertEqual(container, ['Oil and Economy Clou', 'No Need for OPEC to ', 'Non-OPEC Nations Sho', 'Google IPO Auction O', 'Dollar Falls Broadly']) - train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS(data_select=('train', 'test'), - offset=100, stride=5) - container = [text[:20] for idx, (label, text) in enumerate(test_iter) if idx < 5] - self.assertEqual(container, ['Olympic history for ', 'Edwards Banned from ', - 'Yahoo! Ups Ante for ', 'Buckeyes have lots t', - 'Oil prices bubble to']) - def test_imdb(self): from torchtext.experimental.datasets import IMDB from torchtext.vocab import Vocab diff --git a/torchtext/experimental/datasets/raw/common.py b/torchtext/experimental/datasets/raw/common.py index a701afc211..d34627e278 100644 --- a/torchtext/experimental/datasets/raw/common.py +++ b/torchtext/experimental/datasets/raw/common.py @@ -14,7 +14,7 @@ class RawTextIterableDataset(torch.utils.data.IterableDataset): """Defines an abstraction for raw text iterable datasets. """ - def __init__(self, name, full_num_lines, iterator, offset=0, stride=1): + def __init__(self, name, full_num_lines, iterator, offset=0): """Initiate text-classification dataset. """ super(RawTextIterableDataset, self).__init__() @@ -25,7 +25,6 @@ def __init__(self, name, full_num_lines, iterator, offset=0, stride=1): self.start = offset self.num_lines = None self.setup_iter(start=offset, num_lines=full_num_lines - offset) - self.stride = stride def setup_iter(self, start=0, num_lines=None): self.start = start @@ -42,8 +41,6 @@ def __iter__(self): for i, item in enumerate(self._iterator): if i < self.start: continue - if (i - self.start) % self.stride != 0: - continue if self.num_lines and i >= (self.start + self.num_lines): break yield item From 4d8e8a605ae1af7e1b42b169ef94889c5e97f330 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 8 Feb 2021 15:31:57 -0800 Subject: [PATCH 6/7] remove setup_iter func --- torchtext/experimental/datasets/raw/common.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/torchtext/experimental/datasets/raw/common.py b/torchtext/experimental/datasets/raw/common.py index d34627e278..1267b703ff 100644 --- a/torchtext/experimental/datasets/raw/common.py +++ b/torchtext/experimental/datasets/raw/common.py @@ -21,23 +21,10 @@ def __init__(self, name, full_num_lines, iterator, offset=0): self.name = name self.full_num_lines = full_num_lines self._iterator = iterator - self.has_setup = False self.start = offset - self.num_lines = None - self.setup_iter(start=offset, num_lines=full_num_lines - offset) - - def setup_iter(self, start=0, num_lines=None): - self.start = start - self.num_lines = num_lines - if num_lines and self.start + self.num_lines > self.full_num_lines: - raise ValueError("Requested start {} and num_lines {} exceeds available number of lines {}".format( - self.start, self.num_lines, self.full_num_lines)) - self.has_setup = True + self.num_lines = full_num_lines - offset def __iter__(self): - if not self.has_setup: - self.setup_iter() - for i, item in enumerate(self._iterator): if i < self.start: continue @@ -46,9 +33,7 @@ def __iter__(self): yield item def __len__(self): - if self.has_setup: - return self.num_lines - return self.full_num_lines + return self.num_lines def get_iterator(self): return self._iterator From 73f4b86c7e36dbb0c7611311de3f0ba6aca48253 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 8 Feb 2021 16:07:04 -0800 Subject: [PATCH 7/7] fix CI tests --- test/data/test_builtin_datasets.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 5953ef7878..95f2894634 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -138,14 +138,13 @@ def test_text_classification(self): self._helper_test_func(len(test_iter), 7600, next(iter(test_iter))[1][:25], 'Fears for T N pension aft') del train_iter, test_iter - def test_num_lines_of_setup_iter_dataset(self): - train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS() - train_iter.setup_iter(start=10, num_lines=100) + def test_num_lines_of_dataset(self): + train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS(offset=10) _data = [item for item in train_iter] - self.assertEqual(len(_data), 100) + self.assertEqual(len(_data), 119990) def test_offset_dataset(self): - train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS(data_select=('train', 'test'), + train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS(split=('train', 'test'), offset=10) container = [text[:20] for idx, (label, text) in enumerate(train_iter) if idx < 5] self.assertEqual(container, ['Oil and Economy Clou', 'No Need for OPEC to ',