Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 14fef0f

Browse files
authored
Add YelpReviewFull Mocked Unit Test (#1568)
* add test_yelpreviewfull.py to mock YelpReviewFull
1 parent 339804f commit 14fef0f

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os
2+
import random
3+
import string
4+
import tarfile
5+
from collections import defaultdict
6+
from unittest.mock import patch
7+
8+
from parameterized import parameterized
9+
from torchtext.datasets.yelpreviewfull import YelpReviewFull
10+
11+
from ..common.case_utils import TempDirMixin, zip_equal
12+
from ..common.torchtext_test_case import TorchtextTestCase
13+
14+
15+
def _get_mock_dataset(root_dir):
16+
"""
17+
root_dir: directory to the mocked dataset
18+
"""
19+
base_dir = os.path.join(root_dir, "YelpReviewFull")
20+
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir")
21+
os.makedirs(temp_dataset_dir, exist_ok=True)
22+
23+
seed = 1
24+
mocked_data = defaultdict(list)
25+
for file_name in ("train.csv", "test.csv"):
26+
csv_file = os.path.join(temp_dataset_dir, file_name)
27+
mocked_lines = mocked_data[os.path.splitext(file_name)[0]]
28+
with open(csv_file, "w") as f:
29+
for i in range(5):
30+
label = seed % 5 + 1
31+
rand_string = " ".join(
32+
random.choice(string.ascii_letters) for i in range(seed)
33+
)
34+
dataset_line = (label, f"{rand_string}")
35+
f.write(f'"{label}","{rand_string}"\n')
36+
37+
# append line to correct dataset split
38+
mocked_lines.append(dataset_line)
39+
seed += 1
40+
41+
compressed_dataset_path = os.path.join(base_dir, "yelp_review_full_csv.tar.gz")
42+
# create gz file from dataset folder
43+
with tarfile.open(compressed_dataset_path, "w:gz") as tar:
44+
tar.add(temp_dataset_dir, arcname="yelp_review_full_csv")
45+
46+
return mocked_data
47+
48+
49+
class TestYelpReviewFull(TempDirMixin, TorchtextTestCase):
50+
root_dir = None
51+
samples = []
52+
53+
@classmethod
54+
def setUpClass(cls):
55+
super().setUpClass()
56+
cls.root_dir = cls.get_base_temp_dir()
57+
cls.samples = _get_mock_dataset(cls.root_dir)
58+
cls.patcher = patch(
59+
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
60+
)
61+
cls.patcher.start()
62+
63+
@classmethod
64+
def tearDownClass(cls):
65+
cls.patcher.stop()
66+
super().tearDownClass()
67+
68+
@parameterized.expand(["train", "test"])
69+
def test_yelpreviewfull(self, split):
70+
dataset = YelpReviewFull(root=self.root_dir, split=split)
71+
72+
samples = list(dataset)
73+
expected_samples = self.samples[split]
74+
for sample, expected_sample in zip_equal(samples, expected_samples):
75+
self.assertEqual(sample, expected_sample)
76+
77+
@parameterized.expand(["train", "test"])
78+
def test_yelpreviewfull_split_argument(self, split):
79+
dataset1 = YelpReviewFull(root=self.root_dir, split=split)
80+
(dataset2,) = YelpReviewFull(root=self.root_dir, split=(split,))
81+
82+
for d1, d2 in zip_equal(dataset1, dataset2):
83+
self.assertEqual(d1, d2)

0 commit comments

Comments
 (0)