11import hashlib
22import os
3- from collections import defaultdict
43from functools import partial
5- from typing import Union , Tuple
4+ from typing import Union , Set , Tuple
65
76from torchtext ._internal .module_utils import is_module_available
87from torchtext .data .datasets_utils import (
5251 "test" : 11490 ,
5352}
5453
55- story_fnames = defaultdict (set )
56-
5754
5855def _filepath_fn (root : str , source : str , _ = None ):
5956 return os .path .join (root , PATH_LIST [source ])
6057
6158
6259# called once per tar file, therefore no duplicate processing
6360def _extracted_folder_fn (root : str , source : str , split : str , _ = None ):
64- global story_fnames
6561 key = source + "_" + split
66- story_fnames [key ] = set (_get_split_list (source , split ))
67- filepaths = [os .path .join (root , _EXTRACTED_FOLDERS [source ], story ) for story in story_fnames [key ]]
68- return filepaths
62+ filepath = os .path .join (root , key )
63+ return filepath
6964
7065
7166def _extracted_filepath_fn (root : str , source : str , x : str ):
7267 return os .path .join (root , _EXTRACTED_FOLDERS [source ], os .path .basename (x ))
7368
7469
75- def _filter_fn (source : str , split : str , x : tuple ):
76- return os .path .basename (x [0 ]) in story_fnames [ source + "_" + split ]
70+ def _filter_fn (split_list : Set [ str ] , x : tuple ):
71+ return os .path .basename (x [0 ]) in split_list
7772
7873
7974def _hash_urls (s : tuple ):
@@ -96,6 +91,9 @@ def _get_split_list(source: str, split: str):
9691
9792
9893def _load_stories (root : str , source : str , split : str ):
94+
95+ split_list = set (_get_split_list (source , split ))
96+
9997 story_dp = IterableWrapper ([URL [source ]])
10098 cache_compressed_dp = story_dp .on_disk_cache (
10199 filepath_fn = partial (_filepath_fn , root , source ),
@@ -108,7 +106,7 @@ def _load_stories(root: str, source: str, split: str):
108106 filepath_fn = partial (_extracted_folder_fn , root , source , split )
109107 )
110108 cache_decompressed_dp = (
111- FileOpener (cache_decompressed_dp , mode = "b" ).load_from_tar ().filter (partial (_filter_fn , source , split ))
109+ FileOpener (cache_decompressed_dp , mode = "b" ).load_from_tar ().filter (partial (_filter_fn , split_list ))
112110 )
113111 cache_decompressed_dp = cache_decompressed_dp .end_caching (
114112 mode = "wb" , filepath_fn = partial (_extracted_filepath_fn , root , source )
0 commit comments