Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions test/data/test_builtin_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/user/bin/env python3
# Note that all the tests in this module require dataset (either network access or cached)
import torch
import torchtext
import json
import hashlib
Expand All @@ -24,20 +23,6 @@ class TestDataset(TorchtextTestCase):
def setUpClass(cls):
check_cache_status()

def _helper_test_func(self, length, target_length, results, target_results):
self.assertEqual(length, target_length)
if isinstance(target_results, list):
target_results = torch.tensor(target_results, dtype=torch.int64)
if isinstance(target_results, tuple):
target_results = tuple(torch.tensor(item, dtype=torch.int64) for item in target_results)
self.assertEqual(results, target_results)

def test_raw_ag_news(self):
train_iter, test_iter = torchtext.datasets.AG_NEWS()
self._helper_test_func(len(train_iter), 120000, next(train_iter)[1][:25], 'Wall St. Bears Claw Back ')
self._helper_test_func(len(test_iter), 7600, next(test_iter)[1][:25], 'Fears for T N pension aft')
del train_iter, test_iter

@parameterized.expand(
load_params('raw_datasets.jsonl'),
name_func=_raw_text_custom_name_func)
Expand Down Expand Up @@ -74,16 +59,3 @@ def test_raw_datasets_split_argument(self, dataset_name):
break
# Exercise default constructor
_ = dataset()

def test_next_method_dataset(self):
train_iter, test_iter = torchtext.datasets.AG_NEWS()
for_count = 0
next_count = 0
for line in train_iter:
for_count += 1
try:
next(train_iter)
next_count += 1
except:
break
self.assertEqual((for_count, next_count), (60000, 60000))
34 changes: 21 additions & 13 deletions torchtext/datasets/ag_news.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from torchtext.utils import (
download_from_url,
)
from torchtext._internal.module_utils import is_module_available
from typing import Union, Tuple

if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper

from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_add_docstring_header,
_create_dataset_directory,
_create_data_from_csv,
)
import os

Expand All @@ -30,11 +31,18 @@

@_add_docstring_header(num_lines=NUM_LINES, num_classes=4)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'test'))
def AG_NEWS(root, split):
path = download_from_url(URL[split], root=root,
path=os.path.join(root, split + ".csv"),
hash_value=MD5[split],
hash_type='md5')
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
_create_data_from_csv(path))
@_wrap_split_argument(("train", "test"))
def AG_NEWS(root: str, split: Union[Tuple[str], str]):
if not is_module_available("torchdata"):
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")

url_dp = IterableWrapper([URL[split]])
cache_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, split + ".csv"),
hash_dict={os.path.join(root, split + ".csv"): MD5[split]},
hash_type="md5"
)
cache_dp = HttpReader(cache_dp)
cache_dp = cache_dp.end_caching(mode="w", same_filepath_fn=True)
Comment on lines +45 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: can we combine these two lines like we did for AmazonReviewFull and AmazonReviewPolarity

We want to have as much consistency between the dataset implementations as possible

cache_dp = FileOpener(cache_dp, mode="r")
return cache_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))