1818MD5 = 'c393ed3fc2a1b0f004b3331043f615ae'
1919
2020SUPPORTED_DATASETS = {
21-
2221 'valid_test' : ['dev2010' , 'tst2010' , 'tst2011' , 'tst2012' , 'tst2013' , 'tst2014' ],
2322 'language_pair' : {
2423 'en' : ['ar' , 'de' , 'fr' , 'cs' ],
2827 'cs' : ['en' ],
2928 },
3029 'year' : 16 ,
31-
3230}
3331
3432NUM_LINES = {
127125 ('cs' , 'en' ): ['tst2014' ]
128126}
129127
130-
131- def _construct_filenames (filename , languages ):
132- filenames = []
133- for lang in languages :
134- filenames .append (filename + "." + lang )
135- return filenames
136-
137-
138- def _construct_filepath (path , src_filename , tgt_filename ):
139- src_path = None
140- tgt_path = None
141- src_path = path if src_filename in path else src_path
142- tgt_path = path if tgt_filename in path else tgt_path
143- return src_path , tgt_path
144-
145-
146- def _construct_filepaths (paths , src_filename , tgt_filename ):
147- src_path = None
148- tgt_path = None
149- for p in paths :
150- src_path , tgt_path = _construct_filepath (p , src_filename , tgt_filename )
151- return src_path , tgt_path
152-
153-
154128DATASET_NAME = "IWSLT2016"
155129
156130
@@ -247,6 +221,7 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
247221 "texts" , src_language , tgt_language , languages
248222 )
249223 cache_decompressed_dp = cache_compressed_dp .on_disk_cache (
224+ # Convert /path/to/downloaded/foo.tgz to /path/to/downloaded/foo/rest/of/path
250225 filepath_fn = lambda x : os .path .join (os .path .splitext (x [0 ])[0 ], iwslt_tar )
251226 )
252227 cache_decompressed_dp = cache_decompressed_dp .read_from_tar ()
@@ -263,27 +238,28 @@ def clean_files(fname):
263238
264239 cache_decompressed_dp = cache_decompressed_dp .flatmap (FileLister )
265240
266- def get_filepath (f ):
267- src_file , tgt_file = {
268- "train" : _construct_filepath (f , src_train , tgt_train ),
269- "valid" : _construct_filepath (f , src_eval , tgt_eval ),
270- "test" : _construct_filepath (f , src_test , tgt_test )
271- }[split ]
272-
273- return src_file , tgt_file
274-
275- cleaned_cache_decompressed_dp = cache_decompressed_dp .map (clean_files ).map (get_filepath )
276-
277- # pairs of filenames are either both None or one of src/tgt is None.
278- # filter out both None since they're not relevant
279- cleaned_cache_decompressed_dp = cleaned_cache_decompressed_dp .filter (lambda x : x != (None , None ))
280-
281- # (None, tgt) => 1, (src, None) => 0
282- tgt_data_dp , src_data_dp = cleaned_cache_decompressed_dp .demux (2 , lambda x : x .index (None ))
283-
284- # Pull out the non-None element (i.e., filename) from the tuple
285- tgt_data_dp = FileOpener (tgt_data_dp .map (lambda x : x [1 ]), mode = "r" )
286- src_data_dp = FileOpener (src_data_dp .map (lambda x : x [0 ]), mode = "r" )
241+ def get_filepath (split , lang ):
242+ return {
243+ src_language : {
244+ "train" : src_train ,
245+ "valid" : src_eval ,
246+ "test" : src_test ,
247+ },
248+ tgt_language : {
249+ "train" : tgt_train ,
250+ "valid" : tgt_eval ,
251+ "test" : tgt_test ,
252+ }
253+ }[lang ][split ]
254+
255+ cleaned_cache_decompressed_dp = cache_decompressed_dp .map (clean_files )
256+
257+ # Filters out irrelevant file given the filename templates filled with split and src/tgt codes
258+ src_data_dp = cleaned_cache_decompressed_dp .filter (lambda x : get_filepath (split , src_language ) in x )
259+ tgt_data_dp = cleaned_cache_decompressed_dp .filter (lambda x : get_filepath (split , tgt_language ) in x )
260+
261+ tgt_data_dp = FileOpener (tgt_data_dp , mode = "r" )
262+ src_data_dp = FileOpener (src_data_dp , mode = "r" )
287263
288264 src_lines = src_data_dp .readlines (return_path = False , strip_newline = False )
289265 tgt_lines = tgt_data_dp .readlines (return_path = False , strip_newline = False )
0 commit comments