@@ -164,7 +164,7 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
164164 Examples:
165165 >>> from torchtext.datasets import IWSLT2016
166166 >>> train_iter, valid_iter, test_iter = IWSLT2016()
167- >>> src_sentence, tgt_sentence = next(train_iter)
167+ >>> src_sentence, tgt_sentence = next(iter( train_iter) )
168168
169169 """
170170 if not is_module_available ("torchdata" ):
@@ -204,6 +204,17 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
204204 src_eval , tgt_eval = valid_filenames
205205 src_test , tgt_test = test_filenames
206206
207+ uncleaned_train_filenames = ('train.tags.{}-{}.{}' .format (src_language , tgt_language , src_language ),
208+ 'train.tags.{}-{}.{}' .format (src_language , tgt_language , tgt_language ))
209+ uncleaed_valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml' .format (SUPPORTED_DATASETS ['year' ], valid_set , src_language , tgt_language , src_language ),
210+ 'IWSLT{}.TED.{}.{}-{}.{}.xml' .format (SUPPORTED_DATASETS ['year' ], valid_set , src_language , tgt_language , tgt_language ))
211+ uncleaned_test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml' .format (SUPPORTED_DATASETS ['year' ], test_set , src_language , tgt_language , src_language ),
212+ 'IWSLT{}.TED.{}.{}-{}.{}.xml' .format (SUPPORTED_DATASETS ['year' ], test_set , src_language , tgt_language , tgt_language ))
213+
214+ uncleaned_src_train , uncleaned_tgt_train = uncleaned_train_filenames
215+ uncleaned_src_eval , uncleaned_tgt_eval = uncleaed_valid_filenames
216+ uncleaned_src_test , uncleaned_tgt_test = uncleaned_test_filenames
217+
207218 url_dp = IterableWrapper ([URL ])
208219 cache_compressed_dp = url_dp .on_disk_cache (
209220 filepath_fn = lambda x : os .path .join (root , _PATH ),
@@ -215,14 +226,12 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
215226
216227 languages = "-" .join ([src_language , tgt_language ])
217228
218- inner_iwslt_tar = os . path . join (
219- root , os . path . splitext ( _PATH )[ 0 ], " texts" , src_language , tgt_language , languages
220- ) + ".tgz"
229+ # Intentionally don't use root here because we're inspecting contents of a not-fully
230+ # extracted tgz, so / root/.../ texts/... will not match /root/.../2016-01.tgz/texts/...
231+ inner_iwslt_tar = os . path . join ( root , os . path . splitext ( _PATH )[ 0 ], "texts" , src_language , tgt_language , languages ) + ".tgz"
221232
222- cache_decompressed_dp = cache_compressed_dp .on_disk_cache (
223- filepath_fn = lambda x : inner_iwslt_tar
224- )
225- cache_decompressed_dp = FileOpener (cache_decompressed_dp , mode = "b" ).read_from_tar ().filter (lambda x : inner_iwslt_tar in x [0 ])
233+ cache_decompressed_dp = cache_compressed_dp .on_disk_cache (filepath_fn = lambda x : inner_iwslt_tar )
234+ cache_decompressed_dp = FileOpener (cache_decompressed_dp , mode = "b" ).read_from_tar ().filter (lambda x : os .path .basename (inner_iwslt_tar ) in x [0 ])
226235 cache_decompressed_dp = cache_decompressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
227236
228237 file_path_by_lang_and_split = {
@@ -238,28 +247,45 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
238247 }
239248 }
240249
250+ uncleaned_filenames = {
251+ src_language : {
252+ "train" : uncleaned_src_train ,
253+ "valid" : uncleaned_src_eval ,
254+ "test" : uncleaned_src_test ,
255+ },
256+ tgt_language : {
257+ "train" : uncleaned_tgt_train ,
258+ "valid" : uncleaned_tgt_eval ,
259+ "test" : uncleaned_tgt_test ,
260+ }
261+ }
262+
241263 src_filename = file_path_by_lang_and_split [src_language ][split ]
264+ uncleaned_src_filename = uncleaned_filenames [src_language ][split ]
265+
242266 full_src_filepath = os .path .join (root , "2016-01/texts/" , src_language , tgt_language , languages , src_filename )
243267
244268 cache_inner_src_decompressed_dp = cache_decompressed_dp .on_disk_cache (filepath_fn = lambda x : full_src_filepath )
245269 cache_inner_src_decompressed_dp = FileOpener (cache_inner_src_decompressed_dp , mode = "b" ).read_from_tar ()
246- cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .map (lambda x : _clean_files ( x [ 0 ], os .path .splitext ( os . path . dirname ( os . path . dirname ( x [ 0 ])))[ 0 ], x [ 1 ]) )
247- cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .filter (lambda x : full_src_filepath in x )
270+ cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .filter (lambda x : os .path .basename ( uncleaned_src_filename ) in x [ 0 ] )
271+ cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .map (lambda x : _clean_files ( full_src_filepath , x [ 0 ], x [ 1 ]) )
248272 cache_inner_src_decompressed_dp = FileOpener (cache_inner_src_decompressed_dp , mode = "b" )
249273 cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
250274
251275 tgt_filename = file_path_by_lang_and_split [tgt_language ][split ]
276+ uncleaned_tgt_filename = uncleaned_filenames [tgt_language ][split ]
277+
252278 full_tgt_filepath = os .path .join (root , "2016-01/texts/" , src_language , tgt_language , languages , tgt_filename )
253279
254280 cache_inner_tgt_decompressed_dp = cache_decompressed_dp .on_disk_cache (filepath_fn = lambda x : full_tgt_filepath )
255281 cache_inner_tgt_decompressed_dp = FileOpener (cache_inner_tgt_decompressed_dp , mode = "b" ).read_from_tar ()
256- cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .map (lambda x : _clean_files ( x [ 0 ], os .path .splitext ( os . path . dirname ( os . path . dirname ( x [ 0 ])))[ 0 ], x [ 1 ]) )
257- cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .filter (lambda x : full_tgt_filepath in x )
282+ cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .filter (lambda x : os .path .basename ( uncleaned_tgt_filename ) in x [ 0 ] )
283+ cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .map (lambda x : _clean_files ( full_tgt_filepath , x [ 0 ], x [ 1 ]) )
258284 cache_inner_tgt_decompressed_dp = FileOpener (cache_inner_tgt_decompressed_dp , mode = "b" )
259285 cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
260286
261- tgt_data_dp = FileOpener (cache_inner_tgt_decompressed_dp , mode = "r " )
262- src_data_dp = FileOpener (cache_inner_src_decompressed_dp , mode = "r " )
287+ tgt_data_dp = FileOpener (cache_inner_tgt_decompressed_dp , mode = "b " )
288+ src_data_dp = FileOpener (cache_inner_src_decompressed_dp , mode = "b " )
263289
264290 src_lines = src_data_dp .readlines (return_path = False , strip_newline = False , decode = True )
265291 tgt_lines = tgt_data_dp .readlines (return_path = False , strip_newline = False , decode = True )
0 commit comments