Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3ba62ca

Browse files
authored
mock up IWSLT2016 test for faster testing. (#1563)
* mock up IWSLT2016 test for faster testing. * rename variable for consistency.
1 parent 448a791 commit 3ba62ca

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

test/datasets/test_iwslt2016.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os
2+
import random
3+
import string
4+
from collections import defaultdict
5+
from unittest.mock import patch
6+
7+
from parameterized import parameterized
8+
from torchtext.datasets.iwslt2016 import IWSLT2016
9+
from torchtext.data.datasets_utils import _generate_iwslt_files_for_lang_and_split
10+
11+
from ..common.case_utils import TempDirMixin, zip_equal
12+
from ..common.torchtext_test_case import TorchtextTestCase
13+
14+
15+
def _get_mock_dataset(root_dir, split, src, tgt):
16+
"""
17+
root_dir: directory to the mocked dataset
18+
"""
19+
temp_dataset_dir = os.path.join(root_dir, f"IWSLT2016/2016-01/texts/{src}/{tgt}/{src}-{tgt}/")
20+
os.makedirs(temp_dataset_dir, exist_ok=True)
21+
22+
seed = 1
23+
mocked_data = defaultdict(lambda: defaultdict(list))
24+
valid_set = "tst2013"
25+
test_set = "tst2014"
26+
27+
files_for_split, _ = _generate_iwslt_files_for_lang_and_split(16, src, tgt, valid_set, test_set)
28+
src_file = files_for_split[src][split]
29+
tgt_file = files_for_split[tgt][split]
30+
for file_name in (src_file, tgt_file):
31+
txt_file = os.path.join(temp_dataset_dir, file_name)
32+
with open(txt_file, "w") as f:
33+
# Get file extension (i.e., the language) without the . prefix (.en -> en)
34+
lang = os.path.splitext(file_name)[1][1:]
35+
for i in range(5):
36+
rand_string = " ".join(
37+
random.choice(string.ascii_letters) for i in range(seed)
38+
)
39+
dataset_line = f"{rand_string} {rand_string}\n"
40+
# append line to correct dataset split
41+
mocked_data[split][lang].append(dataset_line)
42+
f.write(f'{rand_string} {rand_string}\n')
43+
seed += 1
44+
45+
return list(zip(mocked_data[split][src], mocked_data[split][tgt]))
46+
47+
48+
class TestIWSLT2016(TempDirMixin, TorchtextTestCase):
49+
root_dir = None
50+
patcher = None
51+
52+
@classmethod
53+
def setUpClass(cls):
54+
super().setUpClass()
55+
cls.root_dir = cls.get_base_temp_dir()
56+
cls.patcher = patch(
57+
"torchdata.datapipes.iter.util.cacheholder.OnDiskCacheHolderIterDataPipe._cache_check_fn", return_value=True
58+
)
59+
cls.patcher.start()
60+
61+
@classmethod
62+
def tearDownClass(cls):
63+
cls.patcher.stop()
64+
super().tearDownClass()
65+
66+
@parameterized.expand([("train", "de", "en"), ("valid", "de", "en")])
67+
def test_iwslt2016(self, split, src, tgt):
68+
expected_samples = _get_mock_dataset(self.root_dir, split, src, tgt)
69+
70+
dataset = IWSLT2016(root=self.root_dir, split=split)
71+
72+
samples = list(dataset)
73+
74+
for sample, expected_sample in zip_equal(samples, expected_samples):
75+
self.assertEqual(sample, expected_sample)
76+
77+
@parameterized.expand(["train", "valid"])
78+
def test_iwslt2016_split_argument(self, split):
79+
dataset1 = IWSLT2016(root=self.root_dir, split=split)
80+
(dataset2,) = IWSLT2016(root=self.root_dir, split=(split,))
81+
82+
for d1, d2 in zip_equal(dataset1, dataset2):
83+
self.assertEqual(d1, d2)

0 commit comments

Comments
 (0)