Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit a88a99e

Browse files
committed
refactors some of the caching logic and cleaners
1 parent b668624 commit a88a99e

File tree

2 files changed

+69
-54
lines changed

2 files changed

+69
-54
lines changed

torchtext/data/datasets_utils.py

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,23 @@ def _clean_xml_file(f_xml):
3333
fd_txt.write(e.text.strip() + '\n')
3434

3535

36-
def _clean_inner_xml_file(f_xml, base, stream):
37-
"""Accepts an XML filename within a tarball and a stream of the byte contents
38-
within that file and writes the cleaned contents to a new, untarred file
39-
found in the provided base directory.
36+
def _clean_inner_xml_file(outfile, stream):
37+
"""Accepts an output filename and a stream of the byte contents of an XML file
38+
within a tarball and writes the cleaned contents to a new, untarred file.
4039
4140
Args:
42-
f_orig: the full path of the XML file in the archive
43-
base: the directory to which the new file should be written
44-
stream: the byte datapipe of the contents of f_orig
41+
outfile: the path to which the modified stream should be written
42+
stream: the byte datapipe of the contents of the archived XML file
4543
4644
Returns: the path to the newly-written file
4745
"""
48-
f_txt = os.path.basename(os.path.splitext(f_xml)[0])
49-
os.makedirs(base, exist_ok=True)
50-
out_file = os.path.join(base, f_txt)
51-
with codecs.open(out_file, mode='w', encoding='utf-8') as fd_txt:
46+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
47+
with codecs.open(outfile, mode='w', encoding='utf-8') as fd_txt:
5248
root = ET.fromstring(stream.read().decode("utf-8"))[0]
5349
for doc in root.findall('doc'):
5450
for e in doc.findall('seg'):
5551
fd_txt.write(e.text.strip() + '\n')
56-
return os.path.join(base, f_txt)
52+
return outfile
5753

5854

5955
def _clean_tags_file(f_orig):
@@ -73,62 +69,55 @@ def _clean_tags_file(f_orig):
7369
fd_txt.write(line.strip() + '\n')
7470

7571

76-
def _clean_inner_tags_file(f_orig, base, stream):
77-
"""Accepts a tags filename within a tarball and a stream of the byte contents
78-
within that file and writes the cleaned contents to a new, untarred file
79-
found in the provided base directory.
72+
def _clean_inner_tags_file(outfile, stream):
73+
"""Accepts an output filename and a stream of the byte contents of a tags file
74+
within a tarball and writes the cleaned contents to a new, untarred file.
8075
8176
Args:
82-
f_orig: the full path of the tags file in the archive
83-
base: the directory to which the new file should be written
84-
stream: the byte datapipe of the contents of f_orig
77+
outfile: the path to which the modified stream should be written
78+
stream: the byte datapipe of the contents of the archived tags file
8579
8680
Returns: the path to the newly-written file
8781
"""
8882
xml_tags = [
8983
'<url', '<keywords', '<talkid', '<description', '<reviewer',
9084
'<translator', '<title', '<speaker', '<doc', '</doc'
9185
]
92-
f_txt = os.path.join(base, os.path.basename(f_orig.replace('.tags', '')))
93-
os.makedirs(base, exist_ok=True)
94-
with codecs.open(f_txt, mode='w', encoding='utf-8') as fd_txt:
86+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
87+
with codecs.open(outfile, mode='w', encoding='utf-8') as fd_txt:
9588
for line in stream.readlines():
9689
if not any(tag in line.decode("utf-8") for tag in xml_tags):
9790
# TODO: Fix utf-8 next line mark
9891
# fd_txt.write(l.strip() + '\n')
9992
# fd_txt.write(l.strip() + u"\u0085")
10093
# fd_txt.write(l.lstrip())
10194
fd_txt.write(line.decode("utf-8").strip() + '\n')
102-
return f_txt
95+
return outfile
10396

10497

