Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e31d7f7

Browse files
committed
add initial pass at migrating Amazon Review Full to datapipes.
1 parent 1a05269 commit e31d7f7

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

torchtext/datasets/amazonreviewfull.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
from torchtext._internal.module_utils import is_module_available
2+
3+
if is_module_available("torchdata"):
4+
from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper
5+
16
from torchtext.data.datasets_utils import (
2-
_RawTextIterableDataset,
37
_wrap_split_argument,
48
_add_docstring_header,
5-
_download_extract_validate,
69
_create_dataset_directory,
7-
_create_data_from_csv,
810
)
11+
912
import os
10-
import logging
1113

1214
URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
1315

@@ -35,10 +37,19 @@
3537

3638
@_add_docstring_header(num_lines=NUM_LINES, num_classes=5)
3739
@_create_dataset_directory(dataset_name=DATASET_NAME)
38-
@_wrap_split_argument(('train', 'test'))
40+
@_wrap_split_argument(("train", "test"))
3941
def AmazonReviewFull(root, split):
40-
path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]),
41-
_EXTRACTED_FILES_MD5[split], hash_type="md5")
42-
logging.info('Creating {} data'.format(split))
43-
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
44-
_create_data_from_csv(path))
42+
url_dp = IterableWrapper([URL])
43+
44+
cache_dp = url_dp.on_disk_cache(
45+
filepath_fn=lambda x: os.path.join(root, _PATH),
46+
hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
47+
)
48+
cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
49+
cache_dp = FileOpener(cache_dp, mode="b")
50+
51+
extracted_files = cache_dp.read_from_tar()
52+
53+
filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
54+
55+
return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), t[1:]))

0 commit comments

Comments
 (0)