Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 2e5431d

Browse files
committed
add initial pass at migrating AG_NEWS to datapipes.
1 parent 826a051 commit 2e5431d

File tree

2 files changed

+21
-41
lines changed

2 files changed

+21
-41
lines changed

test/data/test_builtin_datasets.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/user/bin/env python3
22
# Note that all the tests in this module require dataset (either network access or cached)
3-
import torch
43
import torchtext
54
import json
65
import hashlib
@@ -24,20 +23,6 @@ class TestDataset(TorchtextTestCase):
2423
def setUpClass(cls):
2524
check_cache_status()
2625

27-
def _helper_test_func(self, length, target_length, results, target_results):
28-
self.assertEqual(length, target_length)
29-
if isinstance(target_results, list):
30-
target_results = torch.tensor(target_results, dtype=torch.int64)
31-
if isinstance(target_results, tuple):
32-
target_results = tuple(torch.tensor(item, dtype=torch.int64) for item in target_results)
33-
self.assertEqual(results, target_results)
34-
35-
def test_raw_ag_news(self):
36-
train_iter, test_iter = torchtext.datasets.AG_NEWS()
37-
self._helper_test_func(len(train_iter), 120000, next(train_iter)[1][:25], 'Wall St. Bears Claw Back ')
38-
self._helper_test_func(len(test_iter), 7600, next(test_iter)[1][:25], 'Fears for T N pension aft')
39-
del train_iter, test_iter
40-
4126
@parameterized.expand(
4227
load_params('raw_datasets.jsonl'),
4328
name_func=_raw_text_custom_name_func)
@@ -74,16 +59,3 @@ def test_raw_datasets_split_argument(self, dataset_name):
7459
break
7560
# Exercise default constructor
7661
_ = dataset()
77-
78-
def test_next_method_dataset(self):
79-
train_iter, test_iter = torchtext.datasets.AG_NEWS()
80-
for_count = 0
81-
next_count = 0
82-
for line in train_iter:
83-
for_count += 1
84-
try:
85-
next(train_iter)
86-
next_count += 1
87-
except:
88-
break
89-
self.assertEqual((for_count, next_count), (60000, 60000))

torchtext/datasets/ag_news.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from torchtext.utils import (
2-
download_from_url,
3-
)
1+
from torchtext._internal.module_utils import is_module_available
2+
from typing import Union, Tuple
3+
4+
if is_module_available("torchdata"):
5+
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper
6+
47
from torchtext.data.datasets_utils import (
5-
_RawTextIterableDataset,
68
_wrap_split_argument,
79
_add_docstring_header,
810
_create_dataset_directory,
9-
_create_data_from_csv,
1011
)
1112
import os
1213

@@ -30,11 +31,18 @@
3031

3132
@_add_docstring_header(num_lines=NUM_LINES, num_classes=4)
3233
@_create_dataset_directory(dataset_name=DATASET_NAME)
33-
@_wrap_split_argument(('train', 'test'))
34-
def AG_NEWS(root, split):
35-
path = download_from_url(URL[split], root=root,
36-
path=os.path.join(root, split + ".csv"),
37-
hash_value=MD5[split],
38-
hash_type='md5')
39-
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
40-
_create_data_from_csv(path))
34+
@_wrap_split_argument(("train", "test"))
35+
def AG_NEWS(root: str, split: Union[Tuple[str], str]):
36+
if not is_module_available("torchdata"):
37+
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")
38+
39+
url_dp = IterableWrapper([URL[split]])
40+
http_dp = HttpReader(url_dp)
41+
cache_dp = http_dp.on_disk_cache(
42+
filepath_fn=lambda x: os.path.join(root, split + ".csv"),
43+
hash_dict={os.path.join(root, split + ".csv"): MD5[split]},
44+
hash_type="md5"
45+
).end_caching(mode="w", same_filepath_fn=True)
46+
47+
cache_dp = FileOpener(cache_dp, mode="r")
48+
return cache_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))

0 commit comments

Comments
 (0)