From b93157843c47a75c1b904fbf8084ebef984d8101 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sun, 23 Jan 2022 08:58:59 -0500 Subject: [PATCH] migrate Multi30k to datapipes. --- torchtext/datasets/multi30k.py | 76 +++++++++++++++++----------------- 1 file changed, 37 insertions(+), 39 deletions(-) diff --git a/torchtext/datasets/multi30k.py b/torchtext/datasets/multi30k.py index 6e3aa8a690..ba3f516cad 100644 --- a/torchtext/datasets/multi30k.py +++ b/torchtext/datasets/multi30k.py @@ -1,10 +1,13 @@ +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + import os from torchtext.data.datasets_utils import ( - _download_extract_validate, - _RawTextIterableDataset, _wrap_split_argument, _create_dataset_directory, - _read_text_iterator, ) URL = { @@ -19,28 +22,10 @@ 'test': '0681be16a532912288a91ddd573594fbdd57c0fbb81486eff7c55247e35326c2', } -_EXTRACTED_FILES_INFO = { - 'train': { - 'file_prefix': 'train', - 'md5': { - 'de': '695df46f6fd14567e69970408a2c129a50e778a910ecb1585a92eb25b2c7accc', - 'en': '4b4d37e774976ef44fecca1738cdeb3b3ba384851a59a755b9c5e6aa7d87b13c', - }, - }, - 'valid': { - 'file_prefix': 'val', - 'md5': { - 'de': 'fd0fc009db2446cfc12d96a382aff0d3122cb47577b352d0f7e0bb3a38e2e552', - 'en': '40cd20974079d9afb0e3d27c659a8e059cc2fcf850b4bc23ede13fc36dd8a865', - }, - }, - 'test': { - 'file_prefix': 'test', - 'md5': { - 'de': 'c1d2f544471a7387e37d15f1adf075ff7d6fe57a30840bb969281ae102d24cb1', - 'en': '399a4382932c1aadd3ceb9bef1008d388a64c76d4ae4e9d4728c6f4301cac182', - }, - }, +_PREFIX = { + 'train': 'train', + 'valid': 'val', + 'test': 'test', } NUM_LINES = { @@ -53,8 +38,8 @@ @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'valid', 'test')) -def Multi30k(root, split, language_pair=('de', 'en')): +@_wrap_split_argument(("train", "valid", "test")) +def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str] = ('de', 'en')): """Multi30k dataset Reference: http://www.statmt.org/wmt16/multimodal-task.html#task1 @@ -68,18 +53,31 @@ def Multi30k(root, split, language_pair=('de', 'en')): assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively' assert (tuple(sorted(language_pair)) == ('de', 'en')), "language_pair must be either ('de','en') or ('en', 'de')" - downloaded_file = os.path.basename(URL[split]) + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL[split]]) + + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(URL[split])), + hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + hash_type="sha256" + ) + cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + + src_cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[0]}")) + src_cache_decompressed_dp = FileOpener(src_cache_decompressed_dp, mode="b").read_from_tar().filter( + lambda x: f"{_PREFIX[split]}.{language_pair[0]}" in x[0]) + src_cache_decompressed_dp = src_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - src_path = _download_extract_validate(root, URL[split], MD5[split], - os.path.join(root, downloaded_file), - os.path.join(root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[0]), - _EXTRACTED_FILES_INFO[split]['md5'][language_pair[0]]) - trg_path = _download_extract_validate(root, URL[split], MD5[split], - os.path.join(root, downloaded_file), - os.path.join(root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[1]), - _EXTRACTED_FILES_INFO[split]['md5'][language_pair[1]]) + tgt_cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[1]}")) + tgt_cache_decompressed_dp = FileOpener(tgt_cache_decompressed_dp, mode="b").read_from_tar().filter( + lambda x: f"{_PREFIX[split]}.{language_pair[1]}" in x[0]) + tgt_cache_decompressed_dp = tgt_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - src_data_iter = _read_text_iterator(src_path) - trg_data_iter = _read_text_iterator(trg_path) + src_data_dp = FileOpener(src_cache_decompressed_dp, mode="b").readlines(decode=True, return_path=False, strip_newline=False) + tgt_data_dp = FileOpener(tgt_cache_decompressed_dp, mode="b").readlines(decode=True, return_path=False, strip_newline=False) - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], zip(src_data_iter, trg_data_iter)) + return src_data_dp.zip(tgt_data_dp)