From 929f1e096a81a68fac14d9d8d695b2af4d7ea0a7 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Wed, 12 Jan 2022 07:35:57 -0500 Subject: [PATCH 1/2] add initial pass at migrating YelpReviewPolarity to datapipes. --- torchtext/datasets/yelpreviewpolarity.py | 44 ++++++++++++++++-------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/torchtext/datasets/yelpreviewpolarity.py b/torchtext/datasets/yelpreviewpolarity.py index 68dfdfacdc..107686dc71 100644 --- a/torchtext/datasets/yelpreviewpolarity.py +++ b/torchtext/datasets/yelpreviewpolarity.py @@ -1,14 +1,17 @@ -import os -from torchtext.utils import download_from_url, extract_archive +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper + from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, - _find_match, _create_dataset_directory, - _create_data_from_csv, ) +import os + URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg' MD5 = '620c8ae4bd5a150b730f1ba9a7c6a4d3' @@ -22,16 +25,29 @@ DATASET_NAME = "YelpReviewPolarity" +_EXTRACTED_FILES = { + 'train': os.path.join('yelp_review_polarity_csv', 'train.csv'), + 'test': os.path.join('yelp_review_polarity_csv', 'test.csv'), +} @_add_docstring_header(num_lines=NUM_LINES, num_classes=2) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(('train', 'test')) -def YelpReviewPolarity(root, split): - dataset_tar = download_from_url(URL, root=root, - path=os.path.join(root, _PATH), - hash_value=MD5, hash_type='md5') - extracted_files = extract_archive(dataset_tar) - - path = _find_match(split + '.csv', extracted_files) - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], - _create_data_from_csv(path)) +def YelpReviewPolarity(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL]) + + cache_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), + hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" + ) + cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_dp = FileOpener(cache_dp, mode="b") + + extracted_files = cache_dp.read_from_tar() + + filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + + return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) \ No newline at end of file From cd9dd41a18ec74ea6847230af27abb8a437b49c6 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Wed, 12 Jan 2022 08:10:51 -0500 Subject: [PATCH 2/2] fix flake. --- torchtext/datasets/yelpreviewpolarity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtext/datasets/yelpreviewpolarity.py b/torchtext/datasets/yelpreviewpolarity.py index 107686dc71..a536d6dd0f 100644 --- a/torchtext/datasets/yelpreviewpolarity.py +++ b/torchtext/datasets/yelpreviewpolarity.py @@ -30,6 +30,7 @@ 'test': os.path.join('yelp_review_polarity_csv', 'test.csv'), } + @_add_docstring_header(num_lines=NUM_LINES, num_classes=2) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(('train', 'test')) @@ -50,4 +51,4 @@ def YelpReviewPolarity(root: str, split: Union[Tuple[str], str]): filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) - return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) \ No newline at end of file + return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))