From 585096e29b8f5b530dde085c8c329e4996994a9b Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Tue, 11 Jan 2022 08:25:59 -0500 Subject: [PATCH 1/2] add initial pass at migrating SogouNews to datapipes. --- torchtext/datasets/sogounews.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/torchtext/datasets/sogounews.py b/torchtext/datasets/sogounews.py index 6370ddd522..8f5b752930 100644 --- a/torchtext/datasets/sogounews.py +++ b/torchtext/datasets/sogounews.py @@ -1,13 +1,16 @@ +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper + from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, - _download_extract_validate, _create_dataset_directory, - _create_data_from_csv, ) + import os -import logging URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUkVqNEszd0pHaFE' @@ -35,10 +38,17 @@ @_add_docstring_header(num_lines=NUM_LINES, num_classes=5) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'test')) -def SogouNews(root, split): - path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]), - _EXTRACTED_FILES_MD5[split], hash_type="md5") - logging.info('Creating {} data'.format(split)) - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], - _create_data_from_csv(path)) +@_wrap_split_argument(("train", "test")) +def SogouNews(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL]) + cache_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" + ) + cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_dp = FileOpener(cache_dp, mode="b") + extracted_files = cache_dp.read_from_tar() + filter_extracted_files = extracted_files.filter(lambda x: split + ".csv" in x[0]) + return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), ' '.join(t[1:]))) From 63a687f0fad913dbab540c73dcfcc6940c3feea5 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Tue, 11 Jan 2022 13:51:43 -0500 Subject: [PATCH 2/2] make filter for specific split more consistent. --- torchtext/datasets/sogounews.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/sogounews.py b/torchtext/datasets/sogounews.py index 8f5b752930..684c424f8b 100644 --- a/torchtext/datasets/sogounews.py +++ b/torchtext/datasets/sogounews.py @@ -50,5 +50,5 @@ def SogouNews(root: str, split: Union[Tuple[str], str]): cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) cache_dp = FileOpener(cache_dp, mode="b") extracted_files = cache_dp.read_from_tar() - filter_extracted_files = extracted_files.filter(lambda x: split + ".csv" in x[0]) + filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), ' '.join(t[1:])))