diff --git a/test/data/test_builtin_datasets.py b/test/data/test_builtin_datasets.py index 02ddf1447e..41ede61819 100644 --- a/test/data/test_builtin_datasets.py +++ b/test/data/test_builtin_datasets.py @@ -38,20 +38,6 @@ def test_raw_ag_news(self): self._helper_test_func(len(test_iter), 7600, next(test_iter)[1][:25], 'Fears for T N pension aft') del train_iter, test_iter - @parameterized.expand( - load_params('raw_datasets.jsonl'), - name_func=_raw_text_custom_name_func) - def test_raw_text_name_property(self, info): - dataset_name = info['dataset_name'] - split = info['split'] - - if dataset_name == 'WMT14': - return - else: - data_iter = torchtext.datasets.DATASETS[dataset_name](split=split) - - self.assertEqual(str(data_iter), dataset_name) - @parameterized.expand( load_params('raw_datasets.jsonl'), name_func=_raw_text_custom_name_func) @@ -63,8 +49,7 @@ def test_raw_text_classification(self, info): return else: data_iter = torchtext.datasets.DATASETS[dataset_name](split=split) - self.assertEqual(len(data_iter), info['NUM_LINES']) - self.assertEqual(hashlib.md5(json.dumps(next(data_iter), sort_keys=True).encode('utf-8')).hexdigest(), info['first_line']) + self.assertEqual(hashlib.md5(json.dumps(next(iter(data_iter)), sort_keys=True).encode('utf-8')).hexdigest(), info['first_line']) if dataset_name == "AG_NEWS": self.assertEqual(torchtext.datasets.URLS[dataset_name][split], info['URL']) self.assertEqual(torchtext.datasets.MD5[dataset_name][split], info['MD5']) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 12297931f9..903d45c97e 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -209,8 +209,7 @@ def _wrap_split_argument_with_fn(fn, splits): argspec.args[1] == "split" and argspec.varargs is None and argspec.varkw is None and - len(argspec.kwonlyargs) == 0 and - len(argspec.annotations) == 0 + len(argspec.kwonlyargs) == 0 ): raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn)) @@ -240,23 +239,22 @@ def new_fn(fn): def _create_dataset_directory(dataset_name): - def decorator(func): - argspec = inspect.getfullargspec(func) + def decorator(fn): + argspec = inspect.getfullargspec(fn) if not (argspec.args[0] == "root" and argspec.args[1] == "split" and argspec.varargs is None and argspec.varkw is None and - len(argspec.kwonlyargs) == 0 and - len(argspec.annotations) == 0 + len(argspec.kwonlyargs) == 0 ): raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn)) - @functools.wraps(func) + @functools.wraps(fn) def wrapper(root=_CACHE_DIR, *args, **kwargs): new_root = os.path.join(root, dataset_name) if not os.path.exists(new_root): os.makedirs(new_root) - return func(root=new_root, *args, **kwargs) + return fn(root=new_root, *args, **kwargs) return wrapper diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index c143677fb7..6a0497c3a1 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -1,13 +1,15 @@ +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper + from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, - _download_extract_validate, _create_dataset_directory, - _create_data_from_csv, ) + import os -import logging URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM' @@ -25,20 +27,24 @@ 'test': f'{os.sep}'.join(['amazon_review_polarity_csv', 'test.csv']), } -_EXTRACTED_FILES_MD5 = { - 'train': "520937107c39a2d1d1f66cd410e9ed9e", - 'test': "f4c8bded2ecbde5f996b675db6228f16" -} DATASET_NAME = "AmazonReviewPolarity" @_add_docstring_header(num_lines=NUM_LINES, num_classes=2) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'test')) -def AmazonReviewPolarity(root, split): - path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]), - _EXTRACTED_FILES_MD5[split], hash_type="md5") - logging.info('Creating {} data'.format(split)) - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], - _create_data_from_csv(path)) +@_wrap_split_argument(("train", "test")) +def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]): + # TODO Remove this after removing conditional dependency + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL]) + cache_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" + ) + cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_dp = FileOpener(cache_dp, mode="b") + extracted_files = cache_dp.read_from_tar() + filter_extracted_files = extracted_files.filter(lambda x: split in x[0]) + return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), ' '.join(t[1:])))