105-
def _rewrite_text_file(file, base, stream):
106-
"""Accepts a text filename within a tarball and a stream of the byte contents
107-
within that file and writes the cleaned contents to a new, untarred file
108-
found in the provided base directory.
98+
def _rewrite_text_file(outfile, stream):
99+
"""Accepts an output filename and a stream of the byte contents of a text file
100+
within a tarball and writes the cleaned contents to a new, untarred file.
109101
110102
Args:
111-
f_orig: the full path of the text file in the archive
112-
base: the directory to which the new file should be written
113-
stream: the byte datapipe of the contents of f_orig
103+
outfile: the path to which the modified stream should be written
104+
stream: the byte datapipe of the contents of the archived text file
114105
115106
Returns: the path to the newly-written file
116107
"""
117-
f_txt = os.path.basename(file)
118-
os.makedirs(base, exist_ok=True)
119-
out_file = os.path.join(base, f_txt)
120-
with open(out_file, 'w', encoding='utf-8') as f:
108+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
109+
with open(outfile, 'w', encoding='utf-8') as f:
121110
for line in stream.readlines():
122-
f.write(line.decode("utf-8"))
123-
return out_file
111+
f.write(line.decode("utf-8") + "\n")
112+
return outfile
124113

125114

126-
def _clean_files(fname, base, stream):
115+
def _clean_files(outfile, fname, stream):
127116
if 'xml' in fname:
128-
return _clean_inner_xml_file(fname, base, stream)
117+
return _clean_inner_xml_file(outfile, stream)
129118
elif "tags" in fname:
130-
return _clean_inner_tags_file(fname, base, stream)
131-
return _rewrite_text_file(fname, base, stream)
119+
return _clean_inner_tags_file(outfile, stream)
120+
return _rewrite_text_file(outfile, stream)
132121

133122

134123
def _create_data_from_json(data_path):

torchtext/datasets/iwslt2016.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
164164
Examples:
165165
>>> from torchtext.datasets import IWSLT2016
166166
>>> train_iter, valid_iter, test_iter = IWSLT2016()
167-
>>> src_sentence, tgt_sentence = next(train_iter)
167+
>>> src_sentence, tgt_sentence = next(iter(train_iter))
168168
169169
"""
170170
if not is_module_available("torchdata"):
@@ -204,6 +204,17 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
204204
src_eval, tgt_eval = valid_filenames
205205
src_test, tgt_test = test_filenames
206206

207+
uncleaned_train_filenames = ('train.tags.{}-{}.{}'.format(src_language, tgt_language, src_language),
208+
'train.tags.{}-{}.{}'.format(src_language, tgt_language, tgt_language))
209+
uncleaed_valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language),
210+
'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language))
211+
uncleaned_test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language),
212+
'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language))
213+
214+
uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames
215+
uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames
216+
uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames
217+
207218
url_dp = IterableWrapper([URL])
208219
cache_compressed_dp = url_dp.on_disk_cache(
209220
filepath_fn=lambda x: os.path.join(root, _PATH),
@@ -215,14 +226,12 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
215226

216227
languages = "-".join([src_language, tgt_language])
217228

218-
inner_iwslt_tar = os.path.join(
219-
root, os.path.splitext(_PATH)[0], "texts", src_language, tgt_language, languages
220-
) + ".tgz"
229+
# Intentionally don't use root here because we're inspecting contents of a not-fully
230+
# extracted tgz, so /root/.../texts/... will not match /root/.../2016-01.tgz/texts/...
231+
inner_iwslt_tar = os.path.join(root, os.path.splitext(_PATH)[0], "texts", src_language, tgt_language, languages) + ".tgz"
221232

222-
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
223-
filepath_fn=lambda x: inner_iwslt_tar
224-
)
225-
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: inner_iwslt_tar in x[0])
233+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar)
234+
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0])
226235
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
227236

228237
file_path_by_lang_and_split = {
@@ -238,28 +247,45 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
238247
}
239248
}
240249

250+
uncleaned_filenames = {
251+
src_language: {
252+
"train": uncleaned_src_train,
253+
"valid": uncleaned_src_eval,
254+
"test": uncleaned_src_test,
255+
},
256+
tgt_language: {
257+
"train": uncleaned_tgt_train,
258+
"valid": uncleaned_tgt_eval,
259+
"test": uncleaned_tgt_test,
260+
}
261+
}
262+
241263
src_filename = file_path_by_lang_and_split[src_language][split]
264+
uncleaned_src_filename = uncleaned_filenames[src_language][split]
265+
242266
full_src_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filename)
243267

244268
cache_inner_src_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_src_filepath)
245269
cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b").read_from_tar()
246-
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1]))
247-
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: full_src_filepath in x)
270+
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.filter(lambda x: os.path.basename(uncleaned_src_filename) in x[0])
271+
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.map(lambda x: _clean_files(full_src_filepath, x[0], x[1]))
248272
cache_inner_src_decompressed_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b")
249273
cache_inner_src_decompressed_dp = cache_inner_src_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
250274

251275
tgt_filename = file_path_by_lang_and_split[tgt_language][split]
276+
uncleaned_tgt_filename = uncleaned_filenames[tgt_language][split]
277+
252278
full_tgt_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filename)
253279

254280
cache_inner_tgt_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_tgt_filepath)
255281
cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b").read_from_tar()
256-
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: _clean_files(x[0], os.path.splitext(os.path.dirname(os.path.dirname(x[0])))[0], x[1]))
257-
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: full_tgt_filepath in x)
282+
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.filter(lambda x: os.path.basename(uncleaned_tgt_filename) in x[0])
283+
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.map(lambda x: _clean_files(full_tgt_filepath, x[0], x[1]))
258284
cache_inner_tgt_decompressed_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b")
259285
cache_inner_tgt_decompressed_dp = cache_inner_tgt_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
260286

261-
tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="r")
262-
src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="r")
287+
tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="b")
288+
src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="b")
263289

264290
src_lines = src_data_dp.readlines(return_path=False, strip_newline=False, decode=True)
265291
tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False, decode=True)

0 commit comments

Comments
 (0)