2323_PATH = 'amazon_review_polarity_csv.tar.gz'
2424
2525_EXTRACTED_FILES = {
26- 'train' : os .path .join (_PATH , 'amazon_review_polarity_csv' , 'train.csv' ),
27- 'test' : os .path .join (_PATH , 'amazon_review_polarity_csv' , 'test.csv' ),
26+ 'train' : os .path .join ('amazon_review_polarity_csv' , 'train.csv' ),
27+ 'test' : os .path .join ('amazon_review_polarity_csv' , 'test.csv' ),
2828}
2929
3030
@@ -44,9 +44,25 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
4444 filepath_fn = lambda x : os .path .join (root , _PATH ), hash_dict = {os .path .join (root , _PATH ): MD5 }, hash_type = "md5"
4545 )
4646 cache_compressed_dp = GDriveReader (cache_compressed_dp ).end_caching (mode = "wb" , same_filepath_fn = True )
47+
48+ def extracted_filepath_fn (x ):
49+ file_path = os .path .join (root , _EXTRACTED_FILES [split ])
50+ dir_path = os .path .dirname (file_path )
51+ if not os .path .exists (dir_path ):
52+ os .makedirs (dir_path )
53+ return file_path
54+
4755 cache_decompressed_dp = cache_compressed_dp .on_disk_cache (
48- filepath_fn = lambda x : os .path .join (root , os .path .dirname (_EXTRACTED_FILES [split ]), os .path .basename (x )))
49- cache_decompressed_dp = FileOpener (cache_decompressed_dp , mode = "b" ).read_from_tar ()
50- cache_compressed_dp = cache_decompressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
51- data_dp = FileOpener (cache_decompressed_dp .filter (lambda x : _EXTRACTED_FILES [split ] in x [0 ]).map (lambda x : x [0 ]), mode = 'b' )
56+ filepath_fn = extracted_filepath_fn )
57+ cache_decompressed_dp = FileOpener (cache_decompressed_dp , mode = "b" ).\
58+ read_from_tar ().\
59+ filter (lambda x : _EXTRACTED_FILES [split ] in x [0 ]).\
60+ map (lambda x : (x [0 ].replace ('_PATH' + '/' , '' ), x [1 ]))
61+ cache_decompressed_dp = cache_decompressed_dp .end_caching (mode = "wb" , same_filepath_fn = True )
62+ data_dp = FileOpener (cache_decompressed_dp , mode = 'b' )
63+
64+ # data_dp = FileOpener(cache_compressed_dp, mode='b')
65+ # data_dp = data_dp.read_from_tar()
66+ # data_dp = data_dp.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
67+
5268 return data_dp .parse_csv ().map (fn = lambda t : (int (t [0 ]), ' ' .join (t [1 :])))
0 commit comments