Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit df0ec14

Browse files
authored
add initial pass at migrating Amazon Review Full to datapipes. (#1499)
1 parent d896135 commit df0ec14

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

torchtext/datasets/amazonreviewfull.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
from torchtext._internal.module_utils import is_module_available
2+
from typing import Union, Tuple
3+
4+
if is_module_available("torchdata"):
5+
from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper
6+
17
from torchtext.data.datasets_utils import (
2-
_RawTextIterableDataset,
38
_wrap_split_argument,
49
_add_docstring_header,
5-
_download_extract_validate,
610
_create_dataset_directory,
7-
_create_data_from_csv,
811
)
12+
913
import os
10-
import logging
1114

1215
URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
1316

@@ -35,10 +38,22 @@
3538

3639
@_add_docstring_header(num_lines=NUM_LINES, num_classes=5)
3740
@_create_dataset_directory(dataset_name=DATASET_NAME)
38-
@_wrap_split_argument(('train', 'test'))
39-
def AmazonReviewFull(root, split):
40-
path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]),
41-
_EXTRACTED_FILES_MD5[split], hash_type="md5")
42-
logging.info('Creating {} data'.format(split))
43-
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
44-
_create_data_from_csv(path))
41+
@_wrap_split_argument(("train", "test"))
42+
def AmazonReviewFull(root: str, split: Union[Tuple[str], str]):
43+
if not is_module_available("torchdata"):
44+
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")
45+
46+
url_dp = IterableWrapper([URL])
47+
48+
cache_dp = url_dp.on_disk_cache(
49+
filepath_fn=lambda x: os.path.join(root, _PATH),
50+
hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
51+
)
52+
cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
53+
cache_dp = FileOpener(cache_dp, mode="b")
54+
55+
extracted_files = cache_dp.read_from_tar()
56+
57+
filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
58+
59+
return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))

0 commit comments

Comments
 (0)