From 072b234eaf575f11dbf66cdf8f5a5755628db6bc Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Mon, 10 Jan 2022 16:10:55 -0500 Subject: [PATCH 1/2] add initial pass at migrating AG_NEWS to datapipes. --- test/data/test_builtin_datasets.py | 28 ------------------------ torchtext/datasets/ag_news.py | 34 ++++++++++++++++++------------ 2 files changed, 21 insertions(+), 41 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 41ede61819..437d73c6d6 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -1,6 +1,5 @@ #!/user/bin/env python3 # Note that all the tests in this module require dataset (either network access or cached) -import torch import torchtext import json import hashlib @@ -24,20 +23,6 @@ class TestDataset(TorchtextTestCase): def setUpClass(cls): check_cache_status() - def _helper_test_func(self, length, target_length, results, target_results): - self.assertEqual(length, target_length) - if isinstance(target_results, list): - target_results = torch.tensor(target_results, dtype=torch.int64) - if isinstance(target_results, tuple): - target_results = tuple(torch.tensor(item, dtype=torch.int64) for item in target_results) - self.assertEqual(results, target_results) - - def test_raw_ag_news(self): - train_iter, test_iter = torchtext.datasets.AG_NEWS() - self._helper_test_func(len(train_iter), 120000, next(train_iter)[1][:25], 'Wall St. Bears Claw Back ') - self._helper_test_func(len(test_iter), 7600, next(test_iter)[1][:25], 'Fears for T N pension aft') - del train_iter, test_iter - @parameterized.expand( load_params('raw_datasets.jsonl'), name_func=_raw_text_custom_name_func) @@ -74,16 +59,3 @@ def test_raw_datasets_split_argument(self, dataset_name): break # Exercise default constructor _ = dataset() - - def test_next_method_dataset(self): - train_iter, test_iter = torchtext.datasets.AG_NEWS() - for_count = 0 - next_count = 0 - for line in train_iter: - for_count += 1 - try: - next(train_iter) - next_count += 1 - except: - break - self.assertEqual((for_count, next_count), (60000, 60000)) diff --git a/torchtext/datasets/ag_news.py b/torchtext/datasets/ag_news.py index 136a0069c6..34570a3e8b 100644 --- a/torchtext/datasets/ag_news.py +++ b/torchtext/datasets/ag_news.py @@ -1,12 +1,13 @@ -from torchtext.utils import ( - download_from_url, -) +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, _create_dataset_directory, - _create_data_from_csv, ) import os @@ -30,11 +31,18 @@ @_add_docstring_header(num_lines=NUM_LINES, num_classes=4) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'test')) -def AG_NEWS(root, split): - path = download_from_url(URL[split], root=root, - path=os.path.join(root, split + ".csv"), - hash_value=MD5[split], - hash_type='md5') - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], - _create_data_from_csv(path)) +@_wrap_split_argument(("train", "test")) +def AG_NEWS(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL[split]]) + http_dp = HttpReader(url_dp) + cache_dp = http_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, split + ".csv"), + hash_dict={os.path.join(root, split + ".csv"): MD5[split]}, + hash_type="md5" + ).end_caching(mode="w", same_filepath_fn=True) + + cache_dp = FileOpener(cache_dp, mode="r") + return cache_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) From a9ef573c4f356fff9fd719a8cc07b6c4f8f1f867 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Tue, 11 Jan 2022 13:43:12 -0500 Subject: [PATCH 2/2] corrects caching order. --- torchtext/datasets/ag_news.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtext/datasets/ag_news.py b/torchtext/datasets/ag_news.py index 34570a3e8b..baf0a8914e 100644 --- a/torchtext/datasets/ag_news.py +++ b/torchtext/datasets/ag_news.py @@ -37,12 +37,12 @@ def AG_NEWS(root: str, split: Union[Tuple[str], str]): raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") url_dp = IterableWrapper([URL[split]]) - http_dp = HttpReader(url_dp) - cache_dp = http_dp.on_disk_cache( + cache_dp = url_dp.on_disk_cache( filepath_fn=lambda x: os.path.join(root, split + ".csv"), hash_dict={os.path.join(root, split + ".csv"): MD5[split]}, hash_type="md5" - ).end_caching(mode="w", same_filepath_fn=True) - + ) + cache_dp = HttpReader(cache_dp) + cache_dp = cache_dp.end_caching(mode="w", same_filepath_fn=True) cache_dp = FileOpener(cache_dp, mode="r") return cache_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))