From a30515f4e3d49ee2135ab93b372c3771577684df Mon Sep 17 00:00:00 2001 From: Nayef Ahmed Date: Wed, 10 Nov 2021 18:29:03 -0800 Subject: [PATCH] Fixed file filtering bug in SST2 dataset Summary: - Removed copying partial SST2 asset file to a temp dir and instead directly working with the file from the asset folder - Fixed bug with path names affecting how files were filtered out from the zip file - For example, if the value of `split` is "test", the following snippet of code `filter(lambda x: split in x[0])` might match all of the "train", "test", and "dev" files depending on the location of the dataset asset file - When testing with buck, the location of the extracted files could look something like `/data/users/nayef211/fbsource/fbcode/buck-out/dev/gen/pytorch/text/test/experimental_test_datasets#binary,link-tree/test/asset/SST2/SST-2.zip/train.tsv`. Since the word "test" is contained in this path string, the filtering logic would incorrectly select the "train" file even though what we want is the "test" file - To resolve this we append the file extension (in this case ".tsv") to the `split` variable in the filtering logic Reviewed By: parmeet Differential Revision: D32329831 fbshipit-source-id: dbb4803a04f6cd50fab3f7ce5530d3258b2db012 --- test/asset/{ => SST2}/SST-2.zip | Bin test/experimental/test_datasets.py | 67 ++++++++++-------------- torchtext/experimental/datasets/sst2.py | 4 +- 3 files changed, 30 insertions(+), 41 deletions(-) rename test/asset/{ => SST2}/SST-2.zip (100%) diff --git a/test/asset/SST-2.zip b/test/asset/SST2/SST-2.zip similarity index 100% rename from test/asset/SST-2.zip rename to test/asset/SST2/SST-2.zip diff --git a/test/experimental/test_datasets.py b/test/experimental/test_datasets.py index 50af68ac4b..18807c4958 100644 --- a/test/experimental/test_datasets.py +++ b/test/experimental/test_datasets.py @@ -1,12 +1,9 @@ import hashlib import json -import os -import shutil -import tempfile from torchtext.experimental.datasets import sst2 -from ..common.assets import get_asset_path +from ..common.assets import _ASSET_DIR from ..common.case_utils import skipIfNoModule from ..common.torchtext_test_case import TorchtextTestCase @@ -14,42 +11,32 @@ class TestDataset(TorchtextTestCase): @skipIfNoModule("torchdata") def test_sst2__dataset(self): - # copy the asset file into the expected download location - # note that this is just a zip file with the first 10 lines of the SST2 dataset - # test if providing a custom hash works with the dummy dataset - with tempfile.TemporaryDirectory() as dir_name: - asset_path = get_asset_path(sst2._PATH) - data_path = os.path.join(dir_name, sst2.DATASET_NAME, sst2._PATH) - os.makedirs(os.path.join(dir_name, sst2.DATASET_NAME)) - shutil.copy(asset_path, data_path) - split = ("train", "dev", "test") - train_dataset, dev_dataset, test_dataset = sst2.SST2( - split=split, root=dir_name, validate_hash=False - ) + split = ("train", "dev", "test") + train_dataset, dev_dataset, test_dataset = sst2.SST2( + split=split, root=_ASSET_DIR, validate_hash=False + ) - # verify datasets objects are instances of SST2Dataset - for dataset in (train_dataset, dev_dataset, test_dataset): - self.assertTrue(isinstance(dataset, sst2.SST2Dataset)) + # verify datasets objects are instances of SST2Dataset + for dataset in (train_dataset, dev_dataset, test_dataset): + self.assertTrue(isinstance(dataset, sst2.SST2Dataset)) - # verify hashes of first line in dataset - self.assertEqual( - hashlib.md5( - json.dumps(next(iter(train_dataset)), sort_keys=True).encode( - "utf-8" - ) - ).hexdigest(), - sst2._FIRST_LINE_MD5["train"], - ) - self.assertEqual( - hashlib.md5( - json.dumps(next(iter(dev_dataset)), sort_keys=True).encode("utf-8") - ).hexdigest(), - sst2._FIRST_LINE_MD5["dev"], - ) - self.assertEqual( - hashlib.md5( - json.dumps(next(iter(test_dataset)), sort_keys=True).encode("utf-8") - ).hexdigest(), - sst2._FIRST_LINE_MD5["test"], - ) + # verify hashes of first line in dataset + self.assertEqual( + hashlib.md5( + json.dumps(next(iter(train_dataset)), sort_keys=True).encode("utf-8") + ).hexdigest(), + sst2._FIRST_LINE_MD5["train"], + ) + self.assertEqual( + hashlib.md5( + json.dumps(next(iter(dev_dataset)), sort_keys=True).encode("utf-8") + ).hexdigest(), + sst2._FIRST_LINE_MD5["dev"], + ) + self.assertEqual( + hashlib.md5( + json.dumps(next(iter(test_dataset)), sort_keys=True).encode("utf-8") + ).hexdigest(), + sst2._FIRST_LINE_MD5["test"], + ) diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index 95bac217f7..77d3134ac6 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -92,7 +92,9 @@ def _get_datapipe(self, root, split, validate_hash): ) # extract data from zip - extracted_files = check_cache_dp.read_from_zip().filter(lambda x: split in x[0]) + extracted_files = check_cache_dp.read_from_zip().filter( + lambda x: f"{split}.tsv" in x[0] + ) # Parse CSV file and yield data samples return extracted_files.parse_csv(skip_lines=1, delimiter="\t").map(