Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 15c4222

Browse files
authored
mock up AG NEWS test for faster testing. (#1553)
1 parent 9561cde commit 15c4222

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

test/datasets/test_agnews.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
import random
3+
import string
4+
from collections import defaultdict
5+
from unittest.mock import patch
6+
7+
from parameterized import parameterized
8+
from torchtext.datasets.ag_news import AG_NEWS
9+
10+
from ..common.case_utils import TempDirMixin, zip_equal
11+
from ..common.torchtext_test_case import TorchtextTestCase
12+
13+
14+
def _get_mock_dataset(root_dir):
15+
"""
16+
root_dir: directory to the mocked dataset
17+
"""
18+
temp_dataset_dir = os.path.join(root_dir, "AG_NEWS")
19+
os.makedirs(temp_dataset_dir, exist_ok=True)
20+
21+
seed = 1
22+
mocked_data = defaultdict(list)
23+
for file_name in ("train.csv", "test.csv"):
24+
txt_file = os.path.join(temp_dataset_dir, file_name)
25+
with open(txt_file, "w") as f:
26+
for i in range(5):
27+
label = seed % 4 + 1
28+
rand_string = " ".join(
29+
random.choice(string.ascii_letters) for i in range(seed)
30+
)
31+
dataset_line = (label, f"{rand_string} {rand_string}")
32+
# append line to correct dataset split
33+
mocked_data[os.path.splitext(file_name)[0]].append(dataset_line)
34+
f.write(f'"{label}","{rand_string}","{rand_string}"\n')
35+
seed += 1
36+
37+
return mocked_data
38+
39+
40+
class TestAGNews(TempDirMixin, TorchtextTestCase):
41+
root_dir = None
42+
samples = []
43+
patcher = None
44+
45+
@classmethod
46+
def setUpClass(cls):
47+
super().setUpClass()
48+
cls.root_dir = cls.get_base_temp_dir()
49+
cls.samples = _get_mock_dataset(cls.root_dir)
50+
cls.patcher = patch(
51+
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
52+
)
53+
cls.patcher.start()
54+
55+
@classmethod
56+
def tearDownClass(cls):
57+
cls.patcher.stop()
58+
super().tearDownClass()
59+
60+
@parameterized.expand(["train", "test"])
61+
def test_agnews(self, split):
62+
dataset = AG_NEWS(root=self.root_dir, split=split)
63+
64+
samples = list(dataset)
65+
expected_samples = self.samples[split]
66+
for sample, expected_sample in zip_equal(samples, expected_samples):
67+
print(sample, expected_sample)
68+
self.assertEqual(sample, expected_sample)
69+
70+
@parameterized.expand(["train", "test"])
71+
def test_agnews_split_argument(self, split):
72+
dataset1 = AG_NEWS(root=self.root_dir, split=split)
73+
(dataset2,) = AG_NEWS(root=self.root_dir, split=(split,))
74+
75+
for d1, d2 in zip_equal(dataset1, dataset2):
76+
self.assertEqual(d1, d2)

0 commit comments

Comments
 (0)