Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 12 additions & 24 deletions examples/text/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,8 @@
_PATH = "amazon_review_polarity_csv.tar.gz"

_EXTRACTED_FILES = {
"train": os.path.join(_PATH, "amazon_review_polarity_csv", "train.csv"),
"test": os.path.join(_PATH, "amazon_review_polarity_csv", "test.csv"),
}

_EXTRACTED_FILES_MD5 = {
"train": "520937107c39a2d1d1f66cd410e9ed9e",
"test": "f4c8bded2ecbde5f996b675db6228f16",
"train": os.path.join("amazon_review_polarity_csv", "train.csv"),
"test": os.path.join("amazon_review_polarity_csv", "test.csv"),
}

DATASET_NAME = "AmazonReviewPolarity"
Expand All @@ -37,25 +32,18 @@ def AmazonReviewPolarity(root, split):
"""Demonstrating caching, extraction and sanity check pipelines."""

url_dp = IterableWrapper([URL])
# cache data on-disk with sanity check
cache_dp = url_dp.on_disk_cache(
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
)
cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_dp = FileOpener(cache_dp, mode="b")
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

# stack TAR extractor on top of loader DP
extracted_files = cache_dp.read_from_tar()

# filter files as necessary
filter_extracted_files = extracted_files.filter(lambda x: split in x[0])

# stack sanity checker on top of extracted files
check_filter_extracted_files = filter_extracted_files.check_hash(
{os.path.normpath(os.path.join(root, _EXTRACTED_FILES[split])): _EXTRACTED_FILES_MD5[split]},
"md5",
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
)
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0])
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

# stack CSV reader and do some mapping
return check_filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), t[1]))
data_dp = FileOpener(cache_decompressed_dp, mode="b")
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))