diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 903d45c97e..8f238177ed 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -10,6 +10,7 @@ extract_archive, unicode_csv_reader, ) +from torch.utils.data import IterDataPipe, functional_datapipe import codecs try: import defusedxml.ElementTree as ET @@ -318,3 +319,26 @@ def pos(self): def __str__(self): return self.description + + +@functional_datapipe("read_squad") +class _ParseSQuADQAData(IterDataPipe): + r"""Iterable DataPipe to parse the contents of a stream of JSON objects + as provided by SQuAD QA. Used in SQuAD1 and SQuAD2. + """ + def __init__(self, source_datapipe) -> None: + self.source_datapipe = source_datapipe + + def __iter__(self): + for _, stream in self.source_datapipe: + raw_json_data = stream["data"] + for layer1 in raw_json_data: + for layer2 in layer1["paragraphs"]: + for layer3 in layer2["qas"]: + _context, _question = layer2["context"], layer3["question"] + _answers = [item["text"] for item in layer3["answers"]] + _answer_start = [item["answer_start"] for item in layer3["answers"]] + if len(_answers) == 0: + _answers = [""] + _answer_start = [-1] + yield _context, _question, _answers, _answer_start diff --git a/torchtext/datasets/squad1.py b/torchtext/datasets/squad1.py index 2c9fdc3ff4..b8452e37e1 100644 --- a/torchtext/datasets/squad1.py +++ b/torchtext/datasets/squad1.py @@ -1,11 +1,17 @@ -from torchtext.utils import download_from_url +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, _create_dataset_directory, - _create_data_from_json, ) + +import os + URL = { 'train': "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json", 'dev': "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", @@ -27,8 +33,18 @@ @_add_docstring_header(num_lines=NUM_LINES) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'dev')) -def SQuAD1(root, split): - extracted_files = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5') - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], - _create_data_from_json(extracted_files)) +@_wrap_split_argument(("train", "dev")) +def SQuAD1(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL[split]]) + # cache data on-disk with sanity check + cache_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), + hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + hash_type="md5", + ) + cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_dp = FileOpener(cache_dp, mode="b") + return cache_dp.parse_json_files().read_squad()