From 576b3e83b10f63da349919d9fd19e6578ecb869e Mon Sep 17 00:00:00 2001 From: nayef211 Date: Fri, 22 Oct 2021 13:09:05 -0700 Subject: [PATCH 1/2] Updated SST2Dataset to subclass IterableDataset. Updated SST2 functional call to return SST2Dataset object --- test/experimental/test_datasets.py | 12 ++++++++---- torchtext/experimental/datasets/sst2.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/test/experimental/test_datasets.py b/test/experimental/test_datasets.py index 2a9ff700ff..31f52f3193 100644 --- a/test/experimental/test_datasets.py +++ b/test/experimental/test_datasets.py @@ -11,24 +11,28 @@ class TestDataset(TorchtextTestCase): @skipIfNoModule("torchdata") def test_sst2_dataset(self): split = ("train", "dev", "test") - train_dp, dev_dp, test_dp = sst2.SST2(split=split) + train_dataset, dev_dataset, test_dataset = sst2.SST2(split=split) + + # verify datasets objects are instances of SST2Dataset + for dataset in (train_dataset, dev_dataset, test_dataset): + self.assertTrue(isinstance(dataset, sst2.SST2Dataset)) # verify hashes of first line in dataset self.assertEqual( hashlib.md5( - json.dumps(next(iter(train_dp)), sort_keys=True).encode("utf-8") + json.dumps(next(iter(train_dataset)), sort_keys=True).encode("utf-8") ).hexdigest(), sst2._FIRST_LINE_MD5["train"], ) self.assertEqual( hashlib.md5( - json.dumps(next(iter(dev_dp)), sort_keys=True).encode("utf-8") + json.dumps(next(iter(dev_dataset)), sort_keys=True).encode("utf-8") ).hexdigest(), sst2._FIRST_LINE_MD5["dev"], ) self.assertEqual( hashlib.md5( - json.dumps(next(iter(test_dp)), sort_keys=True).encode("utf-8") + json.dumps(next(iter(test_dataset)), sort_keys=True).encode("utf-8") ).hexdigest(), sst2._FIRST_LINE_MD5["test"], ) diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index 71774a3e58..f8cfc53158 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. import os +from torch.utils.data.dataset import IterableDataset from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( _add_docstring_header, @@ -50,10 +51,10 @@ @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "dev", "test")) def SST2(root, split): - return SST2Dataset(root, split).get_datapipe() + return SST2Dataset(root, split) -class SST2Dataset: +class SST2Dataset(IterableDataset): """The SST2 dataset uses torchdata datapipes end-2-end. To avoid download at every epoch, we cache the data on-disk We do sanity check on dowloaded and extracted data @@ -69,6 +70,11 @@ def __init__(self, root, split): self.root = root self.split = split + self.dp = self.get_datapipe() + + def __iter__(self): + for data in self.dp: + yield data def get_datapipe(self): # cache data on-disk From 9503cee8218fa3f74083a20d4511f97a27b800bb Mon Sep 17 00:00:00 2001 From: nayef211 Date: Fri, 22 Oct 2021 14:40:57 -0700 Subject: [PATCH 2/2] Updated get_datapipe to be private, passed class parameters directly into get_datapipe function --- torchtext/experimental/datasets/sst2.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index f8cfc53158..fa15b73304 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -68,31 +68,27 @@ def __init__(self, root, split): "how to install the package." ) - self.root = root - self.split = split - self.dp = self.get_datapipe() + self._dp = self._get_datapipe(root, split) def __iter__(self): - for data in self.dp: + for data in self._dp: yield data - def get_datapipe(self): + def _get_datapipe(self, root, split): # cache data on-disk cache_dp = IterableWrapper([URL]).on_disk_cache( HttpReader, op_map=lambda x: (x[0], x[1].read()), - filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)), + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), ) # do sanity check check_cache_dp = cache_dp.check_hash( - {os.path.join(self.root, "SST-2.zip"): MD5}, "md5" + {os.path.join(root, "SST-2.zip"): MD5}, "md5" ) # extract data from zip - extracted_files = check_cache_dp.read_from_zip().filter( - lambda x: self.split in x[0] - ) + extracted_files = check_cache_dp.read_from_zip().filter(lambda x: split in x[0]) # Parse CSV file and yield data samples return extracted_files.parse_csv(skip_lines=1, delimiter="\t").map(