1+ import itertools
12import os
23import random
34import shutil
45import string
56import tarfile
6- import itertools
77import tempfile
88from collections import defaultdict
99from unittest .mock import patch
1010
1111from parameterized import parameterized
12- from torchtext .datasets .iwslt2016 import DATASET_NAME , IWSLT2016 , SUPPORTED_DATASETS , SET_NOT_EXISTS
1312from torchtext .data .datasets_utils import _generate_iwslt_files_for_lang_and_split
13+ from torchtext .datasets .iwslt2016 import (
14+ DATASET_NAME ,
15+ IWSLT2016 ,
16+ SUPPORTED_DATASETS ,
17+ SET_NOT_EXISTS ,
18+ )
1419
1520from ..common .case_utils import zip_equal
1621from ..common .torchtext_test_case import TorchtextTestCase
1722
18- SUPPORTED_LANGPAIRS = [(k , e ) for k , v in SUPPORTED_DATASETS ["language_pair" ].items () for e in v ]
23+ SUPPORTED_LANGPAIRS = [
24+ (k , e ) for k , v in SUPPORTED_DATASETS ["language_pair" ].items () for e in v
25+ ]
1926SUPPORTED_DEVTEST_SPLITS = SUPPORTED_DATASETS ["valid_test" ]
20- DEV_TEST_SPLITS = [(dev , test ) for dev , test in itertools .product (SUPPORTED_DEVTEST_SPLITS , repeat = 2 ) if dev != test ]
27+ DEV_TEST_SPLITS = [
28+ (dev , test )
29+ for dev , test in itertools .product (SUPPORTED_DEVTEST_SPLITS , repeat = 2 )
30+ if dev != test
31+ ]
2132
2233
2334def _generate_uncleaned_train ():
2435 """Generate tags files"""
2536 file_contents = []
2637 examples = []
2738 xml_tags = [
28- '<url' , '<keywords' , '<talkid' , '<description' , '<reviewer' ,
29- '<translator' , '<title' , '<speaker' , '<doc' , '</doc'
39+ "<url" ,
40+ "<keywords" ,
41+ "<talkid" ,
42+ "<description" ,
43+ "<reviewer" ,
44+ "<translator" ,
45+ "<title" ,
46+ "<speaker" ,
47+ "<doc" ,
48+ "</doc" ,
3049 ]
3150 for i in range (100 ):
32- rand_string = " " .join (
33- random .choice (string .ascii_letters ) for i in range (10 )
34- )
51+ rand_string = " " .join (random .choice (string .ascii_letters ) for i in range (10 ))
3552 # With a 10% change, add one of the XML tags which is cleaned
3653 # to ensure cleaning happens appropriately
3754 if random .random () < 0.1 :
@@ -79,7 +96,7 @@ def _get_mock_dataset(root_dir, split, src, tgt, valid_set, test_set):
7996 """
8097
8198 base_dir = os .path .join (root_dir , DATASET_NAME )
82- temp_dataset_dir = os .path .join (base_dir , ' temp_dataset_dir' )
99+ temp_dataset_dir = os .path .join (base_dir , " temp_dataset_dir" )
83100 outer_temp_dataset_dir = os .path .join (temp_dataset_dir , f"texts/{ src } /{ tgt } /" )
84101 inner_temp_dataset_dir = os .path .join (outer_temp_dataset_dir , f"{ src } -{ tgt } " )
85102
@@ -88,7 +105,9 @@ def _get_mock_dataset(root_dir, split, src, tgt, valid_set, test_set):
88105
89106 mocked_data = defaultdict (lambda : defaultdict (list ))
90107
91- cleaned_file_names , uncleaned_file_names = _generate_iwslt_files_for_lang_and_split (16 , src , tgt , valid_set , test_set )
108+ cleaned_file_names , uncleaned_file_names = _generate_iwslt_files_for_lang_and_split (
109+ 16 , src , tgt , valid_set , test_set
110+ )
92111 uncleaned_src_file = uncleaned_file_names [src ][split ]
93112 uncleaned_tgt_file = uncleaned_file_names [tgt ][split ]
94113
@@ -97,7 +116,7 @@ def _get_mock_dataset(root_dir, split, src, tgt, valid_set, test_set):
97116
98117 for (unclean_file_name , clean_file_name ) in [
99118 (uncleaned_src_file , cleaned_src_file ),
100- (uncleaned_tgt_file , cleaned_tgt_file )
119+ (uncleaned_tgt_file , cleaned_tgt_file ),
101120 ]:
102121 # Get file extension (i.e., the language) without the . prefix (.en -> en)
103122 lang = os .path .splitext (unclean_file_name )[1 ][1 :]
@@ -144,20 +163,31 @@ def tearDownClass(cls):
144163 cls .patcher .stop ()
145164 super ().tearDownClass ()
146165
147- @parameterized .expand ([
148- (split , src , tgt , dev_set , test_set )
149- for split in ("train" , "valid" , "test" )
150- for dev_set , test_set in DEV_TEST_SPLITS
151- for src , tgt in SUPPORTED_LANGPAIRS
152- if (dev_set not in SET_NOT_EXISTS [(src , tgt )] and test_set not in SET_NOT_EXISTS [(src , tgt )])
153- ])
166+ @parameterized .expand (
167+ [
168+ (split , src , tgt , dev_set , test_set )
169+ for split in ("train" , "valid" , "test" )
170+ for dev_set , test_set in DEV_TEST_SPLITS
171+ for src , tgt in SUPPORTED_LANGPAIRS
172+ if (
173+ dev_set not in SET_NOT_EXISTS [(src , tgt )]
174+ and test_set not in SET_NOT_EXISTS [(src , tgt )]
175+ )
176+ ]
177+ )
154178 def test_iwslt2016 (self , split , src , tgt , dev_set , test_set ):
155179
156180 with tempfile .TemporaryDirectory () as root_dir :
157- expected_samples = _get_mock_dataset (root_dir , split , src , tgt , dev_set , test_set )
181+ expected_samples = _get_mock_dataset (
182+ root_dir , split , src , tgt , dev_set , test_set
183+ )
158184
159185 dataset = IWSLT2016 (
160- root = root_dir , split = split , language_pair = (src , tgt ), valid_set = dev_set , test_set = test_set
186+ root = root_dir ,
187+ split = split ,
188+ language_pair = (src , tgt ),
189+ valid_set = dev_set ,
190+ test_set = test_set ,
161191 )
162192
163193 samples = list (dataset )
@@ -171,9 +201,23 @@ def test_iwslt2016_split_argument(self, split):
171201 language_pair = ("de" , "en" )
172202 valid_set = "tst2013"
173203 test_set = "tst2014"
174- _ = _get_mock_dataset (root_dir , split , language_pair [0 ], language_pair [1 ], valid_set , test_set )
175- dataset1 = IWSLT2016 (root = root_dir , split = split , language_pair = language_pair , valid_set = valid_set , test_set = test_set )
176- (dataset2 ,) = IWSLT2016 (root = root_dir , split = (split ,), language_pair = language_pair , valid_set = valid_set , test_set = test_set )
204+ _ = _get_mock_dataset (
205+ root_dir , split , language_pair [0 ], language_pair [1 ], valid_set , test_set
206+ )
207+ dataset1 = IWSLT2016 (
208+ root = root_dir ,
209+ split = split ,
210+ language_pair = language_pair ,
211+ valid_set = valid_set ,
212+ test_set = test_set ,
213+ )
214+ (dataset2 ,) = IWSLT2016 (
215+ root = root_dir ,
216+ split = (split ,),
217+ language_pair = language_pair ,
218+ valid_set = valid_set ,
219+ test_set = test_set ,
220+ )
177221
178222 for d1 , d2 in zip_equal (dataset1 , dataset2 ):
179223 self .assertEqual (d1 , d2 )
0 commit comments