@@ -44,11 +44,26 @@ def SogouNews(root: str, split: Union[Tuple[str], str]):
4444 raise ModuleNotFoundError ("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" )
4545
4646 url_dp = IterableWrapper ([URL ])
47- cache_dp = url_dp .on_disk_cache (
48- filepath_fn = lambda x : os .path .join (root , _PATH ), hash_dict = {os .path .join (root , _PATH ): MD5 }, hash_type = "md5"
47+ cache_compressed_dp = url_dp .on_disk_cache (
48+ filepath_fn = lambda x : os .path .join (root , _PATH ),
49+ hash_dict = {os .path .join (root , _PATH ): MD5 },
50+ hash_type = "md5" ,
4951 )
50- cache_dp = GDriveReader (cache_dp ).end_caching (mode = "wb" , same_filepath_fn = True )
51- cache_dp = FileOpener (cache_dp , mode = "b" )
52- extracted_files = cache_dp .read_from_tar ()
53- filter_extracted_files = extracted_files .filter (lambda x : _EXTRACTED_FILES [split ] in x [0 ])
54- return filter_extracted_files .parse_csv ().map (fn = lambda t : (int (t [0 ]), ' ' .join (t [1 :])))
52+ cache_compressed_dp = GDriveReader (cache_compressed_dp ).end_caching (
53+ mode = "wb" , same_filepath_fn = True
54+ )
55+
56+ cache_decompressed_dp = cache_compressed_dp .on_disk_cache (
57+ filepath_fn = lambda x : os .path .join (root , _EXTRACTED_FILES [split ])
58+ )
59+ cache_decompressed_dp = (
60+ FileOpener (cache_decompressed_dp , mode = "b" )
61+ .read_from_tar ()
62+ .filter (lambda x : _EXTRACTED_FILES [split ] in x [0 ])
63+ )
64+ cache_decompressed_dp = cache_decompressed_dp .end_caching (
65+ mode = "wb" , same_filepath_fn = True
66+ )
67+
68+ data_dp = FileOpener (cache_decompressed_dp , mode = "b" )
69+ return data_dp .parse_csv ().map (fn = lambda t : (int (t [0 ]), " " .join (t [1 :])))
0 commit comments