Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 199af33

Browse files
committed
add initial pass at migrating Amazon Review Full to datapipes.
1 parent 1a05269 commit 199af33

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

test/data/test_builtin_datasets.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,6 @@ def test_raw_ag_news(self):
3838
self._helper_test_func(len(test_iter), 7600, next(test_iter)[1][:25], 'Fears for T N pension aft')
3939
del train_iter, test_iter
4040

41-
@parameterized.expand(
42-
load_params('raw_datasets.jsonl'),
43-
name_func=_raw_text_custom_name_func)
44-
def test_raw_text_name_property(self, info):
45-
dataset_name = info['dataset_name']
46-
split = info['split']
47-
48-
if dataset_name == 'WMT14':
49-
return
50-
else:
51-
data_iter = torchtext.datasets.DATASETS[dataset_name](split=split)
52-
53-
self.assertEqual(str(data_iter), dataset_name)
54-
5541
@parameterized.expand(
5642
load_params('raw_datasets.jsonl'),
5743
name_func=_raw_text_custom_name_func)
@@ -63,8 +49,7 @@ def test_raw_text_classification(self, info):
6349
return
6450
else:
6551
data_iter = torchtext.datasets.DATASETS[dataset_name](split=split)
66-
self.assertEqual(len(data_iter), info['NUM_LINES'])
67-
self.assertEqual(hashlib.md5(json.dumps(next(data_iter), sort_keys=True).encode('utf-8')).hexdigest(), info['first_line'])
52+
self.assertEqual(hashlib.md5(json.dumps(next(iter(data_iter)), sort_keys=True).encode('utf-8')).hexdigest(), info['first_line'])
6853
if dataset_name == "AG_NEWS":
6954
self.assertEqual(torchtext.datasets.URLS[dataset_name][split], info['URL'])
7055
self.assertEqual(torchtext.datasets.MD5[dataset_name][split], info['MD5'])

torchtext/datasets/amazonreviewfull.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
from torchtext._internal.module_utils import is_module_available
2+
3+
if is_module_available("torchdata"):
4+
from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper
5+
16
from torchtext.data.datasets_utils import (
2-
_RawTextIterableDataset,
37
_wrap_split_argument,
48
_add_docstring_header,
5-
_download_extract_validate,
69
_create_dataset_directory,
7-
_create_data_from_csv,
810
)
11+
912
import os
10-
import logging
1113

1214
URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
1315

@@ -35,10 +37,19 @@
3537

3638
@_add_docstring_header(num_lines=NUM_LINES, num_classes=5)
3739
@_create_dataset_directory(dataset_name=DATASET_NAME)
38-
@_wrap_split_argument(('train', 'test'))
40+
@_wrap_split_argument(("train", "test"))
3941
def AmazonReviewFull(root, split):
40-
path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]),
41-
_EXTRACTED_FILES_MD5[split], hash_type="md5")
42-
logging.info('Creating {} data'.format(split))
43-
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
44-
_create_data_from_csv(path))
42+
url_dp = IterableWrapper([URL])
43+
44+
cache_dp = url_dp.on_disk_cache(
45+
filepath_fn=lambda x: os.path.join(root, _PATH),
46+
hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
47+
)
48+
cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
49+
cache_dp = FileOpener(cache_dp, mode="b")
50+
51+
extracted_files = cache_dp.read_from_tar()
52+
53+
filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
54+
55+
return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))

0 commit comments

Comments
 (0)