66import os
77from torchtext .data .datasets_utils import (
88 _wrap_split_argument ,
9- _clean_xml_file ,
10- _clean_tags_file ,
9+ _clean_inner_xml_file ,
10+ _clean_inner_tags_file ,
11+ _create_dataset_directory ,
12+ _rewrite_text_file ,
1113)
12- from torchtext .data .datasets_utils import _create_dataset_directory
1314
1415URL = 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8'
1516
@@ -211,32 +212,27 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
211212 hash_dict = {os .path .join (root , _PATH ): MD5 },
212213 hash_type = "md5"
213214 )
214- cache_compressed_dp = GDriveReader (cache_compressed_dp ).end_caching (mode = "wb" , same_filepath_fn = True )
215- cache_compressed_dp = FileOpener (cache_compressed_dp , mode = "b" )
216- src_language = train_filenames [0 ].split ("." )[- 1 ]
217- tgt_language = train_filenames [1 ].split ("." )[- 1 ]
215+ cache_compressed_dp = GDriveReader (cache_compressed_dp )
216+ cache_compressed_dp = cache_compressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
217+
218218 languages = "-" .join ([src_language , tgt_language ])
219219
220220 iwslt_tar = os .path .join (
221221 "texts" , src_language , tgt_language , languages
222- )
222+ ) + ".tgz"
223+
223224 cache_decompressed_dp = cache_compressed_dp .on_disk_cache (
224- # Convert /path/to/downloaded/foo.tgz to /path/to/downloaded/foo/rest/of/path
225- filepath_fn = lambda x : os .path .join (os .path .splitext (x [0 ])[0 ], iwslt_tar )
225+ filepath_fn = lambda x : os .path .join (root , os .path .splitext (_PATH )[0 ], iwslt_tar )
226226 )
227- cache_decompressed_dp = cache_decompressed_dp .read_from_tar ()
228- cache_decompressed_dp = cache_decompressed_dp .end_caching (mode = "wb" )
227+ cache_decompressed_dp = FileOpener ( cache_decompressed_dp , mode = "b" ) .read_from_tar (). filter ( lambda x : iwslt_tar in x [ 0 ] )
228+ cache_decompressed_dp = cache_decompressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
229229
230- def clean_files (fname ):
230+ def clean_files (fname , base , stream ):
231231 if 'xml' in fname :
232- _clean_xml_file (fname )
233- return os .path .splitext (fname )[0 ]
232+ return _clean_inner_xml_file (fname , base , stream )
234233 elif "tags" in fname :
235- _clean_tags_file (fname )
236- return fname .replace ('.tags' , '' )
237- return fname
238-
239- cache_decompressed_dp = cache_decompressed_dp .flatmap (FileLister )
234+ return _clean_inner_tags_file (fname , base , stream )
235+ return _rewrite_text_file (fname , base , stream )
240236
241237 def get_filepath (split , lang ):
242238 return {
@@ -252,15 +248,28 @@ def get_filepath(split, lang):
252248 }
253249 }[lang ][split ]
254250
255- cleaned_cache_decompressed_dp = cache_decompressed_dp .map (clean_files )
251+ cache_inner_src_decompressed_dp = cache_decompressed_dp .on_disk_cache (
252+ filepath_fn = lambda x : os .path .join (root , "2016-01/texts/" , src_language , tgt_language , languages , get_filepath (split , src_language ))
253+ )
254+ cache_inner_src_decompressed_dp = FileOpener (cache_inner_src_decompressed_dp , mode = "b" ).read_from_tar ()
255+ cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .map (lambda x : clean_files (x [0 ], os .path .splitext (os .path .dirname (os .path .dirname (x [0 ])))[0 ], x [1 ]))
256+ cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .filter (lambda x : get_filepath (split , src_language ) in x )
257+ cache_inner_src_decompressed_dp = FileOpener (cache_inner_src_decompressed_dp , mode = "b" )
258+ cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
259+
260+ cache_inner_tgt_decompressed_dp = cache_decompressed_dp .on_disk_cache (
261+ filepath_fn = lambda x : os .path .join (root , "2016-01/texts/" , src_language , tgt_language , languages , get_filepath (split , tgt_language ))
262+ )
263+ cache_inner_tgt_decompressed_dp = FileOpener (cache_inner_tgt_decompressed_dp , mode = "b" ).read_from_tar ()
264+ cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .map (lambda x : clean_files (x [0 ], os .path .splitext (os .path .dirname (os .path .dirname (x [0 ])))[0 ], x [1 ]))
265+ cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .filter (lambda x : get_filepath (split , tgt_language ) in x )
266+ cache_inner_tgt_decompressed_dp = FileOpener (cache_inner_tgt_decompressed_dp , mode = "b" )
267+ cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
256268
257- # Filters out irrelevant file given the filename templates filled with split and src/tgt codes
258- src_data_dp = cleaned_cache_decompressed_dp .filter (lambda x : get_filepath (split , src_language ) in x )
259- tgt_data_dp = cleaned_cache_decompressed_dp .filter (lambda x : get_filepath (split , tgt_language ) in x )
269+ tgt_data_dp = FileOpener (cache_inner_tgt_decompressed_dp , mode = "r" )
270+ src_data_dp = FileOpener (cache_inner_src_decompressed_dp , mode = "r" )
260271
261- tgt_data_dp = FileOpener ( tgt_data_dp , mode = "r" )
262- src_data_dp = FileOpener ( src_data_dp , mode = "r" )
272+ src_lines = src_data_dp . readlines ( return_path = False , strip_newline = False , decode = True )
273+ tgt_lines = tgt_data_dp . readlines ( return_path = False , strip_newline = False , decode = True )
263274
264- src_lines = src_data_dp .readlines (return_path = False , strip_newline = False )
265- tgt_lines = tgt_data_dp .readlines (return_path = False , strip_newline = False )
266275 return src_lines .zip (tgt_lines )
0 commit comments