diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index 85b892eb69..8c678de865 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -33,10 +33,12 @@ MD5 = "9f81648d4199384278b86e315dac217c" URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip" +_PATH = "SST-2.zip" + _EXTRACTED_FILES = { - "train": f"{os.sep}".join(["SST-2", "train.tsv"]), - "dev": f"{os.sep}".join(["SST-2", "dev.tsv"]), - "test": f"{os.sep}".join(["SST-2", "test.tsv"]), + "train": f"{os.sep}".join([_PATH, "SST-2", "train.tsv"]), + "dev": f"{os.sep}".join([_PATH, "SST-2", "dev.tsv"]), + "test": f"{os.sep}".join([_PATH, "SST-2", "test.tsv"]), } _EXTRACTED_FILES_MD5 = { @@ -79,12 +81,17 @@ def get_datapipe(self): filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)), ) + # do sanity check + check_cache_dp = cache_dp.check_hash( + {os.path.join(self.root, "SST-2.zip"): MD5}, "md5" + ) + # extract data from zip - extracted_files = cache_dp.read_from_zip() + extracted_files = check_cache_dp.read_from_zip().filter( + lambda x: self.split in x[0] + ) # Parse CSV file and yield data samples - return ( - extracted_files.filter(lambda x: self.split in x[0]) - .parse_csv(skip_lines=1, delimiter="\t") - .map(lambda x: (x[0], x[1])) + return extracted_files.parse_csv(skip_lines=1, delimiter="\t").map( + lambda x: (x[0], x[1]) )