From dff81cb29f07322f6f7a97aef472b80b8b7041c1 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Thu, 13 Jan 2022 08:44:16 -0500 Subject: [PATCH 1/3] add initial pass at migrating SQUAD1 to datapipes. --- torchtext/data/datasets_utils.py | 20 ++++++++++++++++++ torchtext/datasets/squad1.py | 36 +++++++++++++++++++++++++------- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 903d45c97e..b022a19af6 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -10,6 +10,7 @@ extract_archive, unicode_csv_reader, ) +from torch.utils.data import IterDataPipe import codecs try: import defusedxml.ElementTree as ET @@ -318,3 +319,22 @@ def pos(self): def __str__(self): return self.description + + +class _ParseSQuADQAData(IterDataPipe): + def __init__(self, source_datapipe) -> None: + self.source_datapipe = source_datapipe + + def __iter__(self): + for _, stream in self.source_datapipe: + raw_json_data = stream["data"] + for layer1 in raw_json_data: + for layer2 in layer1["paragraphs"]: + for layer3 in layer2["qas"]: + _context, _question = layer2["context"], layer3["question"] + _answers = [item["text"] for item in layer3["answers"]] + _answer_start = [item["answer_start"] for item in layer3["answers"]] + if len(_answers) == 0: + _answers = [""] + _answer_start = [-1] + yield _context, _question, _answers, _answer_start \ No newline at end of file diff --git a/torchtext/datasets/squad1.py b/torchtext/datasets/squad1.py index 2c9fdc3ff4..c026b16955 100644 --- a/torchtext/datasets/squad1.py +++ b/torchtext/datasets/squad1.py @@ -1,11 +1,18 @@ -from torchtext.utils import download_from_url +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, _create_dataset_directory, - _create_data_from_json, + _ParseSQuADQAData ) + +import os + URL = { 'train': "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json", 'dev': "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", @@ -27,8 +34,21 @@ @_add_docstring_header(num_lines=NUM_LINES) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'dev')) -def SQuAD1(root, split): - extracted_files = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5') - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], - _create_data_from_json(extracted_files)) +@_wrap_split_argument(("train", "dev")) +def SQuAD1(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL[split]]) + # cache data on-disk with sanity check + cache_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), + hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + hash_type="md5", + ) + cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + + cache_dp = FileOpener(cache_dp, mode="b") + + # stack custom data pipe on top of JSON reader to orchestrate data samples for Q&A dataset + return _ParseSQuADQAData(cache_dp.parse_json_files()) \ No newline at end of file From 2ab27b3fc74df7bf239aefa060e2d76a057d6214 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Thu, 13 Jan 2022 08:49:08 -0500 Subject: [PATCH 2/3] fix style. --- torchtext/data/datasets_utils.py | 2 +- torchtext/datasets/squad1.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index b022a19af6..2081d81e79 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -337,4 +337,4 @@ def __iter__(self): if len(_answers) == 0: _answers = [""] _answer_start = [-1] - yield _context, _question, _answers, _answer_start \ No newline at end of file + yield _context, _question, _answers, _answer_start diff --git a/torchtext/datasets/squad1.py b/torchtext/datasets/squad1.py index c026b16955..feb4e25b59 100644 --- a/torchtext/datasets/squad1.py +++ b/torchtext/datasets/squad1.py @@ -51,4 +51,4 @@ def SQuAD1(root: str, split: Union[Tuple[str], str]): cache_dp = FileOpener(cache_dp, mode="b") # stack custom data pipe on top of JSON reader to orchestrate data samples for Q&A dataset - return _ParseSQuADQAData(cache_dp.parse_json_files()) \ No newline at end of file + return _ParseSQuADQAData(cache_dp.parse_json_files()) From 8999c2c9ba46b991c502b92b9056c48a512cde07 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sat, 15 Jan 2022 08:43:53 -0500 Subject: [PATCH 3/3] add some comments and a decorator to use functionally. --- torchtext/data/datasets_utils.py | 6 +++++- torchtext/datasets/squad1.py | 6 +----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 2081d81e79..8f238177ed 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -10,7 +10,7 @@ extract_archive, unicode_csv_reader, ) -from torch.utils.data import IterDataPipe +from torch.utils.data import IterDataPipe, functional_datapipe import codecs try: import defusedxml.ElementTree as ET @@ -321,7 +321,11 @@ def __str__(self): return self.description +@functional_datapipe("read_squad") class _ParseSQuADQAData(IterDataPipe): + r"""Iterable DataPipe to parse the contents of a stream of JSON objects + as provided by SQuAD QA. Used in SQuAD1 and SQuAD2. + """ def __init__(self, source_datapipe) -> None: self.source_datapipe = source_datapipe diff --git a/torchtext/datasets/squad1.py b/torchtext/datasets/squad1.py index feb4e25b59..b8452e37e1 100644 --- a/torchtext/datasets/squad1.py +++ b/torchtext/datasets/squad1.py @@ -8,7 +8,6 @@ _wrap_split_argument, _add_docstring_header, _create_dataset_directory, - _ParseSQuADQAData ) import os @@ -47,8 +46,5 @@ def SQuAD1(root: str, split: Union[Tuple[str], str]): hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_dp = FileOpener(cache_dp, mode="b") - - # stack custom data pipe on top of JSON reader to orchestrate data samples for Q&A dataset - return _ParseSQuADQAData(cache_dp.parse_json_files()) + return cache_dp.parse_json_files().read_squad()