|
1 | | -from torchtext.utils import ( |
2 | | - download_from_url, |
3 | | -) |
| 1 | +from torchtext._internal.module_utils import is_module_available |
| 2 | + |
| 3 | +if is_module_available("torchdata"): |
| 4 | + from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper |
| 5 | + |
4 | 6 | from torchtext.data.datasets_utils import ( |
5 | | - _RawTextIterableDataset, |
6 | 7 | _wrap_split_argument, |
7 | 8 | _add_docstring_header, |
8 | 9 | _create_dataset_directory, |
9 | | - _create_data_from_csv, |
10 | 10 | ) |
11 | 11 | import os |
12 | 12 |
|
|
30 | 30 |
|
31 | 31 | @_add_docstring_header(num_lines=NUM_LINES, num_classes=4) |
32 | 32 | @_create_dataset_directory(dataset_name=DATASET_NAME) |
33 | | -@_wrap_split_argument(('train', 'test')) |
| 33 | +@_wrap_split_argument(("train", "test")) |
34 | 34 | def AG_NEWS(root, split): |
35 | | - path = download_from_url(URL[split], root=root, |
36 | | - path=os.path.join(root, split + ".csv"), |
37 | | - hash_value=MD5[split], |
38 | | - hash_type='md5') |
39 | | - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], |
40 | | - _create_data_from_csv(path)) |
| 35 | + url_dp = IterableWrapper([URL[split]]) |
| 36 | + http_dp = HttpReader(url_dp) |
| 37 | + cache_dp = http_dp.on_disk_cache( |
| 38 | + filepath_fn=lambda x: os.path.join(root, split + ".csv"), |
| 39 | + hash_dict={os.path.join(root, split + ".csv"): MD5[split]}, |
| 40 | + hash_type="md5" |
| 41 | + ).end_caching(mode="w", same_filepath_fn=True) |
| 42 | + |
| 43 | + cache_dp = FileOpener(cache_dp, mode="r") |
| 44 | + return cache_dp.parse_csv().map(fn=lambda t: (int(t[0]), t[1])) |
0 commit comments