Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit aa6af76

Browse files
committed
migrate IWSLT2016 to datapipes.
1 parent 042f12f commit aa6af76

File tree

1 file changed

+71
-45
lines changed

1 file changed

+71
-45
lines changed

torchtext/datasets/iwslt2016.py

Lines changed: 71 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
from torchtext._internal.module_utils import is_module_available
2+
from typing import Union, Tuple
3+
4+
if is_module_available("torchdata"):
5+
from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper, FileLister
6+
17
import os
28
from torchtext.utils import (download_from_url, extract_archive)
39
from torchtext.data.datasets_utils import (
@@ -9,11 +15,14 @@
915
)
1016
from torchtext.data.datasets_utils import _create_dataset_directory
1117

18+
URL = 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8'
19+
20+
_PATH = '2016-01.tgz'
21+
22+
MD5 = 'c393ed3fc2a1b0f004b3331043f615ae'
1223

1324
SUPPORTED_DATASETS = {
14-
'URL': 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8',
15-
'_PATH': '2016-01.tgz',
16-
'MD5': 'c393ed3fc2a1b0f004b3331043f615ae',
25+
1726
'valid_test': ['dev2010', 'tst2010', 'tst2011', 'tst2012', 'tst2013', 'tst2014'],
1827
'language_pair': {
1928
'en': ['ar', 'de', 'fr', 'cs'],
@@ -26,9 +35,6 @@
2635

2736
}
2837

29-
URL = SUPPORTED_DATASETS['URL']
30-
MD5 = SUPPORTED_DATASETS['MD5']
31-
3238
NUM_LINES = {
3339
'train': {
3440
'train': {
@@ -133,21 +139,28 @@ def _construct_filenames(filename, languages):
133139
return filenames
134140

135141

142+
def _construct_filepath(path, src_filename, tgt_filename):
143+
src_path = None
144+
tgt_path = None
145+
src_path = path if src_filename in path else src_path
146+
tgt_path = path if tgt_filename in path else tgt_path
147+
return src_path, tgt_path
148+
149+
136150
def _construct_filepaths(paths, src_filename, tgt_filename):
137151
src_path = None
138152
tgt_path = None
139153
for p in paths:
140-
src_path = p if src_filename in p else src_path
141-
tgt_path = p if tgt_filename in p else tgt_path
142-
return (src_path, tgt_path)
154+
src_path, tgt_path = _construct_filepath(p, src_filename, tgt_filename)
155+
return src_path, tgt_path
143156

144157

145158
DATASET_NAME = "IWSLT2016"
146159

147160

148161
@_create_dataset_directory(dataset_name=DATASET_NAME)
149-
@_wrap_split_argument(('train', 'valid', 'test'))
150-
def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'):
162+
@_wrap_split_argument(("train", "valid", "test"))
163+
def IWSLT2016(root = '.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'):
151164
"""IWSLT2016 dataset
152165
153166
The available datasets include following:
@@ -191,6 +204,9 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
191204
'test': test_set
192205
}
193206

207+
if not is_module_available("torchdata"):
208+
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")
209+
194210
if not isinstance(language_pair, list) and not isinstance(language_pair, tuple):
195211
raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair)))
196212

@@ -225,50 +241,60 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
225241
src_eval, tgt_eval = valid_filenames
226242
src_test, tgt_test = test_filenames
227243

228-
extracted_files = [] # list of paths to the extracted files
229-
dataset_tar = download_from_url(SUPPORTED_DATASETS['URL'], root=root, hash_value=SUPPORTED_DATASETS['MD5'],
230-
path=os.path.join(root, SUPPORTED_DATASETS['_PATH']), hash_type='md5')
231-
extracted_dataset_tar = extract_archive(dataset_tar)
232-
# IWSLT dataset's url downloads a multilingual tgz.
233-
# We need to take an extra step to pick out the specific language pair from it.
244+
url_dp = IterableWrapper([URL])
245+
cache_compressed_dp = url_dp.on_disk_cache(
246+
filepath_fn=lambda x: os.path.join(root, _PATH),
247+
hash_dict={os.path.join(root, _PATH): MD5},
248+
hash_type="md5"
249+
)
250+
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)
251+
cache_compressed_dp = FileOpener(cache_compressed_dp, mode="b")
234252
src_language = train_filenames[0].split(".")[-1]
235253
tgt_language = train_filenames[1].split(".")[-1]
236254
languages = "-".join([src_language, tgt_language])
237255

238-
iwslt_tar = '{}/{}/texts/{}/{}/{}.tgz'
239-
iwslt_tar = iwslt_tar.format(
240-
root, SUPPORTED_DATASETS['_PATH'].split(".")[0], src_language, tgt_language, languages)
241-
extracted_dataset_tar = extract_archive(iwslt_tar)
242-
extracted_files.extend(extracted_dataset_tar)
256+
iwslt_tar = os.path.join(
257+
"texts", src_language, tgt_language, languages
258+
)
259+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
260+
filepath_fn=lambda x: os.path.join(os.path.splitext(x[0])[0], iwslt_tar)
261+
)
262+
cache_decompressed_dp = cache_decompressed_dp.read_from_tar()
263+
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb")
243264

244-
# Clean the xml and tag file in the archives
245-
file_archives = []
246-
for fname in extracted_files:
265+
def clean_files(fname):
247266
if 'xml' in fname:
248267
_clean_xml_file(fname)
249-
file_archives.append(os.path.splitext(fname)[0])
268+
return os.path.splitext(fname)[0]
250269
elif "tags" in fname:
251270
_clean_tags_file(fname)
252-
file_archives.append(fname.replace('.tags', ''))
253-
else:
254-
file_archives.append(fname)
255-
256-
data_filenames = {
257-
"train": _construct_filepaths(file_archives, src_train, tgt_train),
258-
"valid": _construct_filepaths(file_archives, src_eval, tgt_eval),
259-
"test": _construct_filepaths(file_archives, src_test, tgt_test)
260-
}
271+
return fname.replace('.tags', '')
272+
return fname
273+
274+
cache_decompressed_dp = cache_decompressed_dp.flatmap(FileLister)
275+
276+
def get_filepath(f):
277+
src_file, tgt_file = {
278+
"train": _construct_filepath(f, src_train, tgt_train),
279+
"valid": _construct_filepath(f, src_eval, tgt_eval),
280+
"test": _construct_filepath(f, src_test, tgt_test)
281+
}[split]
282+
283+
return src_file, tgt_file
284+
285+
cleaned_cache_decompressed_dp = cache_decompressed_dp.map(clean_files).map(get_filepath)
261286

262-
for key in data_filenames.keys():
263-
if len(data_filenames[key]) == 0 or data_filenames[key] is None:
264-
raise FileNotFoundError(
265-
"Files are not found for data type {}".format(key))
287+
# pairs of filenames are either both None or one of src/tgt is None.
288+
# filter out both None since they're not relevant
289+
cleaned_cache_decompressed_dp = cleaned_cache_decompressed_dp.filter(lambda x: x != (None, None))
266290

267-
src_data_iter = _read_text_iterator(data_filenames[split][0])
268-
tgt_data_iter = _read_text_iterator(data_filenames[split][1])
291+
# (None, tgt) => 1, (src, None) => 0
292+
tgt_data_dp, src_data_dp = cleaned_cache_decompressed_dp.demux(2, lambda x: x.index(None))
269293

270-
def _iter(src_data_iter, tgt_data_iter):
271-
for item in zip(src_data_iter, tgt_data_iter):
272-
yield item
294+
# Pull out the non-None element (i.e., filename) from the tuple
295+
tgt_data_dp = FileOpener(tgt_data_dp.map(lambda x: x[1]), mode="r")
296+
src_data_dp = FileOpener(src_data_dp.map(lambda x: x[0]), mode="r")
273297

274-
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split][num_lines_set_identifier[split]][tuple(sorted(language_pair))], _iter(src_data_iter, tgt_data_iter))
298+
src_lines = src_data_dp.readlines(return_path=False, strip_newline=False)
299+
tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False)
300+
return src_lines.zip(tgt_lines)

0 commit comments

Comments
 (0)