Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 99eb1f8

Browse files
Nayef211nayef211
andauthored
Added caching to extracted files in debpedia (#1571)
Co-authored-by: nayef211 <[email protected]>
1 parent fed25fe commit 99eb1f8

File tree

1 file changed

+32
-19
lines changed

1 file changed

+32
-19
lines changed

torchtext/datasets/dbpedia.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,32 @@
1-
from torchtext._internal.module_utils import is_module_available
21
from typing import Union, Tuple
32

3+
from torchtext._internal.module_utils import is_module_available
4+
45
if is_module_available("torchdata"):
56
from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper
67

8+
import os
9+
710
from torchtext.data.datasets_utils import (
811
_wrap_split_argument,
912
_add_docstring_header,
1013
_create_dataset_directory,
1114
)
1215

13-
import os
14-
15-
URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k'
16+
URL = "https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k"
1617

17-
MD5 = 'dca7b1ae12b1091090db52aa7ec5ca64'
18+
MD5 = "dca7b1ae12b1091090db52aa7ec5ca64"
1819

1920
NUM_LINES = {
20-
'train': 560000,
21-
'test': 70000,
21+
"train": 560000,
22+
"test": 70000,
2223
}
2324

24-
_PATH = 'dbpedia_csv.tar.gz'
25+
_PATH = "dbpedia_csv.tar.gz"
2526

2627
_EXTRACTED_FILES = {
2728
"train": os.path.join("dbpedia_csv", "train.csv"),
28-
"test": os.path.join("dbpedia_csv", "test.csv")
29+
"test": os.path.join("dbpedia_csv", "test.csv"),
2930
}
3031

3132
DATASET_NAME = "DBpedia"
@@ -37,19 +38,31 @@
3738
def DBpedia(root: str, split: Union[Tuple[str], str]):
3839
# TODO Remove this after removing conditional dependency
3940
if not is_module_available("torchdata"):
40-
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")
41+
raise ModuleNotFoundError(
42+
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
43+
)
4144

4245
url_dp = IterableWrapper([URL])
43-
44-
cache_dp = url_dp.on_disk_cache(
46+
cache_compressed_dp = url_dp.on_disk_cache(
4547
filepath_fn=lambda x: os.path.join(root, _PATH),
46-
hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
48+
hash_dict={os.path.join(root, _PATH): MD5},
49+
hash_type="md5",
50+
)
51+
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(
52+
mode="wb", same_filepath_fn=True
4753
)
48-
cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
49-
cache_dp = FileOpener(cache_dp, mode="b")
50-
51-
extracted_files = cache_dp.read_from_tar()
5254

53-
filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
55+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
56+
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
57+
)
58+
cache_decompressed_dp = (
59+
FileOpener(cache_decompressed_dp, mode="b")
60+
.read_from_tar()
61+
.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
62+
)
63+
cache_decompressed_dp = cache_decompressed_dp.end_caching(
64+
mode="wb", same_filepath_fn=True
65+
)
5466

55-
return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))
67+
data_dp = FileOpener(cache_decompressed_dp, mode="b")
68+
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))

0 commit comments

Comments
 (0)