diff --git a/torchtext/datasets/enwik9.py b/torchtext/datasets/enwik9.py index 7ec4060573..4534c64138 100644 --- a/torchtext/datasets/enwik9.py +++ b/torchtext/datasets/enwik9.py @@ -1,34 +1,53 @@ -import logging -from torchtext.utils import ( - download_from_url, - extract_archive, -) +import os +from typing import Tuple, Union + +from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, _create_dataset_directory, - _read_text_iterator, ) -URL = 'http://mattmahoney.net/dc/enwik9.zip' +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + +URL = "http://mattmahoney.net/dc/enwik9.zip" + +MD5 = "3e773f8a1577fda2e27f871ca17f31fd" -MD5 = '3e773f8a1577fda2e27f871ca17f31fd' +_PATH = "enwik9.zip" -NUM_LINES = { - 'train': 13147026 -} +NUM_LINES = {"train": 13147026} DATASET_NAME = "EnWik9" @_add_docstring_header(num_lines=NUM_LINES) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train',)) -def EnWik9(root, split): - dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') - extracted_files = extract_archive(dataset_tar) - path = extracted_files[0] - logging.info('Creating {} data'.format(split)) - return _RawTextIterableDataset(DATASET_NAME, - NUM_LINES[split], _read_text_iterator(path)) +@_wrap_split_argument(("train",)) +def EnWik9(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" + ) + + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), + hash_dict={os.path.join(root, _PATH): MD5}, + hash_type="md5", + ) + cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching( + mode="wb", same_filepath_fn=True + ) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0]) + ) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_zip() + cache_decompressed_dp = cache_decompressed_dp.end_caching( + mode="wb", same_filepath_fn=True + ) + + data_dp = FileOpener(cache_decompressed_dp, mode="b") + return data_dp.readlines(decode=True, return_path=False)