@@ -164,7 +164,7 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
164164 Examples:
165165 >>> from torchtext.datasets import IWSLT2016
166166 >>> train_iter, valid_iter, test_iter = IWSLT2016()
167- >>> src_sentence, tgt_sentence = next(train_iter)
167+ >>> src_sentence, tgt_sentence = next(iter( train_iter) )
168168
169169 """
170170 if not is_module_available ("torchdata" ):
@@ -204,6 +204,17 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
204204 src_eval , tgt_eval = valid_filenames
205205 src_test , tgt_test = test_filenames
206206
207+ uncleaned_train_filenames = ('train.tags.{}-{}.{}' .format (src_language , tgt_language , src_language ),
208+ 'train.tags.{}-{}.{}' .format (src_language , tgt_language , tgt_language ))
209+ uncleaed_valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml' .format (SUPPORTED_DATASETS ['year' ], valid_set , src_language , tgt_language , src_language ),
210+ 'IWSLT{}.TED.{}.{}-{}.{}.xml' .format (SUPPORTED_DATASETS ['year' ], valid_set , src_language , tgt_language , tgt_language ))
211+ uncleaned_test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml' .format (SUPPORTED_DATASETS ['year' ], test_set , src_language , tgt_language , src_language ),
212+ 'IWSLT{}.TED.{}.{}-{}.{}.xml' .format (SUPPORTED_DATASETS ['year' ], test_set , src_language , tgt_language , tgt_language ))
213+
214+ uncleaned_src_train , uncleaned_tgt_train = uncleaned_train_filenames
215+ uncleaned_src_eval , uncleaned_tgt_eval = uncleaed_valid_filenames
216+ uncleaned_src_test , uncleaned_tgt_test = uncleaned_test_filenames
217+
207218 url_dp = IterableWrapper ([URL ])
208219 cache_compressed_dp = url_dp .on_disk_cache (
209220 filepath_fn = lambda x : os .path .join (root , _PATH ),
@@ -215,14 +226,13 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
215226
216227 languages = "-" .join ([src_language , tgt_language ])
217228
218- inner_iwslt_tar = os .path .join (
219- root , os .path .splitext (_PATH )[0 ], "texts" , src_language , tgt_language , languages
220- ) + ".tgz"
229+ # We create the whole filepath here, but only check for the literal filename in the filter
230+ # because we're lazily extracting from the outer tarfile. Thus,
231+ # /root/2016-01/texts/.../src-tgt.tgz will never be in /root/2016-01.tgz/texts/.../src-tgt.tgz
232+ inner_iwslt_tar = os .path .join (root , os .path .splitext (_PATH )[0 ], "texts" , src_language , tgt_language , languages ) + ".tgz"
221233
222- cache_decompressed_dp = cache_compressed_dp .on_disk_cache (
223- filepath_fn = lambda x : inner_iwslt_tar
224- )
225- cache_decompressed_dp = FileOpener (cache_decompressed_dp , mode = "b" ).read_from_tar ().filter (lambda x : inner_iwslt_tar in x [0 ])
234+ cache_decompressed_dp = cache_compressed_dp .on_disk_cache (filepath_fn = lambda x : inner_iwslt_tar )
235+ cache_decompressed_dp = FileOpener (cache_decompressed_dp , mode = "b" ).read_from_tar ().filter (lambda x : os .path .basename (inner_iwslt_tar ) in x [0 ])
226236 cache_decompressed_dp = cache_decompressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
227237
228238 file_path_by_lang_and_split = {
@@ -238,28 +248,49 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
238248 }
239249 }
240250
251+ uncleaned_filenames = {
252+ src_language : {
253+ "train" : uncleaned_src_train ,
254+ "valid" : uncleaned_src_eval ,
255+ "test" : uncleaned_src_test ,
256+ },
257+ tgt_language : {
258+ "train" : uncleaned_tgt_train ,
259+ "valid" : uncleaned_tgt_eval ,
260+ "test" : uncleaned_tgt_test ,
261+ }
262+ }
263+
241264 src_filename = file_path_by_lang_and_split [src_language ][split ]
265+ uncleaned_src_filename = uncleaned_filenames [src_language ][split ]
266+
267+ # We create the whole filepath here, but only check for the literal filename in the filter
268+ # because we're lazily extracting from the outer tarfile.
242269 full_src_filepath = os .path .join (root , "2016-01/texts/" , src_language , tgt_language , languages , src_filename )
243270
244271 cache_inner_src_decompressed_dp = cache_decompressed_dp .on_disk_cache (filepath_fn = lambda x : full_src_filepath )
245272 cache_inner_src_decompressed_dp = FileOpener (cache_inner_src_decompressed_dp , mode = "b" ).read_from_tar ()
246- cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .map (lambda x : _clean_files ( x [ 0 ], os .path .splitext ( os . path . dirname ( os . path . dirname ( x [ 0 ])))[ 0 ], x [ 1 ]) )
247- cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .filter (lambda x : full_src_filepath in x )
273+ cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .filter (lambda x : os .path .basename ( uncleaned_src_filename ) in x [ 0 ] )
274+ cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .map (lambda x : _clean_files ( full_src_filepath , x [ 0 ], x [ 1 ]) )
248275 cache_inner_src_decompressed_dp = FileOpener (cache_inner_src_decompressed_dp , mode = "b" )
249276 cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
250277
251278 tgt_filename = file_path_by_lang_and_split [tgt_language ][split ]
279+ uncleaned_tgt_filename = uncleaned_filenames [tgt_language ][split ]
280+
281+ # We create the whole filepath here, but only check for the literal filename in the filter
282+ # because we're lazily extracting from the outer tarfile.
252283 full_tgt_filepath = os .path .join (root , "2016-01/texts/" , src_language , tgt_language , languages , tgt_filename )
253284
254285 cache_inner_tgt_decompressed_dp = cache_decompressed_dp .on_disk_cache (filepath_fn = lambda x : full_tgt_filepath )
255286 cache_inner_tgt_decompressed_dp = FileOpener (cache_inner_tgt_decompressed_dp , mode = "b" ).read_from_tar ()
256- cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .map (lambda x : _clean_files ( x [ 0 ], os .path .splitext ( os . path . dirname ( os . path . dirname ( x [ 0 ])))[ 0 ], x [ 1 ]) )
257- cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .filter (lambda x : full_tgt_filepath in x )
287+ cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .filter (lambda x : os .path .basename ( uncleaned_tgt_filename ) in x [ 0 ] )
288+ cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .map (lambda x : _clean_files ( full_tgt_filepath , x [ 0 ], x [ 1 ]) )
258289 cache_inner_tgt_decompressed_dp = FileOpener (cache_inner_tgt_decompressed_dp , mode = "b" )
259290 cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
260291
261- tgt_data_dp = FileOpener (cache_inner_tgt_decompressed_dp , mode = "r " )
262- src_data_dp = FileOpener (cache_inner_src_decompressed_dp , mode = "r " )
292+ tgt_data_dp = FileOpener (cache_inner_tgt_decompressed_dp , mode = "b " )
293+ src_data_dp = FileOpener (cache_inner_src_decompressed_dp , mode = "b " )
263294
264295 src_lines = src_data_dp .readlines (return_path = False , strip_newline = False , decode = True )
265296 tgt_lines = tgt_data_dp .readlines (return_path = False , strip_newline = False , decode = True )
0 commit comments