From 77dbd9cd1ff48799416fb9637760f43b5515fb9e Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 24 Jun 2022 21:06:30 +0000 Subject: [PATCH 01/10] url list benchmark testing --- torchtext/datasets/cnndm.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 3950557f5d..4c7b85c7d1 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -2,6 +2,9 @@ import os from functools import partial from typing import Union, Tuple +from timeit import timeit +from benchmark.utils import Timer +from functools import lru_cache from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -69,6 +72,13 @@ def _hash_urls(s): return story_fname +@lru_cache() +def _get_split_list_cache(split: str): + url_dp = IterableWrapper([URL_LIST[split]]) + online_dp = OnlineReader(url_dp) + return online_dp.readlines().map(fn=_hash_urls) + + def _get_split_list(split: str): url_dp = IterableWrapper([URL_LIST[split]]) online_dp = OnlineReader(url_dp) @@ -88,8 +98,8 @@ def _load_stories(root: str, source: str): return cache_decompressed_dp -@_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(("train", "val", "test")) +#@_create_dataset_directory(dataset_name=DATASET_NAME) +#@_wrap_split_argument(("train", "val", "test")) def CNNDM(root: str, split: Union[Tuple[str], str]): """CNNDM Dataset @@ -127,3 +137,12 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): story_fnames = set(_get_split_list(split)) data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() + +if __name__ == '__main__': + + for split in ('train', 'val', 'test'): + with Timer(f"{split} w/out cache"): + _get_split_list(split) + + with Timer(f"{split} w/ cache"): + _get_split_list_cache(split) \ No newline at end of file From f4fc46f32925e881939e8348a9f9a8e1565f1ec5 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 28 Jun 2022 14:36:21 +0000 Subject: [PATCH 02/10] benchmarking original implementation vs storing fnames in global variable --- torchtext/datasets/cnndm.py | 31 +++---- torchtext/datasets/cnndm_v1.py | 149 +++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 20 deletions(-) create mode 100644 torchtext/datasets/cnndm_v1.py diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 4c7b85c7d1..9dfbbcc267 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -2,9 +2,7 @@ import os from functools import partial from typing import Union, Tuple -from timeit import timeit from benchmark.utils import Timer -from functools import lru_cache from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -55,7 +53,8 @@ def _extracted_filepath_fn(root: str, source: str): return os.path.join(root, _EXTRACTED_FOLDERS[source]) -def _filter_fn(story_fnames, x): +def _filter_fn(split, x): + story_fnames = set(_get_split_list(split)) return os.path.basename(x[0]) in story_fnames @@ -72,13 +71,6 @@ def _hash_urls(s): return story_fname -@lru_cache() -def _get_split_list_cache(split: str): - url_dp = IterableWrapper([URL_LIST[split]]) - online_dp = OnlineReader(url_dp) - return online_dp.readlines().map(fn=_hash_urls) - - def _get_split_list(split: str): url_dp = IterableWrapper([URL_LIST[split]]) online_dp = OnlineReader(url_dp) @@ -132,17 +124,16 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): cnn_dp = _load_stories(root, "cnn") dailymail_dp = _load_stories(root, "dailymail") data_dp = cnn_dp.concat(dailymail_dp) - # TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn - # of the on_disk_cache_dp which caches the files extracted from the tar - story_fnames = set(_get_split_list(split)) - data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) + data_dp = data_dp.filter(partial(_filter_fn, split)) return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() if __name__ == '__main__': - for split in ('train', 'val', 'test'): - with Timer(f"{split} w/out cache"): - _get_split_list(split) - - with Timer(f"{split} w/ cache"): - _get_split_list_cache(split) \ No newline at end of file + out = CNNDM(os.path.expanduser('~/.torchtext/cache'), 'train') + + with Timer(f"initialize dataset"): + ex = iter(out) + + for i in range(2): + with Timer(f"iteration: {i}"): + next(ex) \ No newline at end of file diff --git a/torchtext/datasets/cnndm_v1.py b/torchtext/datasets/cnndm_v1.py new file mode 100644 index 0000000000..912aeec097 --- /dev/null +++ b/torchtext/datasets/cnndm_v1.py @@ -0,0 +1,149 @@ +import hashlib +import os +from functools import partial +from typing import Union, Tuple +from benchmark.utils import Timer + +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import ( + _wrap_split_argument, + _create_dataset_directory, +) + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import ( + FileOpener, + IterableWrapper, + OnlineReader, + GDriveReader, + ) + +DATASET_NAME = "CNNDM" + +URL_LIST = { + "train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt", + "val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt", + "test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt", +} + +STORIES_LIST = { + "cnn": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ", + "dailymail": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs", +} + +PATH_LIST = { + "cnn": "cnn_stories.tgz", + "dailymail": "dailymail_stories.tgz", +} + +STORIES_MD5 = {"cnn": "85ac23a1926a831e8f46a6b8eaf57263", "dailymail": "f9c5f565e8abe86c38bfa4ae8f96fd72"} + +_EXTRACTED_FOLDERS = { + "cnn": os.path.join("cnn", "stories"), + "daily_mail": os.path.join("dailymail", "stories"), +} + +story_fnames = None + + +def _filepath_fn(root: str, source: str, _=None): + return os.path.join(root, PATH_LIST[source]) + + +# this function will be used to cache the contents of the tar file +def _extracted_filepath_fn(root: str, source: str): + return os.path.join(root, _EXTRACTED_FOLDERS[source]) + + +def _filter_fn(split, x): + global story_fnames + if not story_fnames: + story_fnames = set(_get_split_list(split)) + return os.path.basename(x[0]) in story_fnames + + +def _hash_urls(s): + """ + Returns story filename as a heximal formated SHA1 hash of the input url string. + Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py + """ + url = s[1] + h = hashlib.sha1() + h.update(url) + url_hash = h.hexdigest() + story_fname = url_hash + ".story" + return story_fname + + +def _get_split_list_cache(split: str): + url_dp = IterableWrapper([URL_LIST[split]]) + online_dp = OnlineReader(url_dp) + return online_dp.readlines().map(fn=_hash_urls) + + +def _get_split_list(split: str): + url_dp = IterableWrapper([URL_LIST[split]]) + online_dp = OnlineReader(url_dp) + return online_dp.readlines().map(fn=_hash_urls) + + +def _load_stories(root: str, source: str): + story_dp = IterableWrapper([STORIES_LIST[source]]) + cache_compressed_dp = story_dp.on_disk_cache( + filepath_fn=partial(_filepath_fn, root, source), + hash_dict={_filepath_fn(root, source): STORIES_MD5[source]}, + hash_type="md5", + ) + cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + # TODO: cache the contents of the extracted tar file + cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar() + return cache_decompressed_dp + + +#@_create_dataset_directory(dataset_name=DATASET_NAME) +#@_wrap_split_argument(("train", "val", "test")) +def CNNDM(root: str, split: Union[Tuple[str], str]): + """CNNDM Dataset + + .. warning:: + + Using datapipes is still currently subject to a few caveats. If you wish + to use this dataset with shuffling, multi-processing, or distributed + learning, please see :ref:`this note ` for further + instructions. + + For additional details refer to https://arxiv.org/pdf/1704.04368.pdf + + Number of lines per split: + - train: 287,227 + - val: 13,368 + - test: 11,490 + + Args: + root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') + split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `val`, `test`) + + :returns: DataPipe that yields a tuple of texts containing an article and its abstract (i.e. (article, abstract)) + :rtype: (str, str) + """ + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data" + ) + + cnn_dp = _load_stories(root, "cnn") + dailymail_dp = _load_stories(root, "dailymail") + data_dp = cnn_dp.concat(dailymail_dp) + data_dp = data_dp.filter(partial(_filter_fn, split)) + return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() + +if __name__ == '__main__': + + out = CNNDM(os.path.expanduser('~/.torchtext/cache'), 'train') + + with Timer(f"initialize dataset"): + ex = iter(out) + + for i in range(2): + with Timer(f"iteration: {i}"): + next(ex) \ No newline at end of file From 92e9216a328d3b4b9b1778fa67d378261a0f17c3 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 5 Jul 2022 22:09:42 +0000 Subject: [PATCH 03/10] cache tar extraction --- torchtext/datasets/cnndm.py | 53 +++++++----- torchtext/datasets/cnndm_v1.py | 149 --------------------------------- 2 files changed, 34 insertions(+), 168 deletions(-) delete mode 100644 torchtext/datasets/cnndm_v1.py diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 9dfbbcc267..1189de6590 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -1,5 +1,6 @@ import hashlib import os +from collections import defaultdict from functools import partial from typing import Union, Tuple from benchmark.utils import Timer @@ -21,9 +22,12 @@ DATASET_NAME = "CNNDM" URL_LIST = { - "train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt", - "val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt", - "test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt", + "cnn_train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_training_urls.txt", + "cnn_val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_validation_urls.txt", + "cnn_test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_test_urls.txt", + 'dailymail_train': "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_training_urls.txt", + "dailymail_val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_validation_urls.txt", + "dailymail_test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_test_urls.txt", } STORIES_LIST = { @@ -40,22 +44,30 @@ _EXTRACTED_FOLDERS = { "cnn": os.path.join("cnn", "stories"), - "daily_mail": os.path.join("dailymail", "stories"), + "dailymail": os.path.join("dailymail", "stories"), } +story_fnames = defaultdict(set) + def _filepath_fn(root: str, source: str, _=None): return os.path.join(root, PATH_LIST[source]) -# this function will be used to cache the contents of the tar file -def _extracted_filepath_fn(root: str, source: str): - return os.path.join(root, _EXTRACTED_FOLDERS[source]) +def _extracted_folder_fn(root: str, source: str, split: str, _=None): + global story_fnames + key = source + '_' + split + story_fnames[key] = set(_get_split_list(source, split)) + filepaths = [os.path.join(root, _EXTRACTED_FOLDERS[source], story) for story in story_fnames[key]] + return filepaths + + +def _extracted_filepath_fn(root: str, source: str, x): + return os.path.join(root, _EXTRACTED_FOLDERS[source], os.path.basename(x)) -def _filter_fn(split, x): - story_fnames = set(_get_split_list(split)) - return os.path.basename(x[0]) in story_fnames +def _filter_fn(source, split, x): + return os.path.basename(x[0]) in story_fnames[source + '_' + split] def _hash_urls(s): @@ -71,13 +83,13 @@ def _hash_urls(s): return story_fname -def _get_split_list(split: str): - url_dp = IterableWrapper([URL_LIST[split]]) +def _get_split_list(source:str, split: str): + url_dp = IterableWrapper([URL_LIST[source + '_' + split]]) online_dp = OnlineReader(url_dp) return online_dp.readlines().map(fn=_hash_urls) -def _load_stories(root: str, source: str): +def _load_stories(root: str, source: str, split: str): story_dp = IterableWrapper([STORIES_LIST[source]]) cache_compressed_dp = story_dp.on_disk_cache( filepath_fn=partial(_filepath_fn, root, source), @@ -85,8 +97,13 @@ def _load_stories(root: str, source: str): hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - # TODO: cache the contents of the extracted tar file - cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar() + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_folder_fn, root, source, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, source, split)) + ) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", filepath_fn=partial(_extracted_filepath_fn, root, source)) + data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") return cache_decompressed_dp @@ -121,16 +138,14 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data" ) - cnn_dp = _load_stories(root, "cnn") - dailymail_dp = _load_stories(root, "dailymail") + cnn_dp = _load_stories(root, "cnn", split) + dailymail_dp = _load_stories(root, "dailymail", split) data_dp = cnn_dp.concat(dailymail_dp) - data_dp = data_dp.filter(partial(_filter_fn, split)) return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() if __name__ == '__main__': out = CNNDM(os.path.expanduser('~/.torchtext/cache'), 'train') - with Timer(f"initialize dataset"): ex = iter(out) diff --git a/torchtext/datasets/cnndm_v1.py b/torchtext/datasets/cnndm_v1.py deleted file mode 100644 index 912aeec097..0000000000 --- a/torchtext/datasets/cnndm_v1.py +++ /dev/null @@ -1,149 +0,0 @@ -import hashlib -import os -from functools import partial -from typing import Union, Tuple -from benchmark.utils import Timer - -from torchtext._internal.module_utils import is_module_available -from torchtext.data.datasets_utils import ( - _wrap_split_argument, - _create_dataset_directory, -) - -if is_module_available("torchdata"): - from torchdata.datapipes.iter import ( - FileOpener, - IterableWrapper, - OnlineReader, - GDriveReader, - ) - -DATASET_NAME = "CNNDM" - -URL_LIST = { - "train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt", - "val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt", - "test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt", -} - -STORIES_LIST = { - "cnn": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ", - "dailymail": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs", -} - -PATH_LIST = { - "cnn": "cnn_stories.tgz", - "dailymail": "dailymail_stories.tgz", -} - -STORIES_MD5 = {"cnn": "85ac23a1926a831e8f46a6b8eaf57263", "dailymail": "f9c5f565e8abe86c38bfa4ae8f96fd72"} - -_EXTRACTED_FOLDERS = { - "cnn": os.path.join("cnn", "stories"), - "daily_mail": os.path.join("dailymail", "stories"), -} - -story_fnames = None - - -def _filepath_fn(root: str, source: str, _=None): - return os.path.join(root, PATH_LIST[source]) - - -# this function will be used to cache the contents of the tar file -def _extracted_filepath_fn(root: str, source: str): - return os.path.join(root, _EXTRACTED_FOLDERS[source]) - - -def _filter_fn(split, x): - global story_fnames - if not story_fnames: - story_fnames = set(_get_split_list(split)) - return os.path.basename(x[0]) in story_fnames - - -def _hash_urls(s): - """ - Returns story filename as a heximal formated SHA1 hash of the input url string. - Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py - """ - url = s[1] - h = hashlib.sha1() - h.update(url) - url_hash = h.hexdigest() - story_fname = url_hash + ".story" - return story_fname - - -def _get_split_list_cache(split: str): - url_dp = IterableWrapper([URL_LIST[split]]) - online_dp = OnlineReader(url_dp) - return online_dp.readlines().map(fn=_hash_urls) - - -def _get_split_list(split: str): - url_dp = IterableWrapper([URL_LIST[split]]) - online_dp = OnlineReader(url_dp) - return online_dp.readlines().map(fn=_hash_urls) - - -def _load_stories(root: str, source: str): - story_dp = IterableWrapper([STORIES_LIST[source]]) - cache_compressed_dp = story_dp.on_disk_cache( - filepath_fn=partial(_filepath_fn, root, source), - hash_dict={_filepath_fn(root, source): STORIES_MD5[source]}, - hash_type="md5", - ) - cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - # TODO: cache the contents of the extracted tar file - cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar() - return cache_decompressed_dp - - -#@_create_dataset_directory(dataset_name=DATASET_NAME) -#@_wrap_split_argument(("train", "val", "test")) -def CNNDM(root: str, split: Union[Tuple[str], str]): - """CNNDM Dataset - - .. warning:: - - Using datapipes is still currently subject to a few caveats. If you wish - to use this dataset with shuffling, multi-processing, or distributed - learning, please see :ref:`this note ` for further - instructions. - - For additional details refer to https://arxiv.org/pdf/1704.04368.pdf - - Number of lines per split: - - train: 287,227 - - val: 13,368 - - test: 11,490 - - Args: - root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') - split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `val`, `test`) - - :returns: DataPipe that yields a tuple of texts containing an article and its abstract (i.e. (article, abstract)) - :rtype: (str, str) - """ - if not is_module_available("torchdata"): - raise ModuleNotFoundError( - "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data" - ) - - cnn_dp = _load_stories(root, "cnn") - dailymail_dp = _load_stories(root, "dailymail") - data_dp = cnn_dp.concat(dailymail_dp) - data_dp = data_dp.filter(partial(_filter_fn, split)) - return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() - -if __name__ == '__main__': - - out = CNNDM(os.path.expanduser('~/.torchtext/cache'), 'train') - - with Timer(f"initialize dataset"): - ex = iter(out) - - for i in range(2): - with Timer(f"iteration: {i}"): - next(ex) \ No newline at end of file From 020d953c7aafcaa2f5b81dbae1c6f3f5ef4b8ba9 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 6 Jul 2022 00:04:21 +0000 Subject: [PATCH 04/10] cleaning up --- torchtext/datasets/cnndm.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 1189de6590..ac5f314a30 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -53,7 +53,7 @@ def _filepath_fn(root: str, source: str, _=None): return os.path.join(root, PATH_LIST[source]) - +# called once per tar file, therefore no duplicate processing def _extracted_folder_fn(root: str, source: str, split: str, _=None): global story_fnames key = source + '_' + split @@ -103,12 +103,12 @@ def _load_stories(root: str, source: str, split: str): FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, source, split)) ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", filepath_fn=partial(_extracted_filepath_fn, root, source)) - data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - return cache_decompressed_dp + data_dp = FileOpener(cache_decompressed_dp, mode='b') + return data_dp -#@_create_dataset_directory(dataset_name=DATASET_NAME) -#@_wrap_split_argument(("train", "val", "test")) +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train", "val", "test")) def CNNDM(root: str, split: Union[Tuple[str], str]): """CNNDM Dataset @@ -142,13 +142,3 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): dailymail_dp = _load_stories(root, "dailymail", split) data_dp = cnn_dp.concat(dailymail_dp) return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() - -if __name__ == '__main__': - - out = CNNDM(os.path.expanduser('~/.torchtext/cache'), 'train') - with Timer(f"initialize dataset"): - ex = iter(out) - - for i in range(2): - with Timer(f"iteration: {i}"): - next(ex) \ No newline at end of file From 2680e1382f00e626ac68c8bf1b14777b12a50df7 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Thu, 7 Jul 2022 22:27:04 +0000 Subject: [PATCH 05/10] updating unittest --- test/datasets/test_cnndm.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/test/datasets/test_cnndm.py b/test/datasets/test_cnndm.py index 224329a376..669a616b9c 100644 --- a/test/datasets/test_cnndm.py +++ b/test/datasets/test_cnndm.py @@ -41,7 +41,7 @@ def _get_mock_dataset(root_dir): stories.append((txt_file, dataset_line)) seed += 2 - # append stories to correct dataset split, must be in legixographic order of filenames per dataset + # append stories to correct dataset split, must be in lexicographic order of filenames per dataset stories.sort(key=lambda x: x[0]) mocked_data[split] += [t[1] for t in stories] @@ -70,15 +70,14 @@ def tearDownClass(cls): cls.patcher.stop() super().tearDownClass() - def _mock_split_list(split): + def _mock_split_list(source, split): story_fnames = [] - for source in ["cnn", "dailymail"]: - for i in range(5): - url = "_".join([source, split, str(i)]) - h = hashlib.sha1() - h.update(url.encode()) - filename = h.hexdigest() + ".story" - story_fnames.append(filename) + for i in range(5): + url = "_".join([source, split, str(i)]) + h = hashlib.sha1() + h.update(url.encode()) + filename = h.hexdigest() + ".story" + story_fnames.append(filename) return story_fnames From 3b84bf0ee42a4f2b7fa241d19972d5a4d1a7028e Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 8 Jul 2022 18:07:44 +0000 Subject: [PATCH 06/10] patching in get_split_list for split testing --- test/datasets/test_cnndm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/datasets/test_cnndm.py b/test/datasets/test_cnndm.py index 669a616b9c..951cc5447e 100644 --- a/test/datasets/test_cnndm.py +++ b/test/datasets/test_cnndm.py @@ -91,6 +91,7 @@ def test_cnndm(self, split): self.assertEqual(sample, expected_sample) @parameterized.expand(["train", "val", "test"]) + @patch("torchtext.datasets.cnndm._get_split_list", _mock_split_list) def test_cnndm_split_argument(self, split): dataset1 = CNNDM(root=self.root_dir, split=split) (dataset2,) = CNNDM(root=self.root_dir, split=(split,)) From d5c143f1e73f0533c6d176abbb132b9d5210ee79 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 8 Jul 2022 14:17:01 -0400 Subject: [PATCH 07/10] pre-commit --- torchtext/datasets/cnndm.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index ac5f314a30..5704f35840 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -3,7 +3,6 @@ from collections import defaultdict from functools import partial from typing import Union, Tuple -from benchmark.utils import Timer from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -25,7 +24,7 @@ "cnn_train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_training_urls.txt", "cnn_val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_validation_urls.txt", "cnn_test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_test_urls.txt", - 'dailymail_train': "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_training_urls.txt", + "dailymail_train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_training_urls.txt", "dailymail_val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_validation_urls.txt", "dailymail_test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_test_urls.txt", } @@ -53,10 +52,11 @@ def _filepath_fn(root: str, source: str, _=None): return os.path.join(root, PATH_LIST[source]) + # called once per tar file, therefore no duplicate processing def _extracted_folder_fn(root: str, source: str, split: str, _=None): global story_fnames - key = source + '_' + split + key = source + "_" + split story_fnames[key] = set(_get_split_list(source, split)) filepaths = [os.path.join(root, _EXTRACTED_FOLDERS[source], story) for story in story_fnames[key]] return filepaths @@ -67,7 +67,7 @@ def _extracted_filepath_fn(root: str, source: str, x): def _filter_fn(source, split, x): - return os.path.basename(x[0]) in story_fnames[source + '_' + split] + return os.path.basename(x[0]) in story_fnames[source + "_" + split] def _hash_urls(s): @@ -83,8 +83,8 @@ def _hash_urls(s): return story_fname -def _get_split_list(source:str, split: str): - url_dp = IterableWrapper([URL_LIST[source + '_' + split]]) +def _get_split_list(source: str, split: str): + url_dp = IterableWrapper([URL_LIST[source + "_" + split]]) online_dp = OnlineReader(url_dp) return online_dp.readlines().map(fn=_hash_urls) @@ -98,12 +98,16 @@ def _load_stories(root: str, source: str, split: str): ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_folder_fn, root, source, split)) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=partial(_extracted_folder_fn, root, source, split) + ) cache_decompressed_dp = ( FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, source, split)) ) - cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", filepath_fn=partial(_extracted_filepath_fn, root, source)) - data_dp = FileOpener(cache_decompressed_dp, mode='b') + cache_decompressed_dp = cache_decompressed_dp.end_caching( + mode="wb", filepath_fn=partial(_extracted_filepath_fn, root, source) + ) + data_dp = FileOpener(cache_decompressed_dp, mode="b") return data_dp From bd5e1b3d8b31a75915e8b86fd7e828842c9cbfc4 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 13 Jul 2022 22:34:06 +0000 Subject: [PATCH 08/10] typcasting input args --- torchtext/datasets/cnndm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 5704f35840..0f1cf6d680 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -62,15 +62,16 @@ def _extracted_folder_fn(root: str, source: str, split: str, _=None): return filepaths -def _extracted_filepath_fn(root: str, source: str, x): +def _extracted_filepath_fn(root: str, source: str, x: str): return os.path.join(root, _EXTRACTED_FOLDERS[source], os.path.basename(x)) -def _filter_fn(source, split, x): +def _filter_fn(source: str, split: str, x: str): + print(type(x)) return os.path.basename(x[0]) in story_fnames[source + "_" + split] -def _hash_urls(s): +def _hash_urls(s: str): """ Returns story filename as a heximal formated SHA1 hash of the input url string. Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py From 8cbeb8009ac325846b4b52d9e20f97851b8abe44 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 13 Jul 2022 22:35:51 +0000 Subject: [PATCH 09/10] removing print statement --- torchtext/datasets/cnndm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 0f1cf6d680..6c22da952a 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -67,7 +67,6 @@ def _extracted_filepath_fn(root: str, source: str, x: str): def _filter_fn(source: str, split: str, x: str): - print(type(x)) return os.path.basename(x[0]) in story_fnames[source + "_" + split] From 5889862f387648b87a049048c7ee1aefbda57e1d Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 13 Jul 2022 23:48:35 +0000 Subject: [PATCH 10/10] correcting types for input args --- torchtext/datasets/cnndm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 6c22da952a..cb638e51a5 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -66,11 +66,11 @@ def _extracted_filepath_fn(root: str, source: str, x: str): return os.path.join(root, _EXTRACTED_FOLDERS[source], os.path.basename(x)) -def _filter_fn(source: str, split: str, x: str): +def _filter_fn(source: str, split: str, x: tuple): return os.path.basename(x[0]) in story_fnames[source + "_" + split] -def _hash_urls(s: str): +def _hash_urls(s: tuple): """ Returns story filename as a heximal formated SHA1 hash of the input url string. Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py