From 970c79916d17ae991120a6dc899f200bfef60df0 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Wed, 12 Jan 2022 15:55:34 -0800 Subject: [PATCH 1/4] Migrating penntreebank dataset to use torchdata --- torchtext/datasets/penntreebank.py | 54 ++++++++++++++++++------------ 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/torchtext/datasets/penntreebank.py b/torchtext/datasets/penntreebank.py index 29d9666af1..5a7c56a130 100644 --- a/torchtext/datasets/penntreebank.py +++ b/torchtext/datasets/penntreebank.py @@ -1,29 +1,32 @@ -import logging -from torchtext.utils import download_from_url +import os +from typing import Union, Tuple + +from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, _create_dataset_directory, - _read_text_iterator, ) +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileLoader, HttpReader, IterableWrapper + URL = { - 'train': "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt", - 'test': "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt", - 'valid': "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt", + "train": "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt", + "test": "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt", + "valid": "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt", } MD5 = { - 'train': "f26c4b92c5fdc7b3f8c7cdcb991d8420", - 'valid': "aa0affc06ff7c36e977d7cd49e3839bf", - 'test': "8b80168b89c18661a38ef683c0dc3721", + "train": "f26c4b92c5fdc7b3f8c7cdcb991d8420", + "valid": "aa0affc06ff7c36e977d7cd49e3839bf", + "test": "8b80168b89c18661a38ef683c0dc3721", } NUM_LINES = { - 'train': 42068, - 'valid': 3370, - 'test': 3761, + "train": 42068, + "valid": 3370, + "test": 3761, } DATASET_NAME = "PennTreebank" @@ -31,12 +34,19 @@ @_add_docstring_header(num_lines=NUM_LINES) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'valid', 'test')) -def PennTreebank(root, split): - path = download_from_url(URL[split], - root=root, hash_value=MD5[split], - hash_type='md5') - logging.info('Creating {} data'.format(split)) - return _RawTextIterableDataset(DATASET_NAME, - NUM_LINES[split], - _read_text_iterator(path)) +@_wrap_split_argument(("train", "valid", "test")) +def PennTreebank(root, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" + ) + + url_dp = IterableWrapper([URL[split]]) + cache_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), + hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + hash_type="md5", + ) + cache_dp = HttpReader(cache_dp).end_caching(mode="w", same_filepath_fn=True) + cache_dp = FileLoader(cache_dp, mode="r") + return cache_dp.readlines().map(lambda t: t[1][1:-1]) From 002ee283bac3b85de1f7ef82c09c9c86f26cd731 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Wed, 12 Jan 2022 16:26:06 -0800 Subject: [PATCH 2/4] Update FileLoader to FileOpener --- torchtext/datasets/penntreebank.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/datasets/penntreebank.py b/torchtext/datasets/penntreebank.py index 5a7c56a130..c894d1ebdb 100644 --- a/torchtext/datasets/penntreebank.py +++ b/torchtext/datasets/penntreebank.py @@ -9,7 +9,7 @@ ) if is_module_available("torchdata"): - from torchdata.datapipes.iter import FileLoader, HttpReader, IterableWrapper + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper URL = { "train": "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt", @@ -48,5 +48,5 @@ def PennTreebank(root, split: Union[Tuple[str], str]): hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="w", same_filepath_fn=True) - cache_dp = FileLoader(cache_dp, mode="r") + cache_dp = FileOpener(cache_dp, mode="r") return cache_dp.readlines().map(lambda t: t[1][1:-1]) From 3a3b6837ab2ce3262f7fc6727c4cb929670fb9f1 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Wed, 19 Jan 2022 12:30:30 -0800 Subject: [PATCH 3/4] Resolved comments about return_path --- torchtext/datasets/penntreebank.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtext/datasets/penntreebank.py b/torchtext/datasets/penntreebank.py index c894d1ebdb..f46b8ddf16 100644 --- a/torchtext/datasets/penntreebank.py +++ b/torchtext/datasets/penntreebank.py @@ -48,5 +48,6 @@ def PennTreebank(root, split: Union[Tuple[str], str]): hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="w", same_filepath_fn=True) - cache_dp = FileOpener(cache_dp, mode="r") - return cache_dp.readlines().map(lambda t: t[1][1:-1]) + data_dp = FileOpener(cache_dp, mode="r") + # remove single leading and trailing space from the dataset + return data_dp.readlines(return_path=False).map(lambda t: t[1:-1]) From 51e6b3672380101b97ee0262b75ec290c5c90d17 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 20 Jan 2022 09:13:25 -0800 Subject: [PATCH 4/4] Using strip() to remove leading/trailing spaces --- torchtext/datasets/penntreebank.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/penntreebank.py b/torchtext/datasets/penntreebank.py index f46b8ddf16..87aa9e0a8c 100644 --- a/torchtext/datasets/penntreebank.py +++ b/torchtext/datasets/penntreebank.py @@ -50,4 +50,4 @@ def PennTreebank(root, split: Union[Tuple[str], str]): cache_dp = HttpReader(cache_dp).end_caching(mode="w", same_filepath_fn=True) data_dp = FileOpener(cache_dp, mode="r") # remove single leading and trailing space from the dataset - return data_dp.readlines(return_path=False).map(lambda t: t[1:-1]) + return data_dp.readlines(return_path=False).map(lambda t: t.strip())