From d6a985dab45005df29b2f6b4fcdf19b1bfc5685b Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 14 Jun 2022 15:08:42 +0000 Subject: [PATCH 01/20] Add CNN-DM dataset --- torchtext/datasets/cnndm.py | 71 +++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 torchtext/datasets/cnndm.py diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py new file mode 100644 index 0000000000..925cbc71b7 --- /dev/null +++ b/torchtext/datasets/cnndm.py @@ -0,0 +1,71 @@ +import os +from functools import partial +from typing import Union, Tuple + +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import ( + _wrap_split_argument, + _create_dataset_directory, +) + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, IterableWrapper + from torchtext._download_hooks import GDriveReader + +URL = 'https://drive.google.com/u/0/uc?id=0BzQ6rtO2VN95a0c3TlZCWkl3aU0&export=download' + +MD5 = "3514e4ab21ab99708ef746581762f71b" + +_PATH = "finished_files.zip" + +_EXTRACTED_FILES = { + "train": os.path.join("finished_files", "train.bin"), + "train": os.path.join("finished_files", "val.bin"), + "test": os.path.join("finished_files", "test.bin"), +} + +_EXTRACTED_FILES_MD5 = { + "train": "2b5389df76cba2757e2d70627269dbfe", + "val": "8efa7ac46fc61395d23131ec56c3d9ba", + "test": "c9b01159cdbb9ff81268c7a3d2278705", +} + +DATASET_NAME = "CNNDM" + + +def _filepath_fn(root: str, _=None): + return os.path.join(root, _PATH) + +def _extracted_filepath_fn(root: str, split: str, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + +def _filter_fn(split: str, x): + return _EXTRACTED_FILES[split] in x[0] + + +def CNNDM(root: str, split: Union[Tuple[str], str]): + + url_dp = IterableWrapper([URL]) + + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, + hash_type="md5", + ) + + cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split)) + ) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") + + return data_dp.routed_decode() + +if __name__ == '__main__': + + data = list(CNNDM(os.path.expanduser('~/.torchtext/cache'), 'val')) + print(type(data)) From 56f7ac464bd2d5c5d4b34e9506ec719e66d86913 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Wed, 15 Jun 2022 22:25:08 +0000 Subject: [PATCH 02/20] Load url_list and stories --- torchtext/datasets/cnndm.py | 199 ++++++++++++++++++++++++++++++------ 1 file changed, 167 insertions(+), 32 deletions(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 925cbc71b7..352a097a37 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -1,4 +1,9 @@ +import sys import os +import hashlib +import struct +import subprocess +import collections from functools import partial from typing import Union, Tuple @@ -9,63 +14,193 @@ ) if is_module_available("torchdata"): - from torchdata.datapipes.iter import FileOpener, IterableWrapper - from torchtext._download_hooks import GDriveReader + from torchdata.datapipes.iter import FileOpener, IterableWrapper, StreamReader, OnlineReader, FileLister, GDriveReader + from torchtext._download_hooks import HttpReader -URL = 'https://drive.google.com/u/0/uc?id=0BzQ6rtO2VN95a0c3TlZCWkl3aU0&export=download' -MD5 = "3514e4ab21ab99708ef746581762f71b" +dm_single_close_quote = u'\u2019' # unicode +dm_double_close_quote = u'\u201d' +END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence +SENTENCE_START = '' +SENTENCE_END = '' -_PATH = "finished_files.zip" +URL_LIST = { + 'train': 'https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt', + 'val': 'https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt', + 'test': 'https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt', +} -_EXTRACTED_FILES = { - "train": os.path.join("finished_files", "train.bin"), - "train": os.path.join("finished_files", "val.bin"), - "test": os.path.join("finished_files", "test.bin"), +URL_LIST_MD5 = { + 'train': 'c8ca98cfcb6cf3f99a404552568490bc', + 'val': '83a3c483b3ed38b1392285bed668bfee', + 'test': '4f3ac04669934dbc746b7061e68a0258', +} + +STORIES_LIST = { + 'cnn': 'https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ', + 'dailymail': 'https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs', +} + +PATH_LIST = { + 'cnn': "cnn_stories.tgz", + 'dailymail': "dailymail_stories.tgz", } -_EXTRACTED_FILES_MD5 = { - "train": "2b5389df76cba2757e2d70627269dbfe", - "val": "8efa7ac46fc61395d23131ec56c3d9ba", - "test": "c9b01159cdbb9ff81268c7a3d2278705", +STORIES_MD5 = { + 'cnn': '85ac23a1926a831e8f46a6b8eaf57263', + 'dailymail': 'f9c5f565e8abe86c38bfa4ae8f96fd72' } -DATASET_NAME = "CNNDM" +_EXTRACTED_FILES = { + "cnn": os.path.join("cnn_stories.tgz", "cnn", "stories"), + "daily_mail": os.path.join("dailymail_stories.tgz", "dailymail", "stories"), +} +def _filepath_fn(root: str, source: str, _=None): + return os.path.join(root, PATH_LIST[source]) -def _filepath_fn(root: str, _=None): - return os.path.join(root, _PATH) -def _extracted_filepath_fn(root: str, split: str, _=None): - return os.path.join(root, _EXTRACTED_FILES[split]) +def _extracted_filepath_fn(root: str, source: str, _=None): + return os.path.join(root, _EXTRACTED_FILES[source]) -def _filter_fn(split: str, x): - return _EXTRACTED_FILES[split] in x[0] +def _modify_res(t): + return t[1] + + +def _get_url_list(split:str): + + url_dp = IterableWrapper([URL_LIST[split]]) + online_dp = OnlineReader(url_dp) + return online_dp.readlines().map(_modify_res) -def CNNDM(root: str, split: Union[Tuple[str], str]): - url_dp = IterableWrapper([URL]) +def _get_stories(root:str, source: str): + + story_dp = IterableWrapper([STORIES_LIST[source]]) - cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=partial(_filepath_fn, root), - hash_dict={_filepath_fn(root): MD5}, + cache_compressed_dp = story_dp.on_disk_cache( + filepath_fn=partial(_filepath_fn, root, source), + hash_dict={_filepath_fn(root, source): STORIES_MD5[source]}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, source)) cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split)) + FileOpener(cache_decompressed_dp, mode="b").load_from_tar() ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") + stories = FileLister(cache_compressed_dp) + stories = FileOpener(stories, mode="b") + stories = stories.load_from_tar() + + stories_dict = {} + + for filename, stream in stories: + stories_dict[filename] = stream + + return stories_dict + + +def _hashhex(s): + """Returns a heximal formated SHA1 hash of the input string.""" + h = hashlib.sha1() + h.update(s) + return h.hexdigest() + + +def _get_url_hashes(url_list): + return [_hashhex(url) for url in url_list] + + +def _read_text_file(text_file): + + lines = [] + with open(text_file, "r") as f: + for line in f: + lines.append(line.strip()) - return data_dp.routed_decode() + return lines + + +def _fix_missing_period(line): + """Adds a period to a line that is missing a period""" + if "@highlight" in line: return line + if line=="": return line + if line[-1] in END_TOKENS: return line + # print line[-1] + return line + " ." + + +def _get_art_abs(story_file): + #lines = _read_text_file(story_file) + lines = story_file.readlines() + # Lowercase everything + lines = [line.decode().lower() for line in lines] + + # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences) + lines = [_fix_missing_period(line) for line in lines] + + # Separate out article and abstract sentences + article_lines = [] + highlights = [] + next_is_highlight = False + for idx,line in enumerate(lines): + if line == "": + continue # empty line + elif line.startswith("@highlight"): + next_is_highlight = True + elif next_is_highlight: + highlights.append(line) + else: + article_lines.append(line) + + # Make article into a single string + article = ' '.join(article_lines) + + # Make abstract into a signle string, putting and tags around the sentences + abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) + + return article, abstract + + +def CNNDM(root: str, split: Union[Tuple[str], str]): + + urls = list(_get_url_list(split)) + + cnn_stories = _get_stories(root, 'cnn') + dm_stories = _get_stories(root, 'dailymail') + + url_hashes = _get_url_hashes(urls) + story_fnames = [s+".story" for s in url_hashes] + num_stories = len(story_fnames) + + for story in story_fnames: + + if os.path.join(root, _EXTRACTED_FILES['cnn'], story) in cnn_stories: + story_file = cnn_stories[os.path.join(root, _EXTRACTED_FILES['cnn'], story)] + elif os.path.join(root, _EXTRACTED_FILES['dailymail'], story) in cnn_stories: + story_file = dm_stories[os.path.join(root, _EXTRACTED_FILES['cnn'], story)] + else: + print(f"Error: Couldn't find story file {story} in either cnn or dailymail directories. Was there an error when loading the files?") + + article, abstract = _get_art_abs(story_file) + print(f"article: {article}\n\n") + print(f"abstract: {abstract}") + break + + return cnn_stories if __name__ == '__main__': - data = list(CNNDM(os.path.expanduser('~/.torchtext/cache'), 'val')) - print(type(data)) + print("start") + out = CNNDM(os.path.expanduser('~/.torchtext/cache'), 'train') + #print(out['/data/home/pabbo/.torchtext/cache/dailymail_stories.tgz/dailymail/stories/70afd5d444a63b17ae663b1ff5b13d4fe1d507f6.story'].read().decode()) + #print(out['/data/home/pabbo/.torchtext/cache/cnn_stories.tgz/cnn/stories/ee8871b15c50d0db17b0179a6d2beab35065f1e9.story'].read().decode()) + #print(out[:10]) + + + + From 9f53689384b7bc182a92d4413b5aaaa2e6213c89 Mon Sep 17 00:00:00 2001 From: Nayef Ahmed Date: Thu, 16 Jun 2022 11:32:28 -0700 Subject: [PATCH 03/20] Convert cnndm output to datapipes --- torchtext/datasets/cnndm.py | 194 +++++++++++++++++++----------------- 1 file changed, 102 insertions(+), 92 deletions(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 352a097a37..aafedfb241 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -1,9 +1,9 @@ -import sys -import os +import collections import hashlib +import os import struct import subprocess -import collections +import sys from functools import partial from typing import Union, Tuple @@ -14,71 +14,88 @@ ) if is_module_available("torchdata"): - from torchdata.datapipes.iter import FileOpener, IterableWrapper, StreamReader, OnlineReader, FileLister, GDriveReader + from torchdata.datapipes.iter import ( + FileOpener, + IterableWrapper, + StreamReader, + OnlineReader, + FileLister, + GDriveReader, + ) from torchtext._download_hooks import HttpReader -dm_single_close_quote = u'\u2019' # unicode -dm_double_close_quote = u'\u201d' -END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence -SENTENCE_START = '' -SENTENCE_END = '' +dm_single_close_quote = "\u2019" # unicode +dm_double_close_quote = "\u201d" +END_TOKENS = [ + ".", + "!", + "?", + "...", + "'", + "`", + '"', + dm_single_close_quote, + dm_double_close_quote, + ")", +] # acceptable ways to end a sentence +SENTENCE_START = "" +SENTENCE_END = "" + +DATASET_NAME = "CNNDM" URL_LIST = { - 'train': 'https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt', - 'val': 'https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt', - 'test': 'https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt', + "train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt", + "val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt", + "test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt", } URL_LIST_MD5 = { - 'train': 'c8ca98cfcb6cf3f99a404552568490bc', - 'val': '83a3c483b3ed38b1392285bed668bfee', - 'test': '4f3ac04669934dbc746b7061e68a0258', + "train": "c8ca98cfcb6cf3f99a404552568490bc", + "val": "83a3c483b3ed38b1392285bed668bfee", + "test": "4f3ac04669934dbc746b7061e68a0258", } STORIES_LIST = { - 'cnn': 'https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ', - 'dailymail': 'https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs', + "cnn": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ", + "dailymail": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs", } PATH_LIST = { - 'cnn': "cnn_stories.tgz", - 'dailymail': "dailymail_stories.tgz", + "cnn": "cnn_stories.tgz", + "dailymail": "dailymail_stories.tgz", } -STORIES_MD5 = { - 'cnn': '85ac23a1926a831e8f46a6b8eaf57263', - 'dailymail': 'f9c5f565e8abe86c38bfa4ae8f96fd72' -} +STORIES_MD5 = {"cnn": "85ac23a1926a831e8f46a6b8eaf57263", "dailymail": "f9c5f565e8abe86c38bfa4ae8f96fd72"} -_EXTRACTED_FILES = { - "cnn": os.path.join("cnn_stories.tgz", "cnn", "stories"), - "daily_mail": os.path.join("dailymail_stories.tgz", "dailymail", "stories"), +_EXTRACTED_FOLDERS = { + "cnn": os.path.join("cnn", "stories"), + "daily_mail": os.path.join("dailymail", "stories"), } + def _filepath_fn(root: str, source: str, _=None): return os.path.join(root, PATH_LIST[source]) -def _extracted_filepath_fn(root: str, source: str, _=None): - return os.path.join(root, _EXTRACTED_FILES[source]) +def _extracted_file_path_fn(root: str, source: str, t): + return os.path.join(root, _EXTRACTED_FOLDERS[source]) def _modify_res(t): return t[1] -def _get_url_list(split:str): +def _get_url_list(split: str): url_dp = IterableWrapper([URL_LIST[split]]) online_dp = OnlineReader(url_dp) return online_dp.readlines().map(_modify_res) -def _get_stories(root:str, source: str): - +def _get_stories(root: str, source: str): story_dp = IterableWrapper([STORIES_LIST[source]]) - + cache_compressed_dp = story_dp.on_disk_cache( filepath_fn=partial(_filepath_fn, root, source), hash_dict={_filepath_fn(root, source): STORIES_MD5[source]}, @@ -86,22 +103,13 @@ def _get_stories(root:str, source: str): ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, source)) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_tar() - ) - cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - - stories = FileLister(cache_compressed_dp) - stories = FileOpener(stories, mode="b") - stories = stories.load_from_tar() - - stories_dict = {} - - for filename, stream in stories: - stories_dict[filename] = stream - - return stories_dict + # cache_decompressed_dp = cache_compressed_dp.on_disk_cache() + cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar() + # .map(lambda t: (os.path.join(root, source, "stories", os.path.basename(t[0])), t[1])) + # cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=False) + stories = cache_decompressed_dp + stories = stories.map(lambda t: _get_art_abs(t[1])) + return stories def _hashhex(s): @@ -115,27 +123,19 @@ def _get_url_hashes(url_list): return [_hashhex(url) for url in url_list] -def _read_text_file(text_file): - - lines = [] - with open(text_file, "r") as f: - for line in f: - lines.append(line.strip()) - - return lines - - def _fix_missing_period(line): """Adds a period to a line that is missing a period""" - if "@highlight" in line: return line - if line=="": return line - if line[-1] in END_TOKENS: return line + if "@highlight" in line: + return line + if line == "": + return line + if line[-1] in END_TOKENS: + return line # print line[-1] return line + " ." def _get_art_abs(story_file): - #lines = _read_text_file(story_file) lines = story_file.readlines() # Lowercase everything lines = [line.decode().lower() for line in lines] @@ -147,9 +147,9 @@ def _get_art_abs(story_file): article_lines = [] highlights = [] next_is_highlight = False - for idx,line in enumerate(lines): + for idx, line in enumerate(lines): if line == "": - continue # empty line + continue # empty line elif line.startswith("@highlight"): next_is_highlight = True elif next_is_highlight: @@ -158,49 +158,59 @@ def _get_art_abs(story_file): article_lines.append(line) # Make article into a single string - article = ' '.join(article_lines) + article = " ".join(article_lines) # Make abstract into a signle string, putting and tags around the sentences - abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) + abstract = " ".join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) return article, abstract -def CNNDM(root: str, split: Union[Tuple[str], str]): +def _get_story_files(url_hash): + return url_hash + ".story" - urls = list(_get_url_list(split)) - - cnn_stories = _get_stories(root, 'cnn') - dm_stories = _get_stories(root, 'dailymail') - - url_hashes = _get_url_hashes(urls) - story_fnames = [s+".story" for s in url_hashes] - num_stories = len(story_fnames) - for story in story_fnames: +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train", "test")) +def CNNDM(root: str, split: Union[Tuple[str], str]): + urls = list(_get_url_list(split)) - if os.path.join(root, _EXTRACTED_FILES['cnn'], story) in cnn_stories: - story_file = cnn_stories[os.path.join(root, _EXTRACTED_FILES['cnn'], story)] - elif os.path.join(root, _EXTRACTED_FILES['dailymail'], story) in cnn_stories: - story_file = dm_stories[os.path.join(root, _EXTRACTED_FILES['cnn'], story)] - else: - print(f"Error: Couldn't find story file {story} in either cnn or dailymail directories. Was there an error when loading the files?") + cnn_stories = _get_stories(root, "cnn") + # dm_stories = _get_stories(root, 'dailymail') - article, abstract = _get_art_abs(story_file) - print(f"article: {article}\n\n") - print(f"abstract: {abstract}") - break + # Things to figure out + # * combine the contents of cnn/dm tars to get all (filepaths, streams) + # * filter files based on which split we're working with + # * gow to split up these helper functions + # * store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn of the on_disk_cache_dp which caches the files extracted from the tar + # * how to cache the contents of the extracted tar file return cnn_stories -if __name__ == '__main__': + url_hashes = _get_url_hashes(urls) + story_fnames = url_hashes.map(_get_story_files) + # story_fnames = [s+".story" for s in url_hashes] + + # for story in story_fnames: + def _parse_story_file(story): + if os.path.join(root, _EXTRACTED_FILES["cnn"], story) in cnn_stories: + story_file = cnn_stories[os.path.join(root, _EXTRACTED_FILES["cnn"], story)] + elif os.path.join(root, _EXTRACTED_FILES["dailymail"], story) in cnn_stories: + story_file = dm_stories[os.path.join(root, _EXTRACTED_FILES["cnn"], story)] + else: + print( + f"Error: Couldn't find story file {story} in either cnn or dailymail directories. Was there an error when loading the files?" + ) - print("start") - out = CNNDM(os.path.expanduser('~/.torchtext/cache'), 'train') - #print(out['/data/home/pabbo/.torchtext/cache/dailymail_stories.tgz/dailymail/stories/70afd5d444a63b17ae663b1ff5b13d4fe1d507f6.story'].read().decode()) - #print(out['/data/home/pabbo/.torchtext/cache/cnn_stories.tgz/cnn/stories/ee8871b15c50d0db17b0179a6d2beab35065f1e9.story'].read().decode()) - #print(out[:10]) + return _get_art_abs(story_file) + return story_fnames.map(_parse_story_file) +if __name__ == "__main__": + print("start") + out = CNNDM(os.path.expanduser("~/.torchtext/cache"), "train") + # print(out['/data/home/pabbo/.torchtext/cache/dailymail_stories.tgz/dailymail/stories/70afd5d444a63b17ae663b1ff5b13d4fe1d507f6.story'].read().decode()) + # print(out['/data/home/pabbo/.torchtext/cache/cnn_stories.tgz/cnn/stories/ee8871b15c50d0db17b0179a6d2beab35065f1e9.story'].read().decode()) + # print(out[:10]) From b9c816e1958d3f36a0ed92e5cab49bc5e898c386 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Thu, 16 Jun 2022 15:21:50 -0400 Subject: [PATCH 04/20] run pre-commit --- torchtext/datasets/cnndm.py | 136 +++++++++++++++++++----------------- 1 file changed, 71 insertions(+), 65 deletions(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 352a097a37..cc51ec7bd9 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -1,61 +1,67 @@ -import sys -import os import hashlib -import struct -import subprocess -import collections +import os from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available -from torchtext.data.datasets_utils import ( - _wrap_split_argument, - _create_dataset_directory, -) if is_module_available("torchdata"): - from torchdata.datapipes.iter import FileOpener, IterableWrapper, StreamReader, OnlineReader, FileLister, GDriveReader - from torchtext._download_hooks import HttpReader + from torchdata.datapipes.iter import ( + FileOpener, + IterableWrapper, + OnlineReader, + FileLister, + GDriveReader, + ) -dm_single_close_quote = u'\u2019' # unicode -dm_double_close_quote = u'\u201d' -END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence -SENTENCE_START = '' -SENTENCE_END = '' +dm_single_close_quote = "\u2019" # unicode +dm_double_close_quote = "\u201d" +END_TOKENS = [ + ".", + "!", + "?", + "...", + "'", + "`", + '"', + dm_single_close_quote, + dm_double_close_quote, + ")", +] # acceptable ways to end a sentence +SENTENCE_START = "" +SENTENCE_END = "" URL_LIST = { - 'train': 'https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt', - 'val': 'https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt', - 'test': 'https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt', + "train": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt", + "val": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt", + "test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt", } URL_LIST_MD5 = { - 'train': 'c8ca98cfcb6cf3f99a404552568490bc', - 'val': '83a3c483b3ed38b1392285bed668bfee', - 'test': '4f3ac04669934dbc746b7061e68a0258', + "train": "c8ca98cfcb6cf3f99a404552568490bc", + "val": "83a3c483b3ed38b1392285bed668bfee", + "test": "4f3ac04669934dbc746b7061e68a0258", } STORIES_LIST = { - 'cnn': 'https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ', - 'dailymail': 'https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs', + "cnn": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ", + "dailymail": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs", } PATH_LIST = { - 'cnn': "cnn_stories.tgz", - 'dailymail': "dailymail_stories.tgz", + "cnn": "cnn_stories.tgz", + "dailymail": "dailymail_stories.tgz", } -STORIES_MD5 = { - 'cnn': '85ac23a1926a831e8f46a6b8eaf57263', - 'dailymail': 'f9c5f565e8abe86c38bfa4ae8f96fd72' -} +STORIES_MD5 = {"cnn": "85ac23a1926a831e8f46a6b8eaf57263", "dailymail": "f9c5f565e8abe86c38bfa4ae8f96fd72"} _EXTRACTED_FILES = { "cnn": os.path.join("cnn_stories.tgz", "cnn", "stories"), "daily_mail": os.path.join("dailymail_stories.tgz", "dailymail", "stories"), } + def _filepath_fn(root: str, source: str, _=None): return os.path.join(root, PATH_LIST[source]) @@ -68,17 +74,17 @@ def _modify_res(t): return t[1] -def _get_url_list(split:str): +def _get_url_list(split: str): url_dp = IterableWrapper([URL_LIST[split]]) online_dp = OnlineReader(url_dp) return online_dp.readlines().map(_modify_res) -def _get_stories(root:str, source: str): +def _get_stories(root: str, source: str): story_dp = IterableWrapper([STORIES_LIST[source]]) - + cache_compressed_dp = story_dp.on_disk_cache( filepath_fn=partial(_filepath_fn, root, source), hash_dict={_filepath_fn(root, source): STORIES_MD5[source]}, @@ -87,9 +93,7 @@ def _get_stories(root:str, source: str): cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, source)) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").load_from_tar() - ) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar() cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) stories = FileLister(cache_compressed_dp) @@ -100,7 +104,7 @@ def _get_stories(root:str, source: str): for filename, stream in stories: stories_dict[filename] = stream - + return stories_dict @@ -116,26 +120,29 @@ def _get_url_hashes(url_list): def _read_text_file(text_file): - + lines = [] with open(text_file, "r") as f: for line in f: lines.append(line.strip()) - + return lines def _fix_missing_period(line): """Adds a period to a line that is missing a period""" - if "@highlight" in line: return line - if line=="": return line - if line[-1] in END_TOKENS: return line + if "@highlight" in line: + return line + if line == "": + return line + if line[-1] in END_TOKENS: + return line # print line[-1] return line + " ." def _get_art_abs(story_file): - #lines = _read_text_file(story_file) + # lines = _read_text_file(story_file) lines = story_file.readlines() # Lowercase everything lines = [line.decode().lower() for line in lines] @@ -147,9 +154,9 @@ def _get_art_abs(story_file): article_lines = [] highlights = [] next_is_highlight = False - for idx,line in enumerate(lines): + for idx, line in enumerate(lines): if line == "": - continue # empty line + continue # empty line elif line.startswith("@highlight"): next_is_highlight = True elif next_is_highlight: @@ -158,10 +165,10 @@ def _get_art_abs(story_file): article_lines.append(line) # Make article into a single string - article = ' '.join(article_lines) + article = " ".join(article_lines) # Make abstract into a signle string, putting and tags around the sentences - abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) + abstract = " ".join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) return article, abstract @@ -170,21 +177,23 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): urls = list(_get_url_list(split)) - cnn_stories = _get_stories(root, 'cnn') - dm_stories = _get_stories(root, 'dailymail') + cnn_stories = _get_stories(root, "cnn") + dm_stories = _get_stories(root, "dailymail") url_hashes = _get_url_hashes(urls) - story_fnames = [s+".story" for s in url_hashes] - num_stories = len(story_fnames) + story_fnames = [s + ".story" for s in url_hashes] + # num_stories = len(story_fnames) for story in story_fnames: - if os.path.join(root, _EXTRACTED_FILES['cnn'], story) in cnn_stories: - story_file = cnn_stories[os.path.join(root, _EXTRACTED_FILES['cnn'], story)] - elif os.path.join(root, _EXTRACTED_FILES['dailymail'], story) in cnn_stories: - story_file = dm_stories[os.path.join(root, _EXTRACTED_FILES['cnn'], story)] + if os.path.join(root, _EXTRACTED_FILES["cnn"], story) in cnn_stories: + story_file = cnn_stories[os.path.join(root, _EXTRACTED_FILES["cnn"], story)] + elif os.path.join(root, _EXTRACTED_FILES["dailymail"], story) in cnn_stories: + story_file = dm_stories[os.path.join(root, _EXTRACTED_FILES["cnn"], story)] else: - print(f"Error: Couldn't find story file {story} in either cnn or dailymail directories. Was there an error when loading the files?") + print( + f"Error: Couldn't find story file {story} in either cnn or dailymail directories. Was there an error when loading the files?" + ) article, abstract = _get_art_abs(story_file) print(f"article: {article}\n\n") @@ -193,14 +202,11 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): return cnn_stories -if __name__ == '__main__': - - print("start") - out = CNNDM(os.path.expanduser('~/.torchtext/cache'), 'train') - #print(out['/data/home/pabbo/.torchtext/cache/dailymail_stories.tgz/dailymail/stories/70afd5d444a63b17ae663b1ff5b13d4fe1d507f6.story'].read().decode()) - #print(out['/data/home/pabbo/.torchtext/cache/cnn_stories.tgz/cnn/stories/ee8871b15c50d0db17b0179a6d2beab35065f1e9.story'].read().decode()) - #print(out[:10]) - - +if __name__ == "__main__": + print("start") + out = CNNDM(os.path.expanduser("~/.torchtext/cache"), "train") + # print(out['/data/home/pabbo/.torchtext/cache/dailymail_stories.tgz/dailymail/stories/70afd5d444a63b17ae663b1ff5b13d4fe1d507f6.story'].read().decode()) + # print(out['/data/home/pabbo/.torchtext/cache/cnn_stories.tgz/cnn/stories/ee8871b15c50d0db17b0179a6d2beab35065f1e9.story'].read().decode()) + # print(out[:10]) From a46b28c1fa3740e75a4861b9c97571d10eda9ce1 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Thu, 16 Jun 2022 22:44:51 +0000 Subject: [PATCH 05/20] Load and filter tar files --- torchtext/datasets/cnndm.py | 117 ++++++++++++++++++++---------------- 1 file changed, 65 insertions(+), 52 deletions(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index ffe9e0ab05..47d0c0e0bc 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -8,6 +8,10 @@ from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import ( + _wrap_split_argument, + _create_dataset_directory, +) if is_module_available("torchdata"): from torchdata.datapipes.iter import ( @@ -33,6 +37,7 @@ dm_single_close_quote, dm_double_close_quote, ")", + "\n", ] # acceptable ways to end a sentence SENTENCE_START = "" SENTENCE_END = "" @@ -80,6 +85,10 @@ def _modify_res(t): return t[1] +def _filter_fn(story_fnames, x): + return os.path.basename(x[0]) in story_fnames + + def _get_url_list(split: str): url_dp = IterableWrapper([URL_LIST[split]]) @@ -87,25 +96,6 @@ def _get_url_list(split: str): return online_dp.readlines().map(_modify_res) -def _get_stories(root: str, source: str): - story_dp = IterableWrapper([STORIES_LIST[source]]) - - cache_compressed_dp = story_dp.on_disk_cache( - filepath_fn=partial(_filepath_fn, root, source), - hash_dict={_filepath_fn(root, source): STORIES_MD5[source]}, - hash_type="md5", - ) - - cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - # cache_decompressed_dp = cache_compressed_dp.on_disk_cache() - cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar() - # .map(lambda t: (os.path.join(root, source, "stories", os.path.basename(t[0])), t[1])) - # cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=False) - stories = cache_decompressed_dp - stories = stories.map(lambda t: _get_art_abs(t[1])) - return stories - - def _hashhex(s): """Returns a heximal formated SHA1 hash of the input string.""" h = hashlib.sha1() @@ -134,7 +124,8 @@ def _get_art_abs(story_file): # Lowercase everything lines = [line.decode().lower() for line in lines] - # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences) + # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; + # consequently they end up in the body of the article as run-on sentences) lines = [_fix_missing_period(line) for line in lines] # Separate out article and abstract sentences @@ -154,57 +145,79 @@ def _get_art_abs(story_file): # Make article into a single string article = " ".join(article_lines) - # Make abstract into a signle string, putting and tags around the sentences + # Make abstract into a single string, putting and tags around the sentences abstract = " ".join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) return article, abstract + +def _load_stories(root: str, source:str): + + story_dp = IterableWrapper([STORIES_LIST[source]]) -def _get_story_files(url_hash): - return url_hash + ".story" + cache_compressed_dp = story_dp.on_disk_cache( + filepath_fn=partial(_filepath_fn, root, source), + hash_dict={_filepath_fn(root, source): STORIES_MD5[source]}, + hash_type="md5", + ) + + cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + # TODO: cache the extraction + cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar() + + return cache_decompressed_dp + +# commented out because currently not being used +# def _get_story_files(url_hash): +# return url_hash + ".story" @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def CNNDM(root: str, split: Union[Tuple[str], str]): - urls = list(_get_url_list(split)) - - cnn_stories = _get_stories(root, "cnn") - # dm_stories = _get_stories(root, 'dailymail') - # Things to figure out - # * combine the contents of cnn/dm tars to get all (filepaths, streams) - # * filter files based on which split we're working with - # * gow to split up these helper functions # * store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn of the on_disk_cache_dp which caches the files extracted from the tar # * how to cache the contents of the extracted tar file - return cnn_stories + # commented out because currently not being used + # url_hashes = _get_url_hashes(urls) + # story_fnames = url_hashes.map(_get_story_files) + + # def _parse_story_file(story): + # if os.path.join(root, _EXTRACTED_FILES["cnn"], story) in cnn_stories: + # story_file = cnn_stories[os.path.join(root, _EXTRACTED_FILES["cnn"], story)] + # elif os.path.join(root, _EXTRACTED_FILES["dailymail"], story) in cnn_stories: + # story_file = dm_stories[os.path.join(root, _EXTRACTED_FILES["cnn"], story)] + # else: + # print( + # f"Error: Couldn't find story file {story} in either cnn or dailymail directories. Was there an error when loading the files?" + # ) + + # return _get_art_abs(story_file) + # return story_fnames.map(_parse_story_file) + + # TODO: store story_fnames on disk + urls = list(_get_url_list(split)) url_hashes = _get_url_hashes(urls) - story_fnames = url_hashes.map(_get_story_files) - # story_fnames = [s+".story" for s in url_hashes] - - # for story in story_fnames: - def _parse_story_file(story): - if os.path.join(root, _EXTRACTED_FILES["cnn"], story) in cnn_stories: - story_file = cnn_stories[os.path.join(root, _EXTRACTED_FILES["cnn"], story)] - elif os.path.join(root, _EXTRACTED_FILES["dailymail"], story) in cnn_stories: - story_file = dm_stories[os.path.join(root, _EXTRACTED_FILES["cnn"], story)] - else: - print( - f"Error: Couldn't find story file {story} in either cnn or dailymail directories. Was there an error when loading the files?" - ) + story_fnames = set(s+".story" for s in url_hashes) - return _get_art_abs(story_file) - return story_fnames.map(_parse_story_file) + cnn_dp = _load_stories(root, 'cnn') + dailymail_dp = _load_stories(root, 'dailymail') + data_dp = cnn_dp.concat(dailymail_dp) + + data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) + data_dp = data_dp.map(lambda t: _get_art_abs(t[1])) + + return data_dp.shuffle().set_shuffle(False).sharding_filter() if __name__ == "__main__": - print("start") - out = CNNDM(os.path.expanduser("~/.torchtext/cache"), "train") - # print(out['/data/home/pabbo/.torchtext/cache/dailymail_stories.tgz/dailymail/stories/70afd5d444a63b17ae663b1ff5b13d4fe1d507f6.story'].read().decode()) - # print(out['/data/home/pabbo/.torchtext/cache/cnn_stories.tgz/cnn/stories/ee8871b15c50d0db17b0179a6d2beab35065f1e9.story'].read().decode()) - # print(out[:10]) + #out = CNNDM(os.path.expanduser("~/.torchtext/cache"), "val") + #ex = iter(out) + #ex = next(ex) + + #print(f"article:\n{ex[0]}") + #print(f"abstract:\n{ex[1]}") From 5588c165b3cc95c14f7c87bb6493d8e9f9d716ce Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Thu, 16 Jun 2022 18:58:32 -0400 Subject: [PATCH 06/20] pre-commit --- torchtext/datasets/cnndm.py | 40 ++++++++++++++++--------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 47d0c0e0bc..939112af7d 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -1,9 +1,5 @@ -import collections import hashlib import os -import struct -import subprocess -import sys from functools import partial from typing import Union, Tuple @@ -18,10 +14,8 @@ FileOpener, IterableWrapper, OnlineReader, - FileLister, GDriveReader, ) - from torchtext._download_hooks import HttpReader dm_single_close_quote = "\u2019" # unicode @@ -124,7 +118,7 @@ def _get_art_abs(story_file): # Lowercase everything lines = [line.decode().lower() for line in lines] - # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; + # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; # consequently they end up in the body of the article as run-on sentences) lines = [_fix_missing_period(line) for line in lines] @@ -149,10 +143,10 @@ def _get_art_abs(story_file): abstract = " ".join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) return article, abstract - -def _load_stories(root: str, source:str): - + +def _load_stories(root: str, source: str): + story_dp = IterableWrapper([STORIES_LIST[source]]) cache_compressed_dp = story_dp.on_disk_cache( @@ -167,6 +161,7 @@ def _load_stories(root: str, source:str): return cache_decompressed_dp + # commented out because currently not being used # def _get_story_files(url_hash): # return url_hash + ".story" @@ -196,28 +191,27 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): # return _get_art_abs(story_file) # return story_fnames.map(_parse_story_file) - + # TODO: store story_fnames on disk urls = list(_get_url_list(split)) url_hashes = _get_url_hashes(urls) - story_fnames = set(s+".story" for s in url_hashes) + story_fnames = set(s + ".story" for s in url_hashes) - - cnn_dp = _load_stories(root, 'cnn') - dailymail_dp = _load_stories(root, 'dailymail') + cnn_dp = _load_stories(root, "cnn") + dailymail_dp = _load_stories(root, "dailymail") data_dp = cnn_dp.concat(dailymail_dp) - + data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) data_dp = data_dp.map(lambda t: _get_art_abs(t[1])) - + return data_dp.shuffle().set_shuffle(False).sharding_filter() if __name__ == "__main__": + print("start") + # out = CNNDM(os.path.expanduser("~/.torchtext/cache"), "val") + # ex = iter(out) + # ex = next(ex) - #out = CNNDM(os.path.expanduser("~/.torchtext/cache"), "val") - #ex = iter(out) - #ex = next(ex) - - #print(f"article:\n{ex[0]}") - #print(f"abstract:\n{ex[1]}") + # print(f"article:\n{ex[0]}") + # print(f"abstract:\n{ex[1]}") From be5a2449dc9e794e3ea85973b964783f9a5428d6 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 17 Jun 2022 14:55:24 +0000 Subject: [PATCH 07/20] Create datapipe to parse out article and abstract --- torchtext/data/datasets_utils.py | 70 ++++++++++++++++++ torchtext/datasets/cnndm.py | 120 +++---------------------------- 2 files changed, 81 insertions(+), 109 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 0c39751176..65893a08d0 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -310,3 +310,73 @@ def __iter__(self): columns[i].append(column) if len(columns) > 0: yield columns + +@functional_datapipe("parse_cnndm") +class _ParseCNNDMData(IterDataPipe): + """Iterable DataPipe to parse the article and abstract from a stream""" + + dm_single_close_quote = "\u2019" # unicode + dm_double_close_quote = "\u201d" + END_TOKENS = [ + ".", + "!", + "?", + "...", + "'", + "`", + '"', + dm_single_close_quote, + dm_double_close_quote, + ")", + "\n", + ] # acceptable ways to end a sentence + SENTENCE_START = "" + SENTENCE_END = "" + + def __init__(self, source_datapipe) -> None: + self.source_datapipe = source_datapipe + + def _fix_missing_period(self, line): + """Adds a period to a line that is missing a period""" + if "@highlight" in line: + return line + if line == "": + return line + if line[-1] in self.END_TOKENS: + return line + # print line[-1] + return line + " ." + + def __iter__(self): + + for _, stream in self.source_datapipe: + + lines = stream.readlines() + # Lowercase everything + lines = [line.decode().lower() for line in lines] + + # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; + # consequently they end up in the body of the article as run-on sentences) + lines = [self._fix_missing_period(line) for line in lines] + + # Separate out article and abstract sentences + article_lines = [] + highlights = [] + next_is_highlight = False + for idx, line in enumerate(lines): + if line == "": + continue # empty line + elif line.startswith("@highlight"): + next_is_highlight = True + elif next_is_highlight: + highlights.append(line) + else: + article_lines.append(line) + + # Make article into a single string + article = " ".join(article_lines) + + # Make abstract into a single string, putting and tags around the sentences + abstract = " ".join(["%s %s %s" % (self.SENTENCE_START, sent, self.SENTENCE_END) for sent in highlights]) + + yield article, abstract \ No newline at end of file diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 939112af7d..0757141258 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -17,24 +17,6 @@ GDriveReader, ) - -dm_single_close_quote = "\u2019" # unicode -dm_double_close_quote = "\u201d" -END_TOKENS = [ - ".", - "!", - "?", - "...", - "'", - "`", - '"', - dm_single_close_quote, - dm_double_close_quote, - ")", - "\n", -] # acceptable ways to end a sentence -SENTENCE_START = "" -SENTENCE_END = "" DATASET_NAME = "CNNDM" URL_LIST = { @@ -66,85 +48,33 @@ "daily_mail": os.path.join("dailymail", "stories"), } - def _filepath_fn(root: str, source: str, _=None): return os.path.join(root, PATH_LIST[source]) - -def _extracted_file_path_fn(root: str, source: str, t): +def _extracted_filepath_fn(root: str, source: str, t): return os.path.join(root, _EXTRACTED_FOLDERS[source]) - def _modify_res(t): return t[1] - def _filter_fn(story_fnames, x): return os.path.basename(x[0]) in story_fnames - def _get_url_list(split: str): url_dp = IterableWrapper([URL_LIST[split]]) online_dp = OnlineReader(url_dp) return online_dp.readlines().map(_modify_res) - def _hashhex(s): """Returns a heximal formated SHA1 hash of the input string.""" h = hashlib.sha1() h.update(s) return h.hexdigest() - def _get_url_hashes(url_list): return [_hashhex(url) for url in url_list] - -def _fix_missing_period(line): - """Adds a period to a line that is missing a period""" - if "@highlight" in line: - return line - if line == "": - return line - if line[-1] in END_TOKENS: - return line - # print line[-1] - return line + " ." - - -def _get_art_abs(story_file): - lines = story_file.readlines() - # Lowercase everything - lines = [line.decode().lower() for line in lines] - - # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; - # consequently they end up in the body of the article as run-on sentences) - lines = [_fix_missing_period(line) for line in lines] - - # Separate out article and abstract sentences - article_lines = [] - highlights = [] - next_is_highlight = False - for idx, line in enumerate(lines): - if line == "": - continue # empty line - elif line.startswith("@highlight"): - next_is_highlight = True - elif next_is_highlight: - highlights.append(line) - else: - article_lines.append(line) - - # Make article into a single string - article = " ".join(article_lines) - - # Make abstract into a single string, putting and tags around the sentences - abstract = " ".join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) - - return article, abstract - - def _load_stories(root: str, source: str): story_dp = IterableWrapper([STORIES_LIST[source]]) @@ -156,43 +86,16 @@ def _load_stories(root: str, source: str): ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - # TODO: cache the extraction + # TODO: cache the contents of the extracted tar file cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar() return cache_decompressed_dp - -# commented out because currently not being used -# def _get_story_files(url_hash): -# return url_hash + ".story" - - @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(("train", "test")) +@_wrap_split_argument(("train", "val", "test")) def CNNDM(root: str, split: Union[Tuple[str], str]): - # Things to figure out - # * store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn of the on_disk_cache_dp which caches the files extracted from the tar - # * how to cache the contents of the extracted tar file - - # commented out because currently not being used - # url_hashes = _get_url_hashes(urls) - # story_fnames = url_hashes.map(_get_story_files) - - # def _parse_story_file(story): - # if os.path.join(root, _EXTRACTED_FILES["cnn"], story) in cnn_stories: - # story_file = cnn_stories[os.path.join(root, _EXTRACTED_FILES["cnn"], story)] - # elif os.path.join(root, _EXTRACTED_FILES["dailymail"], story) in cnn_stories: - # story_file = dm_stories[os.path.join(root, _EXTRACTED_FILES["cnn"], story)] - # else: - # print( - # f"Error: Couldn't find story file {story} in either cnn or dailymail directories. Was there an error when loading the files?" - # ) - - # return _get_art_abs(story_file) - - # return story_fnames.map(_parse_story_file) - - # TODO: store story_fnames on disk + # TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn + # of the on_disk_cache_dp which caches the files extracted from the tar urls = list(_get_url_list(split)) url_hashes = _get_url_hashes(urls) story_fnames = set(s + ".story" for s in url_hashes) @@ -202,16 +105,15 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): data_dp = cnn_dp.concat(dailymail_dp) data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) - data_dp = data_dp.map(lambda t: _get_art_abs(t[1])) - return data_dp.shuffle().set_shuffle(False).sharding_filter() + return data_dp.parse_cnndm().shuffle().set_shuffle(False).sharding_filter() if __name__ == "__main__": print("start") - # out = CNNDM(os.path.expanduser("~/.torchtext/cache"), "val") - # ex = iter(out) - # ex = next(ex) + #out = CNNDM(os.path.expanduser("~/.torchtext/cache"), "val") + #ex = iter(out) + #ex = next(ex) - # print(f"article:\n{ex[0]}") - # print(f"abstract:\n{ex[1]}") + #print(f"article:\n{ex[0]}") + #print(f"abstract:\n{ex[1]}") From 2fa90186d96baf3cecc95a75efaad89247e9a792 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 17 Jun 2022 11:03:18 -0400 Subject: [PATCH 08/20] pre-commit run --- torchtext/data/datasets_utils.py | 9 +++++---- torchtext/datasets/cnndm.py | 21 +++++++++++++++------ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 65893a08d0..e3a44605a5 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -311,6 +311,7 @@ def __iter__(self): if len(columns) > 0: yield columns + @functional_datapipe("parse_cnndm") class _ParseCNNDMData(IterDataPipe): """Iterable DataPipe to parse the article and abstract from a stream""" @@ -348,14 +349,14 @@ def _fix_missing_period(self, line): return line + " ." def __iter__(self): - + for _, stream in self.source_datapipe: - + lines = stream.readlines() # Lowercase everything lines = [line.decode().lower() for line in lines] - # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; + # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; # consequently they end up in the body of the article as run-on sentences) lines = [self._fix_missing_period(line) for line in lines] @@ -379,4 +380,4 @@ def __iter__(self): # Make abstract into a single string, putting and tags around the sentences abstract = " ".join(["%s %s %s" % (self.SENTENCE_START, sent, self.SENTENCE_END) for sent in highlights]) - yield article, abstract \ No newline at end of file + yield article, abstract diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 0757141258..62159f2e65 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -48,33 +48,41 @@ "daily_mail": os.path.join("dailymail", "stories"), } + def _filepath_fn(root: str, source: str, _=None): return os.path.join(root, PATH_LIST[source]) + def _extracted_filepath_fn(root: str, source: str, t): return os.path.join(root, _EXTRACTED_FOLDERS[source]) + def _modify_res(t): return t[1] + def _filter_fn(story_fnames, x): return os.path.basename(x[0]) in story_fnames + def _get_url_list(split: str): url_dp = IterableWrapper([URL_LIST[split]]) online_dp = OnlineReader(url_dp) return online_dp.readlines().map(_modify_res) + def _hashhex(s): """Returns a heximal formated SHA1 hash of the input string.""" h = hashlib.sha1() h.update(s) return h.hexdigest() + def _get_url_hashes(url_list): return [_hashhex(url) for url in url_list] + def _load_stories(root: str, source: str): story_dp = IterableWrapper([STORIES_LIST[source]]) @@ -91,10 +99,11 @@ def _load_stories(root: str, source: str): return cache_decompressed_dp + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "val", "test")) def CNNDM(root: str, split: Union[Tuple[str], str]): - # TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn + # TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn # of the on_disk_cache_dp which caches the files extracted from the tar urls = list(_get_url_list(split)) url_hashes = _get_url_hashes(urls) @@ -111,9 +120,9 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): if __name__ == "__main__": print("start") - #out = CNNDM(os.path.expanduser("~/.torchtext/cache"), "val") - #ex = iter(out) - #ex = next(ex) + # out = CNNDM(os.path.expanduser("~/.torchtext/cache"), "val") + # ex = iter(out) + # ex = next(ex) - #print(f"article:\n{ex[0]}") - #print(f"abstract:\n{ex[1]}") + # print(f"article:\n{ex[0]}") + # print(f"abstract:\n{ex[1]}") From e04fe4596e47ecfd01cd3800ca64e619b33c1da8 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 17 Jun 2022 15:35:56 +0000 Subject: [PATCH 09/20] Turn url list helper functions into datapipes --- torchtext/data/datasets_utils.py | 28 +++++++++++++++++++++++-- torchtext/datasets/cnndm.py | 36 ++++---------------------------- 2 files changed, 30 insertions(+), 34 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index e3a44605a5..83a4e52480 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -1,6 +1,7 @@ import codecs import functools import inspect +import hashlib import os from torch.utils.data import functional_datapipe, IterDataPipe @@ -312,9 +313,9 @@ def __iter__(self): yield columns -@functional_datapipe("parse_cnndm") +@functional_datapipe("parse_cnndm_data") class _ParseCNNDMData(IterDataPipe): - """Iterable DataPipe to parse the article and abstract from a stream""" + """Iterable DataPipe to parse the article and abstract from a CNNDM data stream""" dm_single_close_quote = "\u2019" # unicode dm_double_close_quote = "\u201d" @@ -381,3 +382,26 @@ def __iter__(self): abstract = " ".join(["%s %s %s" % (self.SENTENCE_START, sent, self.SENTENCE_END) for sent in highlights]) yield article, abstract + + +@functional_datapipe("parse_cnndm_split") +class _ParseCNNDMSplit(IterDataPipe): + """Iterable DataPipe to parse the url list for files in the target split""" + + def __init__(self, source_datapipe) -> None: + self.source_datapipe = source_datapipe + + def _hashhex(self, s): + """Returns a heximal formated SHA1 hash of the input string.""" + h = hashlib.sha1() + h.update(s) + return h.hexdigest() + + def __iter__(self): + + for _, url in self.source_datapipe: + + url_hash = self._hashhex(url) + story_fname = url_hash + ".story" + + yield story_fname \ No newline at end of file diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 62159f2e65..8fa3c209c5 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -1,4 +1,3 @@ -import hashlib import os from functools import partial from typing import Union, Tuple @@ -57,30 +56,15 @@ def _extracted_filepath_fn(root: str, source: str, t): return os.path.join(root, _EXTRACTED_FOLDERS[source]) -def _modify_res(t): - return t[1] - - def _filter_fn(story_fnames, x): return os.path.basename(x[0]) in story_fnames -def _get_url_list(split: str): +def _get_split_list(split: str): url_dp = IterableWrapper([URL_LIST[split]]) online_dp = OnlineReader(url_dp) - return online_dp.readlines().map(_modify_res) - - -def _hashhex(s): - """Returns a heximal formated SHA1 hash of the input string.""" - h = hashlib.sha1() - h.update(s) - return h.hexdigest() - - -def _get_url_hashes(url_list): - return [_hashhex(url) for url in url_list] + return online_dp.readlines().parse_cnndm_split() def _load_stories(root: str, source: str): @@ -105,9 +89,7 @@ def _load_stories(root: str, source: str): def CNNDM(root: str, split: Union[Tuple[str], str]): # TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn # of the on_disk_cache_dp which caches the files extracted from the tar - urls = list(_get_url_list(split)) - url_hashes = _get_url_hashes(urls) - story_fnames = set(s + ".story" for s in url_hashes) + story_fnames = set(_get_split_list(split)) cnn_dp = _load_stories(root, "cnn") dailymail_dp = _load_stories(root, "dailymail") @@ -115,14 +97,4 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) - return data_dp.parse_cnndm().shuffle().set_shuffle(False).sharding_filter() - - -if __name__ == "__main__": - print("start") - # out = CNNDM(os.path.expanduser("~/.torchtext/cache"), "val") - # ex = iter(out) - # ex = next(ex) - - # print(f"article:\n{ex[0]}") - # print(f"abstract:\n{ex[1]}") + return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() \ No newline at end of file From c9e7ffc4e896ac8447e3336c3670ba0bd1d2d979 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 17 Jun 2022 11:38:16 -0400 Subject: [PATCH 10/20] pre-commit --- torchtext/data/datasets_utils.py | 4 ++-- torchtext/datasets/cnndm.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 83a4e52480..4f5d20c106 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -1,7 +1,7 @@ import codecs import functools -import inspect import hashlib +import inspect import os from torch.utils.data import functional_datapipe, IterDataPipe @@ -404,4 +404,4 @@ def __iter__(self): url_hash = self._hashhex(url) story_fname = url_hash + ".story" - yield story_fname \ No newline at end of file + yield story_fname diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 8fa3c209c5..c918034777 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -97,4 +97,4 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) - return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() \ No newline at end of file + return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() From b495353d2a57d7f5f72a1eba434ad026ec479c8c Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 17 Jun 2022 20:36:21 +0000 Subject: [PATCH 11/20] Add dataset documentation and nit corrections --- torchtext/data/datasets_utils.py | 35 +++++++++------------- torchtext/datasets/cnndm.py | 50 +++++++++++++++++++++----------- 2 files changed, 46 insertions(+), 39 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 4f5d20c106..fa027c01f8 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -315,10 +315,12 @@ def __iter__(self): @functional_datapipe("parse_cnndm_data") class _ParseCNNDMData(IterDataPipe): - """Iterable DataPipe to parse the article and abstract from a CNNDM data stream""" + """Iterable DataPipe to parse the article and abstract from a CNNDM data stream. + Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py""" dm_single_close_quote = "\u2019" # unicode dm_double_close_quote = "\u201d" + # acceptable ways to end a sentence END_TOKENS = [ ".", "!", @@ -330,10 +332,8 @@ class _ParseCNNDMData(IterDataPipe): dm_single_close_quote, dm_double_close_quote, ")", - "\n", - ] # acceptable ways to end a sentence - SENTENCE_START = "" - SENTENCE_END = "" + "\n" + ] def __init__(self, source_datapipe) -> None: self.source_datapipe = source_datapipe @@ -346,19 +346,16 @@ def _fix_missing_period(self, line): return line if line[-1] in self.END_TOKENS: return line - # print line[-1] return line + " ." def __iter__(self): - for _, stream in self.source_datapipe: - lines = stream.readlines() - # Lowercase everything - lines = [line.decode().lower() for line in lines] + lines = [line.decode("utf-8") for line in lines] - # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; - # consequently they end up in the body of the article as run-on sentences) + # put periods on the ends of lines that are missing them + # this is a problem in the dataset because many image captions don't end in periods + # consequently they end up in the body of the article as run-on sentences lines = [self._fix_missing_period(line) for line in lines] # Separate out article and abstract sentences @@ -375,18 +372,15 @@ def __iter__(self): else: article_lines.append(line) - # Make article into a single string article = " ".join(article_lines) - - # Make abstract into a single string, putting and tags around the sentences - abstract = " ".join(["%s %s %s" % (self.SENTENCE_START, sent, self.SENTENCE_END) for sent in highlights]) - + abstract = " ".join(highlights) yield article, abstract @functional_datapipe("parse_cnndm_split") class _ParseCNNDMSplit(IterDataPipe): - """Iterable DataPipe to parse the url list for files in the target split""" + """Iterable DataPipe to parse the url list for files in the target split. + Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py""" def __init__(self, source_datapipe) -> None: self.source_datapipe = source_datapipe @@ -398,10 +392,7 @@ def _hashhex(self, s): return h.hexdigest() def __iter__(self): - for _, url in self.source_datapipe: - url_hash = self._hashhex(url) story_fname = url_hash + ".story" - - yield story_fname + yield story_fname \ No newline at end of file diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index c918034777..8f207bad4e 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -24,12 +24,6 @@ "test": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt", } -URL_LIST_MD5 = { - "train": "c8ca98cfcb6cf3f99a404552568490bc", - "val": "83a3c483b3ed38b1392285bed668bfee", - "test": "4f3ac04669934dbc746b7061e68a0258", -} - STORIES_LIST = { "cnn": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ", "dailymail": "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs", @@ -52,6 +46,7 @@ def _filepath_fn(root: str, source: str, _=None): return os.path.join(root, PATH_LIST[source]) +# this function will be used to cache the contents of the tar file def _extracted_filepath_fn(root: str, source: str, t): return os.path.join(root, _EXTRACTED_FOLDERS[source]) @@ -61,40 +56,61 @@ def _filter_fn(story_fnames, x): def _get_split_list(split: str): - url_dp = IterableWrapper([URL_LIST[split]]) online_dp = OnlineReader(url_dp) return online_dp.readlines().parse_cnndm_split() def _load_stories(root: str, source: str): - story_dp = IterableWrapper([STORIES_LIST[source]]) - cache_compressed_dp = story_dp.on_disk_cache( filepath_fn=partial(_filepath_fn, root, source), hash_dict={_filepath_fn(root, source): STORIES_MD5[source]}, hash_type="md5", ) - cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) # TODO: cache the contents of the extracted tar file cache_decompressed_dp = FileOpener(cache_compressed_dp, mode="b").load_from_tar() - return cache_decompressed_dp @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "val", "test")) def CNNDM(root: str, split: Union[Tuple[str], str]): - # TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn - # of the on_disk_cache_dp which caches the files extracted from the tar - story_fnames = set(_get_split_list(split)) + """CNNDM Dataset + + .. warning:: + + Using datapipes is still currently subject to a few caveats. If you wish + to use this dataset with shuffling, multi-processing, or distributed + learning, please see :ref:`this note ` for further + instructions. + + For additional details refer to https://arxiv.org/pdf/1704.04368.pdf + + Number of lines per split: + - train: 287,227 + - val: 13,368 + - test: 11,490 + + Args: + root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') + split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `val`, `test`) + + :returns: DataPipe that yields a tuple of texts containing an article and its abstract (i.e. (article, abstract)) + :rtype: (str, str) + """ + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data" + ) cnn_dp = _load_stories(root, "cnn") dailymail_dp = _load_stories(root, "dailymail") data_dp = cnn_dp.concat(dailymail_dp) - + # TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn + # of the on_disk_cache_dp which caches the files extracted from the tar + story_fnames = set(_get_split_list(split)) + print(len(story_fnames)) data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) - - return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() + return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() \ No newline at end of file From 0af32439ef983efa260e9feda0203f399b95d97b Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 17 Jun 2022 20:36:47 +0000 Subject: [PATCH 12/20] Implement testing --- test/datasets/test_cnndm.py | 102 ++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 test/datasets/test_cnndm.py diff --git a/test/datasets/test_cnndm.py b/test/datasets/test_cnndm.py new file mode 100644 index 0000000000..01a7410b4a --- /dev/null +++ b/test/datasets/test_cnndm.py @@ -0,0 +1,102 @@ +import os +import tarfile +import hashlib +from collections import defaultdict +from unittest.mock import patch + +from parameterized import parameterized +from torchtext.datasets.cnndm import CNNDM + +from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode +from ..common.parameterized_utils import nested_params +from ..common.torchtext_test_case import TorchtextTestCase + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + + base_dir = os.path.join(root_dir, 'CNNDM') + temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") + os.makedirs(temp_dataset_dir, exist_ok=True) + seed = 1 + mocked_data = defaultdict(list) + + for source in ['cnn', 'dailymail']: + source_dir = os.path.join(temp_dataset_dir, source, 'stories') + os.makedirs(source_dir, exist_ok=True) + + for split in ['train', 'val', 'test']: + url = source + '_' + split + h = hashlib.sha1() + h.update(url.encode()) + filename = h.hexdigest() + '.story' + txt_file = os.path.join(source_dir, filename) + + with open(txt_file, "w", encoding=("utf-8")) as f: + article = get_random_unicode(seed) + '.' + abstract = get_random_unicode(seed+1) + '.' + dataset_line = (article + '\n', abstract) + f.writelines([article, "\n@highlight\n", abstract]) + # append line to correct dataset split + mocked_data[split].append(dataset_line) + + seed += 2 + + compressed_dataset_path = os.path.join(base_dir, f"{source}_stories.tgz") + # create zip file from dataset folder + with tarfile.open(compressed_dataset_path, "w:gz") as tar: + tar.add(os.path.join(temp_dataset_dir, source), arcname=source) + + return mocked_data + +class TestCNNDM(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + cls.samples = _get_mock_dataset(os.path.join(cls.root_dir, "datasets")) + cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + def _mock_split_list(split): + story_fnames = [] + for source in ['cnn', 'dailymail']: + url = source + '_' + split + h = hashlib.sha1() + h.update(url.encode()) + filename = h.hexdigest() + '.story' + story_fnames.append(filename) + + return story_fnames + + @parameterized.expand(["train", "val", "test"]) + @patch("torchtext.datasets.cnndm._get_split_list", _mock_split_list) + def test_cnndm(self, split): + dataset = CNNDM(root=self.root_dir, split=split) + samples = list(dataset) + expected_samples = self.samples[split] + print(samples) + print(expected_samples) + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + + + + + + + + + + + From 4342a4eb3a69c1cefc33673035ecfff700800d6d Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 17 Jun 2022 20:40:14 +0000 Subject: [PATCH 13/20] Remove empty lines --- test/datasets/test_cnndm.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/test/datasets/test_cnndm.py b/test/datasets/test_cnndm.py index 01a7410b4a..dd04d10907 100644 --- a/test/datasets/test_cnndm.py +++ b/test/datasets/test_cnndm.py @@ -87,16 +87,4 @@ def test_cnndm(self, split): print(samples) print(expected_samples) for sample, expected_sample in zip_equal(samples, expected_samples): - self.assertEqual(sample, expected_sample) - - - - - - - - - - - - + self.assertEqual(sample, expected_sample) \ No newline at end of file From e9381f9bd7b5fd4484b254c003c4a39c2a8063c3 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 17 Jun 2022 20:42:52 +0000 Subject: [PATCH 14/20] Test split argument --- test/datasets/test_cnndm.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/datasets/test_cnndm.py b/test/datasets/test_cnndm.py index dd04d10907..cf176c867f 100644 --- a/test/datasets/test_cnndm.py +++ b/test/datasets/test_cnndm.py @@ -87,4 +87,12 @@ def test_cnndm(self, split): print(samples) print(expected_samples) for sample, expected_sample in zip_equal(samples, expected_samples): - self.assertEqual(sample, expected_sample) \ No newline at end of file + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train", "val", "test"]) + def test_cnndm_split_argument(self, split): + dataset1 = CNNDM(root=self.root_dir, split=split) + (dataset2,) = CNNDM(root=self.root_dir, split=(split,)) + + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2) \ No newline at end of file From 4415bb4b0349f2049d555ffde5548b7a0703a933 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 17 Jun 2022 17:11:41 -0400 Subject: [PATCH 15/20] pre-commit --- test/datasets/test_cnndm.py | 41 ++++++++++++++++---------------- torchtext/data/datasets_utils.py | 16 ++----------- torchtext/datasets/cnndm.py | 2 +- 3 files changed, 24 insertions(+), 35 deletions(-) diff --git a/test/datasets/test_cnndm.py b/test/datasets/test_cnndm.py index cf176c867f..e5d90a3412 100644 --- a/test/datasets/test_cnndm.py +++ b/test/datasets/test_cnndm.py @@ -1,6 +1,6 @@ +import hashlib import os import tarfile -import hashlib from collections import defaultdict from unittest.mock import patch @@ -8,48 +8,49 @@ from torchtext.datasets.cnndm import CNNDM from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode -from ..common.parameterized_utils import nested_params from ..common.torchtext_test_case import TorchtextTestCase + def _get_mock_dataset(root_dir): """ root_dir: directory to the mocked dataset """ - base_dir = os.path.join(root_dir, 'CNNDM') + base_dir = os.path.join(root_dir, "CNNDM") temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") os.makedirs(temp_dataset_dir, exist_ok=True) seed = 1 mocked_data = defaultdict(list) - for source in ['cnn', 'dailymail']: - source_dir = os.path.join(temp_dataset_dir, source, 'stories') + for source in ["cnn", "dailymail"]: + source_dir = os.path.join(temp_dataset_dir, source, "stories") os.makedirs(source_dir, exist_ok=True) - - for split in ['train', 'val', 'test']: - url = source + '_' + split + + for split in ["train", "val", "test"]: + url = source + "_" + split h = hashlib.sha1() h.update(url.encode()) - filename = h.hexdigest() + '.story' + filename = h.hexdigest() + ".story" txt_file = os.path.join(source_dir, filename) with open(txt_file, "w", encoding=("utf-8")) as f: - article = get_random_unicode(seed) + '.' - abstract = get_random_unicode(seed+1) + '.' - dataset_line = (article + '\n', abstract) + article = get_random_unicode(seed) + "." + abstract = get_random_unicode(seed + 1) + "." + dataset_line = (article + "\n", abstract) f.writelines([article, "\n@highlight\n", abstract]) # append line to correct dataset split mocked_data[split].append(dataset_line) seed += 2 - + compressed_dataset_path = os.path.join(base_dir, f"{source}_stories.tgz") # create zip file from dataset folder with tarfile.open(compressed_dataset_path, "w:gz") as tar: tar.add(os.path.join(temp_dataset_dir, source), arcname=source) - + return mocked_data + class TestCNNDM(TempDirMixin, TorchtextTestCase): root_dir = None samples = [] @@ -69,13 +70,13 @@ def tearDownClass(cls): def _mock_split_list(split): story_fnames = [] - for source in ['cnn', 'dailymail']: - url = source + '_' + split + for source in ["cnn", "dailymail"]: + url = source + "_" + split h = hashlib.sha1() h.update(url.encode()) - filename = h.hexdigest() + '.story' + filename = h.hexdigest() + ".story" story_fnames.append(filename) - + return story_fnames @parameterized.expand(["train", "val", "test"]) @@ -88,11 +89,11 @@ def test_cnndm(self, split): print(expected_samples) for sample, expected_sample in zip_equal(samples, expected_samples): self.assertEqual(sample, expected_sample) - + @parameterized.expand(["train", "val", "test"]) def test_cnndm_split_argument(self, split): dataset1 = CNNDM(root=self.root_dir, split=split) (dataset2,) = CNNDM(root=self.root_dir, split=(split,)) for d1, d2 in zip_equal(dataset1, dataset2): - self.assertEqual(d1, d2) \ No newline at end of file + self.assertEqual(d1, d2) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index fa027c01f8..906a4cb77e 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -321,19 +321,7 @@ class _ParseCNNDMData(IterDataPipe): dm_single_close_quote = "\u2019" # unicode dm_double_close_quote = "\u201d" # acceptable ways to end a sentence - END_TOKENS = [ - ".", - "!", - "?", - "...", - "'", - "`", - '"', - dm_single_close_quote, - dm_double_close_quote, - ")", - "\n" - ] + END_TOKENS = [".", "!", "?", "...", "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")", "\n"] def __init__(self, source_datapipe) -> None: self.source_datapipe = source_datapipe @@ -395,4 +383,4 @@ def __iter__(self): for _, url in self.source_datapipe: url_hash = self._hashhex(url) story_fname = url_hash + ".story" - yield story_fname \ No newline at end of file + yield story_fname diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 8f207bad4e..e3571ff15f 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -113,4 +113,4 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): story_fnames = set(_get_split_list(split)) print(len(story_fnames)) data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) - return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() \ No newline at end of file + return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() From 9687dae0cdf2cdaa7b0fe8f2ed00cc4af27de9e6 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Fri, 17 Jun 2022 21:47:58 +0000 Subject: [PATCH 16/20] Strip newlines to see if passes windows unittest --- test/datasets/test_cnndm.py | 2 +- torchtext/data/datasets_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/datasets/test_cnndm.py b/test/datasets/test_cnndm.py index e5d90a3412..ceec9a2234 100644 --- a/test/datasets/test_cnndm.py +++ b/test/datasets/test_cnndm.py @@ -36,7 +36,7 @@ def _get_mock_dataset(root_dir): with open(txt_file, "w", encoding=("utf-8")) as f: article = get_random_unicode(seed) + "." abstract = get_random_unicode(seed + 1) + "." - dataset_line = (article + "\n", abstract) + dataset_line = (article, abstract) f.writelines([article, "\n@highlight\n", abstract]) # append line to correct dataset split mocked_data[split].append(dataset_line) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 906a4cb77e..8f8c3cdb96 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -339,7 +339,7 @@ def _fix_missing_period(self, line): def __iter__(self): for _, stream in self.source_datapipe: lines = stream.readlines() - lines = [line.decode("utf-8") for line in lines] + lines = [line.decode("utf-8").strip() for line in lines] # put periods on the ends of lines that are missing them # this is a problem in the dataset because many image captions don't end in periods From 3eeb81ce2b4b7e536d3bd872ef51accce7864739 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 20 Jun 2022 14:26:59 +0000 Subject: [PATCH 17/20] Remove print statements --- test/datasets/test_cnndm.py | 2 -- torchtext/datasets/cnndm.py | 1 - 2 files changed, 3 deletions(-) diff --git a/test/datasets/test_cnndm.py b/test/datasets/test_cnndm.py index ceec9a2234..58ac958c50 100644 --- a/test/datasets/test_cnndm.py +++ b/test/datasets/test_cnndm.py @@ -85,8 +85,6 @@ def test_cnndm(self, split): dataset = CNNDM(root=self.root_dir, split=split) samples = list(dataset) expected_samples = self.samples[split] - print(samples) - print(expected_samples) for sample, expected_sample in zip_equal(samples, expected_samples): self.assertEqual(sample, expected_sample) diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index e3571ff15f..6216cc9f64 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -111,6 +111,5 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): # TODO: store the .story filenames corresponding to each split on disk so we can pass that into the filepath_fn # of the on_disk_cache_dp which caches the files extracted from the tar story_fnames = set(_get_split_list(split)) - print(len(story_fnames)) data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() From 8b1fd91a67e4d5a7d5809140df1e6ddadc6334a3 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 21 Jun 2022 21:25:41 +0000 Subject: [PATCH 18/20] Add more mock examples --- test/datasets/test_cnndm.py | 45 ++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/test/datasets/test_cnndm.py b/test/datasets/test_cnndm.py index 58ac958c50..dcfcab94d0 100644 --- a/test/datasets/test_cnndm.py +++ b/test/datasets/test_cnndm.py @@ -25,23 +25,25 @@ def _get_mock_dataset(root_dir): for source in ["cnn", "dailymail"]: source_dir = os.path.join(temp_dataset_dir, source, "stories") os.makedirs(source_dir, exist_ok=True) - for split in ["train", "val", "test"]: - url = source + "_" + split - h = hashlib.sha1() - h.update(url.encode()) - filename = h.hexdigest() + ".story" - txt_file = os.path.join(source_dir, filename) - - with open(txt_file, "w", encoding=("utf-8")) as f: - article = get_random_unicode(seed) + "." - abstract = get_random_unicode(seed + 1) + "." - dataset_line = (article, abstract) - f.writelines([article, "\n@highlight\n", abstract]) - # append line to correct dataset split - mocked_data[split].append(dataset_line) - - seed += 2 + stories = [] + for i in range(5): + url = '_'.join([source, split, str(i)]) + h = hashlib.sha1() + h.update(url.encode()) + filename = h.hexdigest() + ".story" + txt_file = os.path.join(source_dir, filename) + with open(txt_file, "w", encoding=("utf-8")) as f: + article = get_random_unicode(seed) + "." + abstract = get_random_unicode(seed + 1) + "." + dataset_line = (article, abstract) + f.writelines([article, "\n@highlight\n", abstract]) + stories.append((txt_file, dataset_line)) + seed += 2 + + # append stories to correct dataset split, must be in legixographic order of filenames per dataset + stories.sort(key=lambda x: x[0]) + mocked_data[split] += [t[1] for t in stories] compressed_dataset_path = os.path.join(base_dir, f"{source}_stories.tgz") # create zip file from dataset folder @@ -71,11 +73,12 @@ def tearDownClass(cls): def _mock_split_list(split): story_fnames = [] for source in ["cnn", "dailymail"]: - url = source + "_" + split - h = hashlib.sha1() - h.update(url.encode()) - filename = h.hexdigest() + ".story" - story_fnames.append(filename) + for i in range(5): + url = '_'.join([source, split, str(i)]) + h = hashlib.sha1() + h.update(url.encode()) + filename = h.hexdigest() + ".story" + story_fnames.append(filename) return story_fnames From da45057363674c18d003cc0205a2bd2246875563 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 21 Jun 2022 21:26:28 +0000 Subject: [PATCH 19/20] Turn parse_cnndm_split datapipe into hiddent fxn applied using map --- torchtext/data/datasets_utils.py | 24 +----------------------- torchtext/datasets/cnndm.py | 20 +++++++++++++++++--- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 8f8c3cdb96..efc59352ac 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -1,6 +1,5 @@ import codecs import functools -import hashlib import inspect import os @@ -362,25 +361,4 @@ def __iter__(self): article = " ".join(article_lines) abstract = " ".join(highlights) - yield article, abstract - - -@functional_datapipe("parse_cnndm_split") -class _ParseCNNDMSplit(IterDataPipe): - """Iterable DataPipe to parse the url list for files in the target split. - Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py""" - - def __init__(self, source_datapipe) -> None: - self.source_datapipe = source_datapipe - - def _hashhex(self, s): - """Returns a heximal formated SHA1 hash of the input string.""" - h = hashlib.sha1() - h.update(s) - return h.hexdigest() - - def __iter__(self): - for _, url in self.source_datapipe: - url_hash = self._hashhex(url) - story_fname = url_hash + ".story" - yield story_fname + yield article, abstract \ No newline at end of file diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index 6216cc9f64..cfca268f6c 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -1,3 +1,4 @@ +import hashlib import os from functools import partial from typing import Union, Tuple @@ -47,7 +48,7 @@ def _filepath_fn(root: str, source: str, _=None): # this function will be used to cache the contents of the tar file -def _extracted_filepath_fn(root: str, source: str, t): +def _extracted_filepath_fn(root: str, source: str): return os.path.join(root, _EXTRACTED_FOLDERS[source]) @@ -55,10 +56,23 @@ def _filter_fn(story_fnames, x): return os.path.basename(x[0]) in story_fnames +def _hash_urls(s): + """ + Returns story filename as a heximal formated SHA1 hash of the input url string. + Code is inspired from https://github.com/abisee/cnn-dailymail/blob/master/make_datafiles.py + """ + url = s[1] + h = hashlib.sha1() + h.update(url) + url_hash = h.hexdigest() + story_fname = url_hash + ".story" + return story_fname + + def _get_split_list(split: str): url_dp = IterableWrapper([URL_LIST[split]]) online_dp = OnlineReader(url_dp) - return online_dp.readlines().parse_cnndm_split() + return online_dp.readlines().map(fn=_hash_urls) def _load_stories(root: str, source: str): @@ -112,4 +126,4 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): # of the on_disk_cache_dp which caches the files extracted from the tar story_fnames = set(_get_split_list(split)) data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) - return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() + return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() \ No newline at end of file From a5a6a9cc77f9d9207c04888973a73c15c96f95a0 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 21 Jun 2022 17:49:11 -0400 Subject: [PATCH 20/20] pre-commit --- test/datasets/test_cnndm.py | 6 +++--- torchtext/data/datasets_utils.py | 2 +- torchtext/datasets/cnndm.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/datasets/test_cnndm.py b/test/datasets/test_cnndm.py index dcfcab94d0..224329a376 100644 --- a/test/datasets/test_cnndm.py +++ b/test/datasets/test_cnndm.py @@ -28,7 +28,7 @@ def _get_mock_dataset(root_dir): for split in ["train", "val", "test"]: stories = [] for i in range(5): - url = '_'.join([source, split, str(i)]) + url = "_".join([source, split, str(i)]) h = hashlib.sha1() h.update(url.encode()) filename = h.hexdigest() + ".story" @@ -40,7 +40,7 @@ def _get_mock_dataset(root_dir): f.writelines([article, "\n@highlight\n", abstract]) stories.append((txt_file, dataset_line)) seed += 2 - + # append stories to correct dataset split, must be in legixographic order of filenames per dataset stories.sort(key=lambda x: x[0]) mocked_data[split] += [t[1] for t in stories] @@ -74,7 +74,7 @@ def _mock_split_list(split): story_fnames = [] for source in ["cnn", "dailymail"]: for i in range(5): - url = '_'.join([source, split, str(i)]) + url = "_".join([source, split, str(i)]) h = hashlib.sha1() h.update(url.encode()) filename = h.hexdigest() + ".story" diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index efc59352ac..2f02a6cb6c 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -361,4 +361,4 @@ def __iter__(self): article = " ".join(article_lines) abstract = " ".join(highlights) - yield article, abstract \ No newline at end of file + yield article, abstract diff --git a/torchtext/datasets/cnndm.py b/torchtext/datasets/cnndm.py index cfca268f6c..3950557f5d 100644 --- a/torchtext/datasets/cnndm.py +++ b/torchtext/datasets/cnndm.py @@ -126,4 +126,4 @@ def CNNDM(root: str, split: Union[Tuple[str], str]): # of the on_disk_cache_dp which caches the files extracted from the tar story_fnames = set(_get_split_list(split)) data_dp = data_dp.filter(partial(_filter_fn, story_fnames)) - return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter() \ No newline at end of file + return data_dp.parse_cnndm_data().shuffle().set_shuffle(False).sharding_filter()