|
| 1 | +from torchtext._internal.module_utils import is_module_available |
| 2 | +from typing import Union, Tuple |
| 3 | + |
| 4 | +if is_module_available("torchdata"): |
| 5 | + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper, FileLister |
| 6 | + |
1 | 7 | import os |
2 | 8 | from torchtext.utils import (download_from_url, extract_archive) |
3 | 9 | from torchtext.data.datasets_utils import ( |
|
9 | 15 | ) |
10 | 16 | from torchtext.data.datasets_utils import _create_dataset_directory |
11 | 17 |
|
| 18 | +URL = 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8' |
| 19 | + |
| 20 | +_PATH = '2016-01.tgz' |
| 21 | + |
| 22 | +MD5 = 'c393ed3fc2a1b0f004b3331043f615ae' |
12 | 23 |
|
13 | 24 | SUPPORTED_DATASETS = { |
14 | | - 'URL': 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8', |
15 | | - '_PATH': '2016-01.tgz', |
16 | | - 'MD5': 'c393ed3fc2a1b0f004b3331043f615ae', |
| 25 | + |
17 | 26 | 'valid_test': ['dev2010', 'tst2010', 'tst2011', 'tst2012', 'tst2013', 'tst2014'], |
18 | 27 | 'language_pair': { |
19 | 28 | 'en': ['ar', 'de', 'fr', 'cs'], |
|
26 | 35 |
|
27 | 36 | } |
28 | 37 |
|
29 | | -URL = SUPPORTED_DATASETS['URL'] |
30 | | -MD5 = SUPPORTED_DATASETS['MD5'] |
31 | | - |
32 | 38 | NUM_LINES = { |
33 | 39 | 'train': { |
34 | 40 | 'train': { |
@@ -133,21 +139,28 @@ def _construct_filenames(filename, languages): |
133 | 139 | return filenames |
134 | 140 |
|
135 | 141 |
|
| 142 | +def _construct_filepath(path, src_filename, tgt_filename): |
| 143 | + src_path = None |
| 144 | + tgt_path = None |
| 145 | + src_path = path if src_filename in path else src_path |
| 146 | + tgt_path = path if tgt_filename in path else tgt_path |
| 147 | + return src_path, tgt_path |
| 148 | + |
| 149 | + |
136 | 150 | def _construct_filepaths(paths, src_filename, tgt_filename): |
137 | 151 | src_path = None |
138 | 152 | tgt_path = None |
139 | 153 | for p in paths: |
140 | | - src_path = p if src_filename in p else src_path |
141 | | - tgt_path = p if tgt_filename in p else tgt_path |
142 | | - return (src_path, tgt_path) |
| 154 | + src_path, tgt_path = _construct_filepath(p, src_filename, tgt_filename) |
| 155 | + return src_path, tgt_path |
143 | 156 |
|
144 | 157 |
|
145 | 158 | DATASET_NAME = "IWSLT2016" |
146 | 159 |
|
147 | 160 |
|
148 | 161 | @_create_dataset_directory(dataset_name=DATASET_NAME) |
149 | | -@_wrap_split_argument(('train', 'valid', 'test')) |
150 | | -def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'): |
| 162 | +@_wrap_split_argument(("train", "valid", "test")) |
| 163 | +def IWSLT2016(root = '.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'): |
151 | 164 | """IWSLT2016 dataset |
152 | 165 |
|
153 | 166 | The available datasets include following: |
@@ -191,6 +204,9 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de |
191 | 204 | 'test': test_set |
192 | 205 | } |
193 | 206 |
|
| 207 | + if not is_module_available("torchdata"): |
| 208 | + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") |
| 209 | + |
194 | 210 | if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): |
195 | 211 | raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair))) |
196 | 212 |
|
@@ -225,50 +241,60 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de |
225 | 241 | src_eval, tgt_eval = valid_filenames |
226 | 242 | src_test, tgt_test = test_filenames |
227 | 243 |
|
228 | | - extracted_files = [] # list of paths to the extracted files |
229 | | - dataset_tar = download_from_url(SUPPORTED_DATASETS['URL'], root=root, hash_value=SUPPORTED_DATASETS['MD5'], |
230 | | - path=os.path.join(root, SUPPORTED_DATASETS['_PATH']), hash_type='md5') |
231 | | - extracted_dataset_tar = extract_archive(dataset_tar) |
232 | | - # IWSLT dataset's url downloads a multilingual tgz. |
233 | | - # We need to take an extra step to pick out the specific language pair from it. |
| 244 | + url_dp = IterableWrapper([URL]) |
| 245 | + cache_compressed_dp = url_dp.on_disk_cache( |
| 246 | + filepath_fn=lambda x: os.path.join(root, _PATH), |
| 247 | + hash_dict={os.path.join(root, _PATH): MD5}, |
| 248 | + hash_type="md5" |
| 249 | + ) |
| 250 | + cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) |
| 251 | + cache_compressed_dp = FileOpener(cache_compressed_dp, mode="b") |
234 | 252 | src_language = train_filenames[0].split(".")[-1] |
235 | 253 | tgt_language = train_filenames[1].split(".")[-1] |
236 | 254 | languages = "-".join([src_language, tgt_language]) |
237 | 255 |
|
238 | | - iwslt_tar = '{}/{}/texts/{}/{}/{}.tgz' |
239 | | - iwslt_tar = iwslt_tar.format( |
240 | | - root, SUPPORTED_DATASETS['_PATH'].split(".")[0], src_language, tgt_language, languages) |
241 | | - extracted_dataset_tar = extract_archive(iwslt_tar) |
242 | | - extracted_files.extend(extracted_dataset_tar) |
| 256 | + iwslt_tar = os.path.join( |
| 257 | + "texts", src_language, tgt_language, languages |
| 258 | + ) |
| 259 | + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( |
| 260 | + filepath_fn=lambda x: os.path.join(os.path.splitext(x[0])[0], iwslt_tar) |
| 261 | + ) |
| 262 | + cache_decompressed_dp = cache_decompressed_dp.read_from_tar() |
| 263 | + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb") |
243 | 264 |
|
244 | | - # Clean the xml and tag file in the archives |
245 | | - file_archives = [] |
246 | | - for fname in extracted_files: |
| 265 | + def clean_files(fname): |
247 | 266 | if 'xml' in fname: |
248 | 267 | _clean_xml_file(fname) |
249 | | - file_archives.append(os.path.splitext(fname)[0]) |
| 268 | + return os.path.splitext(fname)[0] |
250 | 269 | elif "tags" in fname: |
251 | 270 | _clean_tags_file(fname) |
252 | | - file_archives.append(fname.replace('.tags', '')) |
253 | | - else: |
254 | | - file_archives.append(fname) |
255 | | - |
256 | | - data_filenames = { |
257 | | - "train": _construct_filepaths(file_archives, src_train, tgt_train), |
258 | | - "valid": _construct_filepaths(file_archives, src_eval, tgt_eval), |
259 | | - "test": _construct_filepaths(file_archives, src_test, tgt_test) |
260 | | - } |
| 271 | + return fname.replace('.tags', '') |
| 272 | + return fname |
| 273 | + |
| 274 | + cache_decompressed_dp = cache_decompressed_dp.flatmap(FileLister) |
| 275 | + |
| 276 | + def get_filepath(f): |
| 277 | + src_file, tgt_file = { |
| 278 | + "train": _construct_filepath(f, src_train, tgt_train), |
| 279 | + "valid": _construct_filepath(f, src_eval, tgt_eval), |
| 280 | + "test": _construct_filepath(f, src_test, tgt_test) |
| 281 | + }[split] |
| 282 | + |
| 283 | + return src_file, tgt_file |
| 284 | + |
| 285 | + cleaned_cache_decompressed_dp = cache_decompressed_dp.map(clean_files).map(get_filepath) |
261 | 286 |
|
262 | | - for key in data_filenames.keys(): |
263 | | - if len(data_filenames[key]) == 0 or data_filenames[key] is None: |
264 | | - raise FileNotFoundError( |
265 | | - "Files are not found for data type {}".format(key)) |
| 287 | + # pairs of filenames are either both None or one of src/tgt is None. |
| 288 | + # filter out both None since they're not relevant |
| 289 | + cleaned_cache_decompressed_dp = cleaned_cache_decompressed_dp.filter(lambda x: x != (None, None)) |
266 | 290 |
|
267 | | - src_data_iter = _read_text_iterator(data_filenames[split][0]) |
268 | | - tgt_data_iter = _read_text_iterator(data_filenames[split][1]) |
| 291 | + # (None, tgt) => 1, (src, None) => 0 |
| 292 | + tgt_data_dp, src_data_dp = cleaned_cache_decompressed_dp.demux(2, lambda x: x.index(None)) |
269 | 293 |
|
270 | | - def _iter(src_data_iter, tgt_data_iter): |
271 | | - for item in zip(src_data_iter, tgt_data_iter): |
272 | | - yield item |
| 294 | + # Pull out the non-None element (i.e., filename) from the tuple |
| 295 | + tgt_data_dp = FileOpener(tgt_data_dp.map(lambda x: x[1]), mode="r") |
| 296 | + src_data_dp = FileOpener(src_data_dp.map(lambda x: x[0]), mode="r") |
273 | 297 |
|
274 | | - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split][num_lines_set_identifier[split]][tuple(sorted(language_pair))], _iter(src_data_iter, tgt_data_iter)) |
| 298 | + src_lines = src_data_dp.readlines(return_path=False, strip_newline=False) |
| 299 | + tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False) |
| 300 | + return src_lines.zip(tgt_lines) |
0 commit comments