diff --git a/torchtext/datasets/udpos.py b/torchtext/datasets/udpos.py index 2377448939..9f5e8c667e 100644 --- a/torchtext/datasets/udpos.py +++ b/torchtext/datasets/udpos.py @@ -1,13 +1,17 @@ -from torchtext.utils import download_from_url, extract_archive +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, - _find_match, _create_dataset_directory, - _create_data_from_iob, ) +import os + URL = 'https://bitbucket.org/sivareddyg/public/downloads/en-ud-v2.zip' MD5 = 'bdcac7c52d934656bae1699541424545' @@ -18,19 +22,36 @@ 'test': 2077, } +_EXTRACTED_FILES = { + "train": "train.txt", + "valid": "dev.txt", + "test": "test.txt" +} + DATASET_NAME = "UDPOS" @_add_docstring_header(num_lines=NUM_LINES) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'valid', 'test')) -def UDPOS(root, split): - dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') - extracted_files = extract_archive(dataset_tar) - if split == 'valid': - path = _find_match("dev.txt", extracted_files) - else: - path = _find_match(split + ".txt", extracted_files) - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], - _create_data_from_iob(path)) +@_wrap_split_argument(("train", "valid", "test")) +def UDPOS(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(URL)), + hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + hash_type="md5" + ) + cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter( + lambda x: _EXTRACTED_FILES[split] in x[0]) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + data_dp = FileOpener(cache_decompressed_dp, mode='b') + return data_dp.readlines(decode=True).read_iob()