11import os
22import random
3+ import shutil
34import string
5+ import tarfile
6+ import itertools
7+ import tempfile
48from collections import defaultdict
59from unittest .mock import patch
610
711from parameterized import parameterized
8- from torchtext .datasets .iwslt2016 import IWSLT2016
12+ from torchtext .datasets .iwslt2016 import DATASET_NAME , IWSLT2016 , SUPPORTED_DATASETS , SET_NOT_EXISTS
913from torchtext .data .datasets_utils import _generate_iwslt_files_for_lang_and_split
1014
11- from ..common .case_utils import TempDirMixin , zip_equal
15+ from ..common .case_utils import zip_equal
1216from ..common .torchtext_test_case import TorchtextTestCase
1317
14-
15- def _get_mock_dataset (root_dir , split , src , tgt ):
18+ SUPPORTED_LANGPAIRS = [(k , e ) for k , v in SUPPORTED_DATASETS ["language_pair" ].items () for e in v ]
19+ SUPPORTED_DEVTEST_SPLITS = SUPPORTED_DATASETS ["valid_test" ]
20+ DEV_TEST_SPLITS = [(dev , test ) for dev , test in itertools .product (SUPPORTED_DEVTEST_SPLITS , repeat = 2 ) if dev != test ]
21+
22+
23+ def _generate_uncleaned_train ():
24+ """Generate tags files"""
25+ file_contents = []
26+ examples = []
27+ xml_tags = [
28+ '<url' , '<keywords' , '<talkid' , '<description' , '<reviewer' ,
29+ '<translator' , '<title' , '<speaker' , '<doc' , '</doc'
30+ ]
31+ for i in range (100 ):
32+ rand_string = " " .join (
33+ random .choice (string .ascii_letters ) for i in range (10 )
34+ )
35+ # With a 10% change, add one of the XML tags which is cleaned
36+ # to ensure cleaning happens appropriately
37+ if random .random () < 0.1 :
38+ open_tag = random .choice (xml_tags ) + ">"
39+ close_tag = "</" + open_tag [1 :] + ">"
40+ file_contents .append (open_tag + rand_string + close_tag )
41+ else :
42+ examples .append (rand_string + "\n " )
43+ file_contents .append (rand_string )
44+ return examples , "\n " .join (file_contents )
45+
46+
47+ def _generate_uncleaned_valid ():
48+ file_contents = ["<root>" ]
49+ examples = []
50+
51+ for doc_id in range (5 ):
52+ file_contents .append (f'<doc docid="{ doc_id } " genre="lectures">' )
53+ for seg_id in range (100 ):
54+ rand_string = " " .join (
55+ random .choice (string .ascii_letters ) for i in range (10 )
56+ )
57+ examples .append (rand_string )
58+ file_contents .append (f"<seg>{ rand_string } </seg>" + "\n " )
59+ file_contents .append ("</doc>" )
60+ file_contents .append ("</root>" )
61+ return examples , " " .join (file_contents )
62+
63+
64+ def _generate_uncleaned_test ():
65+ return _generate_uncleaned_valid ()
66+
67+
68+ def _generate_uncleaned_contents (split ):
69+ return {
70+ "train" : _generate_uncleaned_train (),
71+ "valid" : _generate_uncleaned_valid (),
72+ "test" : _generate_uncleaned_test (),
73+ }[split ]
74+
75+
76+ def _get_mock_dataset (root_dir , split , src , tgt , valid_set , test_set ):
1677 """
1778 root_dir: directory to the mocked dataset
1879 """
19- temp_dataset_dir = os .path .join (root_dir , f"IWSLT2016/2016-01/texts/{ src } /{ tgt } /{ src } -{ tgt } /" )
20- os .makedirs (temp_dataset_dir , exist_ok = True )
2180
22- seed = 1
81+ base_dir = os .path .join (root_dir , DATASET_NAME )
82+ temp_dataset_dir = os .path .join (base_dir , 'temp_dataset_dir' )
83+ outer_temp_dataset_dir = os .path .join (temp_dataset_dir , f"texts/{ src } /{ tgt } /" )
84+ inner_temp_dataset_dir = os .path .join (outer_temp_dataset_dir , f"{ src } -{ tgt } " )
85+
86+ os .makedirs (outer_temp_dataset_dir , exist_ok = True )
87+ os .makedirs (inner_temp_dataset_dir , exist_ok = True )
88+
2389 mocked_data = defaultdict (lambda : defaultdict (list ))
24- valid_set = "tst2013"
25- test_set = "tst2014"
26-
27- files_for_split , _ = _generate_iwslt_files_for_lang_and_split (16 , src , tgt , valid_set , test_set )
28- src_file = files_for_split [src ][split ]
29- tgt_file = files_for_split [tgt ][split ]
30- for file_name in (src_file , tgt_file ):
31- txt_file = os .path .join (temp_dataset_dir , file_name )
32- with open (txt_file , "w" ) as f :
33- # Get file extension (i.e., the language) without the . prefix (.en -> en)
34- lang = os .path .splitext (file_name )[1 ][1 :]
35- for i in range (5 ):
36- rand_string = " " .join (
37- random .choice (string .ascii_letters ) for i in range (seed )
38- )
39- dataset_line = f"{ rand_string } { rand_string } \n "
40- # append line to correct dataset split
41- mocked_data [split ][lang ].append (dataset_line )
42- f .write (f'{ rand_string } { rand_string } \n ' )
43- seed += 1
90+
91+ cleaned_file_names , uncleaned_file_names = _generate_iwslt_files_for_lang_and_split (16 , src , tgt , valid_set , test_set )
92+ uncleaned_src_file = uncleaned_file_names [src ][split ]
93+ uncleaned_tgt_file = uncleaned_file_names [tgt ][split ]
94+
95+ cleaned_src_file = cleaned_file_names [src ][split ]
96+ cleaned_tgt_file = cleaned_file_names [tgt ][split ]
97+
98+ for (unclean_file_name , clean_file_name ) in [
99+ (uncleaned_src_file , cleaned_src_file ),
100+ (uncleaned_tgt_file , cleaned_tgt_file )
101+ ]:
102+ # Get file extension (i.e., the language) without the . prefix (.en -> en)
103+ lang = os .path .splitext (unclean_file_name )[1 ][1 :]
104+
105+ out_file = os .path .join (inner_temp_dataset_dir , unclean_file_name )
106+ with open (out_file , "w" ) as f :
107+ mocked_data_for_split , file_contents = _generate_uncleaned_contents (split )
108+ mocked_data [split ][lang ] = mocked_data_for_split
109+ f .write (file_contents )
110+
111+ inner_compressed_dataset_path = os .path .join (
112+ outer_temp_dataset_dir , f"{ src } -{ tgt } .tgz"
113+ )
114+
115+ # create tar file from dataset folder
116+ with tarfile .open (inner_compressed_dataset_path , "w:gz" ) as tar :
117+ tar .add (inner_temp_dataset_dir , arcname = f"{ src } -{ tgt } " )
118+
119+ # this is necessary so that the outer tarball only includes the inner tarball
120+ shutil .rmtree (inner_temp_dataset_dir )
121+
122+ outer_temp_dataset_path = os .path .join (base_dir , "2016-01.tgz" )
123+
124+ with tarfile .open (outer_temp_dataset_path , "w:gz" ) as tar :
125+ tar .add (temp_dataset_dir , arcname = "2016-01" )
44126
45127 return list (zip (mocked_data [split ][src ], mocked_data [split ][tgt ]))
46128
47129
48- class TestIWSLT2016 (TempDirMixin , TorchtextTestCase ):
130+ class TestIWSLT2016 (TorchtextTestCase ):
49131 root_dir = None
50132 patcher = None
51133
52134 @classmethod
53135 def setUpClass (cls ):
54136 super ().setUpClass ()
55- cls .root_dir = cls .get_base_temp_dir ()
56137 cls .patcher = patch (
57- "torchdata.datapipes.iter.util.cacheholder.OnDiskCacheHolderIterDataPipe._cache_check_fn " , return_value = True
138+ "torchdata.datapipes.iter.util.cacheholder._hash_check " , return_value = True
58139 )
59140 cls .patcher .start ()
60141
@@ -63,21 +144,36 @@ def tearDownClass(cls):
63144 cls .patcher .stop ()
64145 super ().tearDownClass ()
65146
66- @parameterized .expand ([("train" , "de" , "en" ), ("valid" , "de" , "en" )])
67- def test_iwslt2016 (self , split , src , tgt ):
68- expected_samples = _get_mock_dataset (self .root_dir , split , src , tgt )
147+ @parameterized .expand ([
148+ (split , src , tgt , dev_set , test_set )
149+ for split in ("train" , "valid" , "test" )
150+ for dev_set , test_set in DEV_TEST_SPLITS
151+ for src , tgt in SUPPORTED_LANGPAIRS
152+ if (dev_set not in SET_NOT_EXISTS [(src , tgt )] and test_set not in SET_NOT_EXISTS [(src , tgt )])
153+ ])
154+ def test_iwslt2016 (self , split , src , tgt , dev_set , test_set ):
69155
70- dataset = IWSLT2016 (root = self .root_dir , split = split )
156+ with tempfile .TemporaryDirectory () as root_dir :
157+ expected_samples = _get_mock_dataset (root_dir , split , src , tgt , dev_set , test_set )
71158
72- samples = list (dataset )
159+ dataset = IWSLT2016 (
160+ root = root_dir , split = split , language_pair = (src , tgt ), valid_set = dev_set , test_set = test_set
161+ )
73162
74- for sample , expected_sample in zip_equal (samples , expected_samples ):
75- self .assertEqual (sample , expected_sample )
163+ samples = list (dataset )
76164
77- @parameterized .expand (["train" , "valid" ])
78- def test_iwslt2016_split_argument (self , split ):
79- dataset1 = IWSLT2016 (root = self .root_dir , split = split )
80- (dataset2 ,) = IWSLT2016 (root = self .root_dir , split = (split ,))
165+ for sample , expected_sample in zip_equal (samples , expected_samples ):
166+ self .assertEqual (sample , expected_sample )
81167
82- for d1 , d2 in zip_equal (dataset1 , dataset2 ):
83- self .assertEqual (d1 , d2 )
168+ @parameterized .expand (["train" , "valid" , "test" ])
169+ def test_iwslt2016_split_argument (self , split ):
170+ with tempfile .TemporaryDirectory () as root_dir :
171+ language_pair = ("de" , "en" )
172+ valid_set = "tst2013"
173+ test_set = "tst2014"
174+ _ = _get_mock_dataset (root_dir , split , language_pair [0 ], language_pair [1 ], valid_set , test_set )
175+ dataset1 = IWSLT2016 (root = root_dir , split = split , language_pair = language_pair , valid_set = valid_set , test_set = test_set )
176+ (dataset2 ,) = IWSLT2016 (root = root_dir , split = (split ,), language_pair = language_pair , valid_set = valid_set , test_set = test_set )
177+
178+ for d1 , d2 in zip_equal (dataset1 , dataset2 ):
179+ self .assertEqual (d1 , d2 )
0 commit comments