|
1 | | -from torchtext.utils import download_from_url |
| 1 | +from torchtext._internal.module_utils import is_module_available |
| 2 | +from typing import Union, Tuple |
| 3 | + |
| 4 | +if is_module_available("torchdata"): |
| 5 | + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper |
| 6 | + |
2 | 7 | from torchtext.data.datasets_utils import ( |
3 | | - _RawTextIterableDataset, |
4 | 8 | _wrap_split_argument, |
5 | 9 | _add_docstring_header, |
6 | 10 | _create_dataset_directory, |
7 | | - _create_data_from_json, |
8 | | -) |
| 11 | +g) |
| 12 | + |
| 13 | +import os |
| 14 | + |
9 | 15 | URL = { |
10 | 16 | 'train': "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", |
11 | 17 | 'dev': "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json", |
|
28 | 34 | @_add_docstring_header(num_lines=NUM_LINES) |
29 | 35 | @_create_dataset_directory(dataset_name=DATASET_NAME) |
30 | 36 | @_wrap_split_argument(('train', 'dev')) |
31 | | -def SQuAD2(root, split): |
32 | | - extracted_files = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5') |
33 | | - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], |
34 | | - _create_data_from_json(extracted_files)) |
| 37 | +def SQuAD2(root: str, split: Union[Tuple[str], str]): |
| 38 | + if not is_module_available("torchdata"): |
| 39 | + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") |
| 40 | + |
| 41 | + url_dp = IterableWrapper([URL[split]]) |
| 42 | + # cache data on-disk with sanity check |
| 43 | + cache_dp = url_dp.on_disk_cache( |
| 44 | + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), |
| 45 | + hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, |
| 46 | + hash_type="md5", |
| 47 | + ) |
| 48 | + cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) |
| 49 | + cache_dp = FileOpener(cache_dp, mode="b") |
| 50 | + return cache_dp.parse_json_files().read_squad() |
0 commit comments