|
1 | | -from torchtext._internal.module_utils import is_module_available |
2 | 1 | from typing import Union, Tuple |
3 | 2 |
|
| 3 | +from torchtext._internal.module_utils import is_module_available |
| 4 | + |
4 | 5 | if is_module_available("torchdata"): |
5 | 6 | from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper |
6 | 7 |
|
| 8 | +import os |
| 9 | + |
7 | 10 | from torchtext.data.datasets_utils import ( |
8 | 11 | _wrap_split_argument, |
9 | 12 | _add_docstring_header, |
10 | 13 | _create_dataset_directory, |
11 | 14 | ) |
12 | 15 |
|
13 | | -import os |
14 | | - |
15 | | -URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k' |
| 16 | +URL = "https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k" |
16 | 17 |
|
17 | | -MD5 = 'dca7b1ae12b1091090db52aa7ec5ca64' |
| 18 | +MD5 = "dca7b1ae12b1091090db52aa7ec5ca64" |
18 | 19 |
|
19 | 20 | NUM_LINES = { |
20 | | - 'train': 560000, |
21 | | - 'test': 70000, |
| 21 | + "train": 560000, |
| 22 | + "test": 70000, |
22 | 23 | } |
23 | 24 |
|
24 | | -_PATH = 'dbpedia_csv.tar.gz' |
| 25 | +_PATH = "dbpedia_csv.tar.gz" |
25 | 26 |
|
26 | 27 | _EXTRACTED_FILES = { |
27 | 28 | "train": os.path.join("dbpedia_csv", "train.csv"), |
28 | | - "test": os.path.join("dbpedia_csv", "test.csv") |
| 29 | + "test": os.path.join("dbpedia_csv", "test.csv"), |
29 | 30 | } |
30 | 31 |
|
31 | 32 | DATASET_NAME = "DBpedia" |
|
37 | 38 | def DBpedia(root: str, split: Union[Tuple[str], str]): |
38 | 39 | # TODO Remove this after removing conditional dependency |
39 | 40 | if not is_module_available("torchdata"): |
40 | | - raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") |
| 41 | + raise ModuleNotFoundError( |
| 42 | + "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" |
| 43 | + ) |
41 | 44 |
|
42 | 45 | url_dp = IterableWrapper([URL]) |
43 | | - |
44 | | - cache_dp = url_dp.on_disk_cache( |
| 46 | + cache_compressed_dp = url_dp.on_disk_cache( |
45 | 47 | filepath_fn=lambda x: os.path.join(root, _PATH), |
46 | | - hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" |
| 48 | + hash_dict={os.path.join(root, _PATH): MD5}, |
| 49 | + hash_type="md5", |
| 50 | + ) |
| 51 | + cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching( |
| 52 | + mode="wb", same_filepath_fn=True |
47 | 53 | ) |
48 | | - cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) |
49 | | - cache_dp = FileOpener(cache_dp, mode="b") |
50 | | - |
51 | | - extracted_files = cache_dp.read_from_tar() |
52 | 54 |
|
53 | | - filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) |
| 55 | + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( |
| 56 | + filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) |
| 57 | + ) |
| 58 | + cache_decompressed_dp = ( |
| 59 | + FileOpener(cache_decompressed_dp, mode="b") |
| 60 | + .read_from_tar() |
| 61 | + .filter(lambda x: _EXTRACTED_FILES[split] in x[0]) |
| 62 | + ) |
| 63 | + cache_decompressed_dp = cache_decompressed_dp.end_caching( |
| 64 | + mode="wb", same_filepath_fn=True |
| 65 | + ) |
54 | 66 |
|
55 | | - return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) |
| 67 | + data_dp = FileOpener(cache_decompressed_dp, mode="b") |
| 68 | + return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) |
0 commit comments