Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 5dd29ea

Browse files
committed
simplify logic for IWSLT2016.
1 parent 05ec62f commit 5dd29ea

File tree

1 file changed

+37
-28
lines changed

1 file changed

+37
-28
lines changed

torchtext/datasets/iwslt2016.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import os
77
from torchtext.data.datasets_utils import (
88
_wrap_split_argument,
9-
_clean_xml_file,
10-
_clean_tags_file,
9+
_clean_inner_xml_file,
10+
_clean_inner_tags_file,
11+
_create_dataset_directory,
12+
_rewrite_text_file,
1113
)
12-
from torchtext.data.datasets_utils import _create_dataset_directory
1314

1415
URL = 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8'
1516

@@ -211,32 +212,27 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
211212
hash_dict={os.path.join(root, _PATH): MD5},
212213
hash_type="md5"
213214
)
214-
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)
215-
cache_compressed_dp = FileOpener(cache_compressed_dp, mode="b")
216-
src_language = train_filenames[0].split(".")[-1]
217-
tgt_language = train_filenames[1].split(".")[-1]
215+
cache_compressed_dp = GDriveReader(cache_compressed_dp)
216+
cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True)
217+
218218
languages = "-".join([src_language, tgt_language])
219219

220220
iwslt_tar = os.path.join(
221221
"texts", src_language, tgt_language, languages
222-
)
222+
) + ".tgz"
223+
223224
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
224-
# Convert /path/to/downloaded/foo.tgz to /path/to/downloaded/foo/rest/of/path
225-
filepath_fn=lambda x: os.path.join(os.path.splitext(x[0])[0], iwslt_tar)
225+
filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0], iwslt_tar)
226226
)
227-
cache_decompressed_dp = cache_decompressed_dp.read_from_tar()
228-
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb")
227+
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: iwslt_tar in x[0])
228+
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
229229

230-
def clean_files(fname):
230+
def clean_files(fname, base, stream):
231231
if 'xml' in fname:
232-
_clean_xml_file(fname)
233-
return os.path.splitext(fname)[0]
232+
return _clean_inner_xml_file(fname, base, stream)
234233
elif "tags" in fname:
235-
_clean_tags_file(fname)
236-
return fname.replace('.tags', '')
237-
return fname
238-
239-
cache_decompressed_dp = cache_decompressed_dp.flatmap(FileLister)
234+
return _clean_inner_tags_file(fname, base, stream)
235+
return _rewrite_text_file(fname, base, stream)
240236

241237
def get_filepath(split, lang):
242238
return {
@@ -252,15 +248,28 @@ def get_filepath(split, lang):
252248
}
253249
}[lang][split]
254250

255-
cleaned_cache_decompressed_dp = cache_decompressed_dp.map(clean_files)
251+
cache_inner_src_decompressed_dp = cache_decompressed_dp.on_disk_cache(
252+
filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, get_filepath(split, src_language))
253+
)
254+
cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b").read_from_tar()
255+
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1]))
256+
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: get_filepath(split, src_language) in x)
257+
cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b")
258+
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
259+
260+
cache_inner_tgt_decompressed_dp = cache_decompressed_dp.on_disk_cache(
261+
filepath_fn=lambda x: os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, get_filepath(split, tgt_language))
262+
)
263+
cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b").read_from_tar()
264+
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1]))
265+
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: get_filepath(split, tgt_language) in x)
266+
cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b")
267+
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
256268

257-
# Filters out irrelevant file given the filename templates filled with split and src/tgt codes
258-
src_data_dp = cleaned_cache_decompressed_dp.filter(lambda x: get_filepath(split, src_language) in x)
259-
tgt_data_dp = cleaned_cache_decompressed_dp.filter(lambda x: get_filepath(split, tgt_language) in x)
269+
tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="r")
270+
src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="r")
260271

261-
tgt_data_dp = FileOpener(tgt_data_dp, mode="r")
262-
src_data_dp = FileOpener(src_data_dp, mode="r")
272+
src_lines = src_data_dp.readlines(return_path=False, strip_newline=False, decode=True)
273+
tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False, decode=True)
263274

264-
src_lines = src_data_dp.readlines(return_path=False, strip_newline=False)
265-
tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False)
266275
return src_lines.zip(tgt_lines)

0 commit comments

Comments
 (0)