Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions torchtext/data/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
extract_archive,
unicode_csv_reader,
)
from torch.utils.data import IterDataPipe, functional_datapipe
import codecs
try:
import defusedxml.ElementTree as ET
Expand Down Expand Up @@ -318,3 +319,26 @@ def pos(self):

def __str__(self):
return self.description


@functional_datapipe("read_squad")
class _ParseSQuADQAData(IterDataPipe):
r"""Iterable DataPipe to parse the contents of a stream of JSON objects
as provided by SQuAD QA. Used in SQuAD1 and SQuAD2.
"""
def __init__(self, source_datapipe) -> None:
self.source_datapipe = source_datapipe

def __iter__(self):
for _, stream in self.source_datapipe:
raw_json_data = stream["data"]
for layer1 in raw_json_data:
for layer2 in layer1["paragraphs"]:
for layer3 in layer2["qas"]:
_context, _question = layer2["context"], layer3["question"]
_answers = [item["text"] for item in layer3["answers"]]
_answer_start = [item["answer_start"] for item in layer3["answers"]]
if len(_answers) == 0:
_answers = [""]
_answer_start = [-1]
yield _context, _question, _answers, _answer_start
32 changes: 24 additions & 8 deletions torchtext/datasets/squad1.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from torchtext.utils import download_from_url
from torchtext._internal.module_utils import is_module_available
from typing import Union, Tuple

if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper

from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_add_docstring_header,
_create_dataset_directory,
_create_data_from_json,
)

import os

URL = {
'train': "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json",
'dev': "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json",
Expand All @@ -27,8 +33,18 @@

@_add_docstring_header(num_lines=NUM_LINES)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'dev'))
def SQuAD1(root, split):
extracted_files = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5')
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
_create_data_from_json(extracted_files))
@_wrap_split_argument(("train", "dev"))
def SQuAD1(root: str, split: Union[Tuple[str], str]):
if not is_module_available("torchdata"):
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")

url_dp = IterableWrapper([URL[split]])
# cache data on-disk with sanity check
cache_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(x)),
hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
cache_dp = FileOpener(cache_dp, mode="b")
return cache_dp.parse_json_files().read_squad()