Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 60f0b80

Browse files
committed
refactors some of the caching logic and cleaners
1 parent b668624 commit 60f0b80

File tree

2 files changed

+74
-54
lines changed

2 files changed

+74
-54
lines changed

torchtext/data/datasets_utils.py

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,23 @@ def _clean_xml_file(f_xml):
3333
fd_txt.write(e.text.strip() + '\n')
3434

3535

36-
def _clean_inner_xml_file(f_xml, base, stream):
37-
"""Accepts an XML filename within a tarball and a stream of the byte contents
38-
within that file and writes the cleaned contents to a new, untarred file
39-
found in the provided base directory.
36+
def _clean_inner_xml_file(outfile, stream):
37+
"""Accepts an output filename and a stream of the byte contents of an XML file
38+
within a tarball and writes the cleaned contents to a new, untarred file.
4039
4140
Args:
42-
f_orig: the full path of the XML file in the archive
43-
base: the directory to which the new file should be written
44-
stream: the byte datapipe of the contents of f_orig
41+
outfile: the path to which the modified stream should be written
42+
stream: the byte datapipe of the contents of the archived XML file
4543
4644
Returns: the path to the newly-written file
4745
"""
48-
f_txt = os.path.basename(os.path.splitext(f_xml)[0])
49-
os.makedirs(base, exist_ok=True)
50-
out_file = os.path.join(base, f_txt)
51-
with codecs.open(out_file, mode='w', encoding='utf-8') as fd_txt:
46+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
47+
with codecs.open(outfile, mode='w', encoding='utf-8') as fd_txt:
5248
root = ET.fromstring(stream.read().decode("utf-8"))[0]
5349
for doc in root.findall('doc'):
5450
for e in doc.findall('seg'):
5551
fd_txt.write(e.text.strip() + '\n')
56-
return os.path.join(base, f_txt)
52+
return outfile
5753

5854

5955
def _clean_tags_file(f_orig):
@@ -73,62 +69,55 @@ def _clean_tags_file(f_orig):
7369
fd_txt.write(line.strip() + '\n')
7470

7571

76-
def _clean_inner_tags_file(f_orig, base, stream):
77-
"""Accepts a tags filename within a tarball and a stream of the byte contents
78-
within that file and writes the cleaned contents to a new, untarred file
79-
found in the provided base directory.
72+
def _clean_inner_tags_file(outfile, stream):
73+
"""Accepts an output filename and a stream of the byte contents of a tags file
74+
within a tarball and writes the cleaned contents to a new, untarred file.
8075
8176
Args:
82-
f_orig: the full path of the tags file in the archive
83-
base: the directory to which the new file should be written
84-
stream: the byte datapipe of the contents of f_orig
77+
outfile: the path to which the modified stream should be written
78+
stream: the byte datapipe of the contents of the archived tags file
8579
8680
Returns: the path to the newly-written file
8781
"""
8882
xml_tags = [
8983
'<url', '<keywords', '<talkid', '<description', '<reviewer',
9084
'<translator', '<title', '<speaker', '<doc', '</doc'
9185
]
92-
f_txt = os.path.join(base, os.path.basename(f_orig.replace('.tags', '')))
93-
os.makedirs(base, exist_ok=True)
94-
with codecs.open(f_txt, mode='w', encoding='utf-8') as fd_txt:
86+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
87+
with codecs.open(outfile, mode='w', encoding='utf-8') as fd_txt:
9588
for line in stream.readlines():
9689
if not any(tag in line.decode("utf-8") for tag in xml_tags):
9790
# TODO: Fix utf-8 next line mark
9891
# fd_txt.write(l.strip() + '\n')
9992
# fd_txt.write(l.strip() + u"\u0085")
10093
# fd_txt.write(l.lstrip())
10194
fd_txt.write(line.decode("utf-8").strip() + '\n')
102-
return f_txt
95+
return outfile
10396

10497

