From c6118e2222afc6b0c6ae4925e45c0cec40adbda6 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 6 Jan 2022 11:38:31 -0500 Subject: [PATCH 1/8] update dataset --- torchtext/datasets/amazonreviewpolarity.py | 37 ++++++++++++++-------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index c143677fb7..af7476a2c8 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -1,13 +1,12 @@ +from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper + from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, - _download_extract_validate, _create_dataset_directory, - _create_data_from_csv, ) + import os -import logging URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM' @@ -25,20 +24,30 @@ 'test': f'{os.sep}'.join(['amazon_review_polarity_csv', 'test.csv']), } -_EXTRACTED_FILES_MD5 = { - 'train': "520937107c39a2d1d1f66cd410e9ed9e", - 'test': "f4c8bded2ecbde5f996b675db6228f16" -} DATASET_NAME = "AmazonReviewPolarity" @_add_docstring_header(num_lines=NUM_LINES, num_classes=2) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'test')) +@_wrap_split_argument(("train", "test")) def AmazonReviewPolarity(root, split): - path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]), - _EXTRACTED_FILES_MD5[split], hash_type="md5") - logging.info('Creating {} data'.format(split)) - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], - _create_data_from_csv(path)) + + url_dp = IterableWrapper([URL]) + + # cache data on-disk with sanity check + cache_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" + ) + cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + + cache_dp = FileOpener(cache_dp, mode="b") + + # stack TAR extractor on top of loader DP + extracted_files = cache_dp.read_from_tar() + + # filter files as necessary + filter_extracted_files = extracted_files.filter(lambda x: split in x[0]) + + # stack CSV reader and do some mapping + return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), t[1])) From 7e50e69fdc9bda71bf52a362fc009d39735f75e5 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 6 Jan 2022 13:17:23 -0500 Subject: [PATCH 2/8] update amazon dataset --- test/data/test_builtin_datasets.py | 3 +-- torchtext/datasets/amazonreviewpolarity.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 02ddf1447e..eafede6010 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -63,8 +63,7 @@ def test_raw_text_classification(self, info): return else: data_iter = torchtext.datasets.DATASETS[dataset_name](split=split) - self.assertEqual(len(data_iter), info['NUM_LINES']) - self.assertEqual(hashlib.md5(json.dumps(next(data_iter), sort_keys=True).encode('utf-8')).hexdigest(), info['first_line']) + self.assertEqual(hashlib.md5(json.dumps(next(iter(data_iter)), sort_keys=True).encode('utf-8')).hexdigest(), info['first_line']) if dataset_name == "AG_NEWS": self.assertEqual(torchtext.datasets.URLS[dataset_name][split], info['URL']) self.assertEqual(torchtext.datasets.MD5[dataset_name][split], info['MD5']) diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index af7476a2c8..53cfe174a1 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -50,4 +50,4 @@ def AmazonReviewPolarity(root, split): filter_extracted_files = extracted_files.filter(lambda x: split in x[0]) # stack CSV reader and do some mapping - return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), t[1])) + return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), ' '.join(t[1:]))) From b646bac5fa702444e868e045eee40c42a185a2fd Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 7 Jan 2022 11:45:40 -0500 Subject: [PATCH 3/8] remove representation test --- test/data/test_builtin_datasets.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index eafede6010..41ede61819 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -38,20 +38,6 @@ def test_raw_ag_news(self): self._helper_test_func(len(test_iter), 7600, next(test_iter)[1][:25], 'Fears for T N pension aft') del train_iter, test_iter - @parameterized.expand( - load_params('raw_datasets.jsonl'), - name_func=_raw_text_custom_name_func) - def test_raw_text_name_property(self, info): - dataset_name = info['dataset_name'] - split = info['split'] - - if dataset_name == 'WMT14': - return - else: - data_iter = torchtext.datasets.DATASETS[dataset_name](split=split) - - self.assertEqual(str(data_iter), dataset_name) - @parameterized.expand( load_params('raw_datasets.jsonl'), name_func=_raw_text_custom_name_func) From 033889b8b2d479df4e7891fb35f404bb231ff2a5 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 7 Jan 2022 12:11:57 -0500 Subject: [PATCH 4/8] conditional import --- torchtext/datasets/amazonreviewpolarity.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index 53cfe174a1..44af60e412 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -1,4 +1,7 @@ -from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper +from torchtext._internal.module_utils import is_module_available + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper from torchtext.data.datasets_utils import ( _wrap_split_argument, From a4669a49daef604c27a0eaf5e98ccf651a6592d4 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 10 Jan 2022 10:11:32 -0500 Subject: [PATCH 5/8] add modulenotfounderror --- torchtext/datasets/amazonreviewpolarity.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index 44af60e412..1a20c51e03 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -35,6 +35,9 @@ @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def AmazonReviewPolarity(root, split): + # TODO Remove this after removing conditional dependency + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") url_dp = IterableWrapper([URL]) From e56eafb584cd66027e634c45c10b389a821eac54 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 10 Jan 2022 10:29:08 -0500 Subject: [PATCH 6/8] remove comments in code --- torchtext/datasets/amazonreviewpolarity.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index 1a20c51e03..425c7f9821 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -40,20 +40,11 @@ def AmazonReviewPolarity(root, split): raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") url_dp = IterableWrapper([URL]) - - # cache data on-disk with sanity check cache_dp = url_dp.on_disk_cache( filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" ) cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_dp = FileOpener(cache_dp, mode="b") - - # stack TAR extractor on top of loader DP extracted_files = cache_dp.read_from_tar() - - # filter files as necessary filter_extracted_files = extracted_files.filter(lambda x: split in x[0]) - - # stack CSV reader and do some mapping return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), ' '.join(t[1:]))) From 97ee0c13db433bb0cb8a0876476618ffe642916b Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 10 Jan 2022 12:54:34 -0500 Subject: [PATCH 7/8] added type annotation --- torchtext/data/datasets_utils.py | 8 +++----- torchtext/datasets/amazonreviewpolarity.py | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 12297931f9..3b30fd865b 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -209,8 +209,7 @@ def _wrap_split_argument_with_fn(fn, splits): argspec.args[1] == "split" and argspec.varargs is None and argspec.varkw is None and - len(argspec.kwonlyargs) == 0 and - len(argspec.annotations) == 0 + len(argspec.kwonlyargs) == 0 ): raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn)) @@ -246,10 +245,9 @@ def decorator(func): argspec.args[1] == "split" and argspec.varargs is None and argspec.varkw is None and - len(argspec.kwonlyargs) == 0 and - len(argspec.annotations) == 0 + len(argspec.kwonlyargs) == 0 ): - raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn)) + raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(func)) @functools.wraps(func) def wrapper(root=_CACHE_DIR, *args, **kwargs): diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index 425c7f9821..6a0497c3a1 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -1,5 +1,5 @@ from torchtext._internal.module_utils import is_module_available - +from typing import Union, Tuple if is_module_available("torchdata"): from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper @@ -34,7 +34,7 @@ @_add_docstring_header(num_lines=NUM_LINES, num_classes=2) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) -def AmazonReviewPolarity(root, split): +def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]): # TODO Remove this after removing conditional dependency if not is_module_available("torchdata"): raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") From 928f3a91f6cec091e8165f4f7eb6f49361633fd7 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 10 Jan 2022 15:48:55 -0500 Subject: [PATCH 8/8] minor edit --- torchtext/data/datasets_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 3b30fd865b..903d45c97e 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -239,22 +239,22 @@ def new_fn(fn): def _create_dataset_directory(dataset_name): - def decorator(func): - argspec = inspect.getfullargspec(func) + def decorator(fn): + argspec = inspect.getfullargspec(fn) if not (argspec.args[0] == "root" and argspec.args[1] == "split" and argspec.varargs is None and argspec.varkw is None and len(argspec.kwonlyargs) == 0 ): - raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(func)) + raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn)) - @functools.wraps(func) + @functools.wraps(fn) def wrapper(root=_CACHE_DIR, *args, **kwargs): new_root = os.path.join(root, dataset_name) if not os.path.exists(new_root): os.makedirs(new_root) - return func(root=new_root, *args, **kwargs) + return fn(root=new_root, *args, **kwargs) return wrapper