Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 6795977

Browse files
committed
add _EXTRACTED_FILES for consistency.
1 parent a3188ac commit 6795977

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchtext/datasets/dbpedia.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323

2424
_PATH = 'dbpedia_csv.tar.gz'
2525

26+
_EXTRACTED_FILES = {
27+
"train": os.path.join("dbpedia_csv", "train.csv"),
28+
"test": os.path.join("dbpedia_csv", "test.csv")
29+
}
30+
2631
DATASET_NAME = "DBpedia"
2732

2833

@@ -45,6 +50,6 @@ def DBpedia(root: str, split: Union[Tuple[str], str]):
4550

4651
extracted_files = cache_dp.read_from_tar()
4752

48-
filter_extracted_files = extracted_files.filter(lambda x: split + ".csv" in x[0])
53+
filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
4954

5055
return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))

0 commit comments

Comments
 (0)