From 18538c61fe4da662d0cb825b68276154008ec425 Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sat, 15 Jan 2022 19:34:26 -0500 Subject: [PATCH] add initial pass at migrating SQUAD2 to datapipes. --- torchtext/datasets/squad2.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/torchtext/datasets/squad2.py b/torchtext/datasets/squad2.py index a889f9900f..75beda497d 100644 --- a/torchtext/datasets/squad2.py +++ b/torchtext/datasets/squad2.py @@ -1,11 +1,17 @@ -from torchtext.utils import download_from_url +from torchtext._internal.module_utils import is_module_available +from typing import Union, Tuple + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper + from torchtext.data.datasets_utils import ( - _RawTextIterableDataset, _wrap_split_argument, _add_docstring_header, _create_dataset_directory, - _create_data_from_json, ) + +import os + URL = { 'train': "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", 'dev': "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json", @@ -28,7 +34,17 @@ @_add_docstring_header(num_lines=NUM_LINES) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(('train', 'dev')) -def SQuAD2(root, split): - extracted_files = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5') - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], - _create_data_from_json(extracted_files)) +def SQuAD2(root: str, split: Union[Tuple[str], str]): + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") + + url_dp = IterableWrapper([URL[split]]) + # cache data on-disk with sanity check + cache_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), + hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + hash_type="md5", + ) + cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_dp = FileOpener(cache_dp, mode="b") + return cache_dp.parse_json_files().read_squad()