diff --git a/torchtext/datasets/yahooanswers.py b/torchtext/datasets/yahooanswers.py index 7e95e2c6ca..70c3d8b807 100644 --- a/torchtext/datasets/yahooanswers.py +++ b/torchtext/datasets/yahooanswers.py @@ -33,22 +33,29 @@ @_add_docstring_header(num_lines=NUM_LINES, num_classes=10) @_create_dataset_directory(dataset_name=DATASET_NAME) -@_wrap_split_argument(('train', 'test')) +@_wrap_split_argument(("train", "test")) def YahooAnswers(root: str, split: Union[Tuple[str], str]): if not is_module_available("torchdata"): raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") url_dp = IterableWrapper([URL]) - cache_dp = url_dp.on_disk_cache( + cache_compressed_dp = url_dp.on_disk_cache( filepath_fn=lambda x: os.path.join(root, _PATH), - hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" + hash_dict={os.path.join(root, _PATH): MD5}, + hash_type="md5" ) - cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_dp = FileOpener(cache_dp, mode="b") + cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + cache_compressed_dp = FileOpener(cache_compressed_dp, mode="b") - extracted_files = cache_dp.read_from_tar() + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) + ) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") + cache_decompressed_dp = cache_decompressed_dp.read_from_tar() + cache_decompressed_dp = cache_decompressed_dp.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + data_dp = FileOpener(cache_decompressed_dp, mode="b") - return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))