From 9c56cf63dbfa04878dfde1419d00343f3883d7e9 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Fri, 14 Jan 2022 02:05:06 -0800 Subject: [PATCH 1/2] Migrate WikiText2 to datapipes --- torchtext/datasets/wikitext2.py | 41 +++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/torchtext/datasets/wikitext2.py b/torchtext/datasets/wikitext2.py index b4c20fc880..c0547ce40f 100644 --- a/torchtext/datasets/wikitext2.py +++ b/torchtext/datasets/wikitext2.py @@ -1,13 +1,17 @@ -import logging -from torchtext.utils import download_from_url, extract_archive +from torchtext._internal.module_utils import is_module_available + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + +import os +import functools from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, - _find_match, _create_dataset_directory, - _read_text_iterator, ) +from typing import Union, Tuple +from pathlib import Path URL = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip' @@ -25,10 +29,23 @@ @_add_docstring_header(num_lines=NUM_LINES) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(('train', 'valid', 'test')) -def WikiText2(root, split): - dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') - extracted_files = extract_archive(dataset_tar) - path = _find_match(split, extracted_files) - logging.info('Creating {} data'.format(split)) - return _RawTextIterableDataset(DATASET_NAME, - NUM_LINES[split], _read_text_iterator(path)) +def WikiText2(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + url_dp = IterableWrapper([URL]) + # cache data on-disk + filepath_fn = functools.partial(lambda x: os.path.join(root, os.path.basename(x))) + cache_dp = url_dp.on_disk_cache( + filepath_fn=filepath_fn, + hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + hash_type="md5", + ) + cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_dp = FileOpener(cache_dp, mode="b") + # stack Zip extractor on top of load files data pipe + extracted_files = cache_dp.read_from_zip() + # filter the files as applicable to create dataset for given split (train or test) + filter_fn = functools.partial(lambda x: split in Path(x[0]).parts[-1]) + filter_extracted_files = extracted_files.filter(filter_fn) + extract_text_fn = functools.partial(lambda t: t[1].decode()) + return filter_extracted_files.readlines(strip_newline=False).map(extract_text_fn) From 3b4d8c370ab8903a2e325bc14cc14980b322cc13 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Fri, 21 Jan 2022 00:08:03 -0800 Subject: [PATCH 2/2] Address code review comments and add double caching --- torchtext/datasets/wikitext2.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/torchtext/datasets/wikitext2.py b/torchtext/datasets/wikitext2.py index c0547ce40f..8a43b4c7b4 100644 --- a/torchtext/datasets/wikitext2.py +++ b/torchtext/datasets/wikitext2.py @@ -4,14 +4,12 @@ from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper import os -import functools from torchtext.data.datasets_utils import ( _wrap_split_argument, _add_docstring_header, _create_dataset_directory, ) from typing import Union, Tuple -from pathlib import Path URL = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip' @@ -25,6 +23,12 @@ DATASET_NAME = "WikiText2" +_EXTRACTED_FILES = { + 'train': os.path.join('wikitext-2', 'wiki.train.tokens'), + 'test': os.path.join('wikitext-2', 'wiki.test.tokens'), + 'valid': os.path.join('wikitext-2', 'wiki.valid.tokens'), +} + @_add_docstring_header(num_lines=NUM_LINES) @_create_dataset_directory(dataset_name=DATASET_NAME) @@ -34,18 +38,15 @@ def WikiText2(root: str, split: Union[Tuple[str], str]): raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") url_dp = IterableWrapper([URL]) # cache data on-disk - filepath_fn = functools.partial(lambda x: os.path.join(root, os.path.basename(x))) - cache_dp = url_dp.on_disk_cache( - filepath_fn=filepath_fn, + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, hash_type="md5", ) - cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_dp = FileOpener(cache_dp, mode="b") - # stack Zip extractor on top of load files data pipe - extracted_files = cache_dp.read_from_zip() - # filter the files as applicable to create dataset for given split (train or test) - filter_fn = functools.partial(lambda x: split in Path(x[0]).parts[-1]) - filter_extracted_files = extracted_files.filter(filter_fn) - extract_text_fn = functools.partial(lambda t: t[1].decode()) - return filter_extracted_files.readlines(strip_newline=False).map(extract_text_fn) + cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])) + # Extract zip and filter the appropriate split file + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + data_dp = FileOpener(cache_decompressed_dp, mode='b') + return data_dp.readlines(strip_newline=False, decode=True, return_path=False)