105-
def _rewrite_text_file(file, base, stream):
106-
"""Accepts a text filename within a tarball and a stream of the byte contents
107-
within that file and writes the cleaned contents to a new, untarred file
108-
found in the provided base directory.
98+
def _rewrite_text_file(outfile, stream):
99+
"""Accepts an output filename and a stream of the byte contents of a text file
100+
within a tarball and writes the cleaned contents to a new, untarred file.
109101
110102
Args:
111-
f_orig: the full path of the text file in the archive
112-
base: the directory to which the new file should be written
113-
stream: the byte datapipe of the contents of f_orig
103+
outfile: the path to which the modified stream should be written
104+
stream: the byte datapipe of the contents of the archived text file
114105
115106
Returns: the path to the newly-written file
116107
"""
117-
f_txt = os.path.basename(file)
118-
os.makedirs(base, exist_ok=True)
119-
out_file = os.path.join(base, f_txt)
120-
with open(out_file, 'w', encoding='utf-8') as f:
108+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
109+
with open(outfile, 'w', encoding='utf-8') as f:
121110
for line in stream.readlines():
122-
f.write(line.decode("utf-8"))
123-
return out_file
111+
f.write(line.decode("utf-8") + "\n")
112+
return outfile
124113

125114

126-
def _clean_files(fname, base, stream):
115+
def _clean_files(outfile, fname, stream):
127116
if 'xml' in fname:
128-
return _clean_inner_xml_file(fname, base, stream)
117+
return _clean_inner_xml_file(outfile, stream)
129118
elif "tags" in fname:
130-
return _clean_inner_tags_file(fname, base, stream)
131-
return _rewrite_text_file(fname, base, stream)
119+
return _clean_inner_tags_file(outfile, stream)
120+
return _rewrite_text_file(outfile, stream)
132121

133122

134123
def _create_data_from_json(data_path):

torchtext/datasets/iwslt2016.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
164164
Examples:
165165
>>> from torchtext.datasets import IWSLT2016
166166
>>> train_iter, valid_iter, test_iter = IWSLT2016()
167-
>>> src_sentence, tgt_sentence = next(train_iter)
167+
>>> src_sentence, tgt_sentence = next(iter(train_iter))
168168
169169
"""
170170
if not is_module_available("torchdata"):
@@ -204,6 +204,17 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
204204
src_eval, tgt_eval = valid_filenames
205205
src_test, tgt_test = test_filenames
206206

207+
uncleaned_train_filenames = ('train.tags.{}-{}.{}'.format(src_language, tgt_language, src_language),
208+
'train.tags.{}-{}.{}'.format(src_language, tgt_language, tgt_language))
209+
uncleaed_valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language),
210+
'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language))
211+
uncleaned_test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language),
212+
'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language))
213+
214+
uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames
215+
uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames
216+
uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames
217+
207218
url_dp = IterableWrapper([URL])
208219
cache_compressed_dp = url_dp.on_disk_cache(
209220
filepath_fn=lambda x: os.path.join(root, _PATH),
@@ -215,14 +226,13 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
215226

216227
languages = "-".join([src_language, tgt_language])
217228

218-
inner_iwslt_tar = os.path.join(
219-
root, os.path.splitext(_PATH)[0], "texts", src_language, tgt_language, languages
220-
) + ".tgz"
229+
# We create the whole filepath here, but only check for the literal filename in the filter
230+
# because we're lazily extracting from the outer tarfile. Thus,
231+
# /root/2016-01/texts/.../src-tgt.tgz will never be in /root/2016-01.tgz/texts/.../src-tgt.tgz
232+
inner_iwslt_tar = os.path.join(root, os.path.splitext(_PATH)[0], "texts", src_language, tgt_language, languages) + ".tgz"
221233

222-
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
223-
filepath_fn=lambda x: inner_iwslt_tar
224-
)
225-
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: inner_iwslt_tar in x[0])
234+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar)
235+
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0])
226236
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
227237

228238
file_path_by_lang_and_split = {
@@ -238,28 +248,49 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
238248
}
239249
}
240250

251+
uncleaned_filenames = {
252+
src_language: {
253+
"train": uncleaned_src_train,
254+
"valid": uncleaned_src_eval,
255+
"test": uncleaned_src_test,
256+
},
257+
tgt_language: {
258+
"train": uncleaned_tgt_train,
259+
"valid": uncleaned_tgt_eval,
260+
"test": uncleaned_tgt_test,
261+
}
262+
}
263+
241264
src_filename = file_path_by_lang_and_split[src_language][split]
265+
uncleaned_src_filename = uncleaned_filenames[src_language][split]
266+
267+
# We create the whole filepath here, but only check for the literal filename in the filter
268+
# because we're lazily extracting from the outer tarfile.
242269
full_src_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filename)
243270

244271
cache_inner_src_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_src_filepath)
245272
cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b").read_from_tar()
246-
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1]))
247-
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: full_src_filepath in x)
273+
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: os.path.basename(uncleaned_src_filename) in x[0])
274+
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: _clean_files(full_src_filepath, x[0], x[1]))
248275
cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b")
249276
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
250277

251278
tgt_filename = file_path_by_lang_and_split[tgt_language][split]
279+
uncleaned_tgt_filename = uncleaned_filenames[tgt_language][split]
280+
281+
# We create the whole filepath here, but only check for the literal filename in the filter
282+
# because we're lazily extracting from the outer tarfile.
252283
full_tgt_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filename)
253284

254285
cache_inner_tgt_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_tgt_filepath)
255286
cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b").read_from_tar()
256-
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1]))
257-
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: full_tgt_filepath in x)
287+
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: os.path.basename(uncleaned_tgt_filename) in x[0])
288+
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: _clean_files(full_tgt_filepath, x[0], x[1]))
258289
cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b")
259290
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
260291

261-
tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="r")
262-
src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="r")
292+
tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b")
293+
src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b")
263294

264295
src_lines = src_data_dp.readlines(return_path=False, strip_newline=False, decode=True)
265296
tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False, decode=True)

0 commit comments

Comments
 (0)