Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions test/experimental/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,28 @@ class TestDataset(TorchtextTestCase):
@skipIfNoModule("torchdata")
def test_sst2_dataset(self):
split = ("train", "dev", "test")
train_dp, dev_dp, test_dp = sst2.SST2(split=split)
train_dataset, dev_dataset, test_dataset = sst2.SST2(split=split)

# verify datasets objects are instances of SST2Dataset
for dataset in (train_dataset, dev_dataset, test_dataset):
self.assertTrue(isinstance(dataset, sst2.SST2Dataset))

# verify hashes of first line in dataset
self.assertEqual(
hashlib.md5(
json.dumps(next(iter(train_dp)), sort_keys=True).encode("utf-8")
json.dumps(next(iter(train_dataset)), sort_keys=True).encode("utf-8")
).hexdigest(),
sst2._FIRST_LINE_MD5["train"],
)
self.assertEqual(
hashlib.md5(
json.dumps(next(iter(dev_dp)), sort_keys=True).encode("utf-8")
json.dumps(next(iter(dev_dataset)), sort_keys=True).encode("utf-8")
).hexdigest(),
sst2._FIRST_LINE_MD5["dev"],
)
self.assertEqual(
hashlib.md5(
json.dumps(next(iter(test_dp)), sort_keys=True).encode("utf-8")
json.dumps(next(iter(test_dataset)), sort_keys=True).encode("utf-8")
).hexdigest(),
sst2._FIRST_LINE_MD5["test"],
)
22 changes: 12 additions & 10 deletions torchtext/experimental/datasets/sst2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import os

from torch.utils.data.dataset import IterableDataset
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_add_docstring_header,
Expand Down Expand Up @@ -50,10 +51,10 @@
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "dev", "test"))
def SST2(root, split):
return SST2Dataset(root, split).get_datapipe()
return SST2Dataset(root, split)


class SST2Dataset:
class SST2Dataset(IterableDataset):
"""The SST2 dataset uses torchdata datapipes end-2-end.
To avoid download at every epoch, we cache the data on-disk
We do sanity check on dowloaded and extracted data
Expand All @@ -67,26 +68,27 @@ def __init__(self, root, split):
"how to install the package."
)

self.root = root
self.split = split
self._dp = self._get_datapipe(root, split)

def get_datapipe(self):
def __iter__(self):
for data in self._dp:
yield data

def _get_datapipe(self, root, split):
# cache data on-disk
cache_dp = IterableWrapper([URL]).on_disk_cache(
HttpReader,
op_map=lambda x: (x[0], x[1].read()),
filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)),
filepath_fn=lambda x: os.path.join(root, os.path.basename(x)),
)

# do sanity check
check_cache_dp = cache_dp.check_hash(
{os.path.join(self.root, "SST-2.zip"): MD5}, "md5"
{os.path.join(root, "SST-2.zip"): MD5}, "md5"
)

# extract data from zip
extracted_files = check_cache_dp.read_from_zip().filter(
lambda x: self.split in x[0]
)
extracted_files = check_cache_dp.read_from_zip().filter(lambda x: split in x[0])

# Parse CSV file and yield data samples
return extracted_files.parse_csv(skip_lines=1, delimiter="\t").map(
Expand Down