Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 58c3eb0

Browse files
committed
add initial pass at migrating AG_NEWS to datapipes.
1 parent 1a05269 commit 58c3eb0

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

torchtext/datasets/ag_news.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from torchtext.utils import (
2-
download_from_url,
3-
)
1+
from torchtext._internal.module_utils import is_module_available
2+
3+
if is_module_available("torchdata"):
4+
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper
5+
46
from torchtext.data.datasets_utils import (
5-
_RawTextIterableDataset,
67
_wrap_split_argument,
78
_add_docstring_header,
89
_create_dataset_directory,
9-
_create_data_from_csv,
1010
)
1111
import os
1212

@@ -30,11 +30,15 @@
3030

3131
@_add_docstring_header(num_lines=NUM_LINES, num_classes=4)
3232
@_create_dataset_directory(dataset_name=DATASET_NAME)
33-
@_wrap_split_argument(('train', 'test'))
33+
@_wrap_split_argument(("train", "test"))
3434
def AG_NEWS(root, split):
35-
path = download_from_url(URL[split], root=root,
36-
path=os.path.join(root, split + ".csv"),
37-
hash_value=MD5[split],
38-
hash_type='md5')
39-
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
40-
_create_data_from_csv(path))
35+
url_dp = IterableWrapper([URL[split]])
36+
http_dp = HttpReader(url_dp)
37+
cache_dp = http_dp.on_disk_cache(
38+
filepath_fn=lambda x: os.path.join(root, split + ".csv"),
39+
hash_dict={os.path.join(root, split + ".csv"): MD5[split]},
40+
hash_type="md5"
41+
).end_caching(mode="w", same_filepath_fn=True)
42+
43+
cache_dp = FileOpener(cache_dp, mode="r")
44+
return cache_dp.parse_csv().map(fn=lambda t: (int(t[0]), t[1]))

0 commit comments

Comments
 (0)