Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 223584b

Browse files
authored
add CC100 mocking test (#1583)
1 parent d8f9559 commit 223584b

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

test/datasets/test_cc100.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import os
2+
import random
3+
import string
4+
import lzma
5+
from parameterized import parameterized
6+
from collections import defaultdict
7+
from unittest.mock import patch
8+
9+
from torchtext.datasets import CC100
10+
11+
from ..common.case_utils import TempDirMixin, zip_equal
12+
from ..common.torchtext_test_case import TorchtextTestCase
13+
14+
from torchtext.datasets.cc100 import VALID_CODES
15+
16+
17+
def _get_mock_dataset(root_dir):
18+
"""
19+
root_dir: directory to the mocked dataset
20+
"""
21+
base_dir = os.path.join(root_dir, "CC100")
22+
os.makedirs(base_dir, exist_ok=True)
23+
24+
seed = 1
25+
mocked_data = defaultdict(list)
26+
27+
for language_code in VALID_CODES:
28+
file_name = f"{language_code}.txt.xz"
29+
compressed_file = os.path.join(base_dir, file_name)
30+
with lzma.open(compressed_file, "wt") as f:
31+
for i in range(5):
32+
rand_string = " ".join(
33+
random.choice(string.ascii_letters) for i in range(seed)
34+
)
35+
content = f"{rand_string}\n"
36+
f.write(content)
37+
mocked_data[language_code].append((language_code, rand_string))
38+
seed += 1
39+
40+
return mocked_data
41+
42+
43+
class TestCC100(TempDirMixin, TorchtextTestCase):
44+
@classmethod
45+
def setUpClass(cls):
46+
super().setUpClass()
47+
cls.root_dir = cls.get_base_temp_dir()
48+
cls.samples = _get_mock_dataset(cls.root_dir)
49+
cls.patcher = patch(
50+
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
51+
)
52+
cls.patcher.start()
53+
54+
@classmethod
55+
def tearDownClass(cls):
56+
cls.patcher.stop()
57+
super().tearDownClass()
58+
59+
@parameterized.expand(VALID_CODES)
60+
def test_cc100(self, language_code):
61+
dataset = CC100(root=self.root_dir, split="train", language_code=language_code)
62+
63+
samples = list(dataset)
64+
expected_samples = self.samples[language_code]
65+
for sample, expected_sample in zip_equal(samples, expected_samples):
66+
self.assertEqual(sample, expected_sample)

0 commit comments

Comments
 (0)