Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
5 changes: 5 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ SST2

.. autofunction:: SST2

STSB
~~~~

.. autofunction:: STSB

YahooAnswers
~~~~~~~~~~~~

Expand Down
89 changes: 89 additions & 0 deletions test/datasets/test_stsb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import random
import tarfile
from collections import defaultdict
from unittest.mock import patch

from parameterized import parameterized
from torchtext.datasets.stsb import STSB

from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode
from ..common.torchtext_test_case import TorchtextTestCase


def _get_mock_dataset(root_dir):
"""
root_dir: directory to the mocked dataset
"""
base_dir = os.path.join(root_dir, "STSB")
temp_dataset_dir = os.path.join(base_dir, "stsbenchmark")
os.makedirs(temp_dataset_dir, exist_ok=True)

seed = 1
mocked_data = defaultdict(list)
for file_name, name in zip(["sts-train.csv", "sts-dev.csv" "sts-test.csv"], ["train", "dev", "test"]):
txt_file = os.path.join(temp_dataset_dir, file_name)
with open(txt_file, "w", encoding="utf-8") as f:
for i in range(5):
label = random.uniform(0, 5)
rand_string_1 = get_random_unicode(seed)
rand_string_2 = get_random_unicode(seed + 1)
rand_string_3 = get_random_unicode(seed + 2)
rand_string_4 = get_random_unicode(seed + 3)
rand_string_5 = get_random_unicode(seed + 4)
dataset_line = (i, label, rand_string_4, rand_string_5)
# append line to correct dataset split
mocked_data[name].append(dataset_line)
f.write(
f"{rand_string_1}\t{rand_string_2}\t{rand_string_3}\t{i}\t{label}\t{rand_string_4}\t{rand_string_5}\n"
)
seed += 1
# case with quotes to test arg `quoting=csv.QUOTE_NONE`
dataset_line = (i, label, rand_string_4, rand_string_5)
# append line to correct dataset split
mocked_data[name].append(dataset_line)
f.write(
f'{rand_string_1}"\t"{rand_string_2}\t{rand_string_3}\t{i}\t{label}\t{rand_string_4}\t{rand_string_5}\n'
)

compressed_dataset_path = os.path.join(base_dir, "Stsbenchmark.tar.gz")
# create tar file from dataset folder
with tarfile.open(compressed_dataset_path, "w:gz") as tar:
tar.add(temp_dataset_dir, arcname="stsbenchmark")

return mocked_data


class TestSTSB(TempDirMixin, TorchtextTestCase):
root_dir = None
samples = []

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.root_dir = cls.get_base_temp_dir()
cls.samples = _get_mock_dataset(cls.root_dir)
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True)
cls.patcher.start()

@classmethod
def tearDownClass(cls):
cls.patcher.stop()
super().tearDownClass()

@parameterized.expand(["train", "dev", "test"])
def test_stsb(self, split):
dataset = STSB(root=self.root_dir, split=split)

samples = list(dataset)
expected_samples = self.samples[split]
for sample, expected_sample in zip_equal(samples, expected_samples):
self.assertEqual(sample, expected_sample)

@parameterized.expand(["train", "dev", "test"])
def test_stsb_split_argument(self, split):
dataset1 = STSB(root=self.root_dir, split=split)
(dataset2,) = STSB(root=self.root_dir, split=(split,))

for d1, d2 in zip_equal(dataset1, dataset2):
self.assertEqual(d1, d2)
2 changes: 2 additions & 0 deletions torchtext/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .squad1 import SQuAD1
from .squad2 import SQuAD2
from .sst2 import SST2
from .stsb import STSB
from .udpos import UDPOS
from .wikitext103 import WikiText103
from .wikitext2 import WikiText2
Expand All @@ -40,6 +41,7 @@
"SQuAD2": SQuAD2,
"SogouNews": SogouNews,
"SST2": SST2,
"STSB": STSB,
"UDPOS": UDPOS,
"WikiText103": WikiText103,
"WikiText2": WikiText2,
Expand Down
90 changes: 90 additions & 0 deletions torchtext/datasets/stsb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import csv
import os

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_create_dataset_directory,
_wrap_split_argument,
)

if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, IterableWrapper

# we import HttpReader from _download_hooks so we can swap out public URLs
# with interal URLs when the dataset is used within Facebook
from torchtext._download_hooks import HttpReader


URL = "http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz"

MD5 = "4eb0065aba063ef77873d3a9c8088811"

NUM_LINES = {
"train": 5749,
"dev": 1500,
"test": 1379,
}

_PATH = "Stsbenchmark.tar.gz"

DATASET_NAME = "STSB"

_EXTRACTED_FILES = {
"train": os.path.join("stsbenchmark", "sts-train.csv"),
"dev": os.path.join("stsbenchmark", "sts-dev.csv"),
"test": os.path.join("stsbenchmark", "sts-test.csv"),
}


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "dev", "test"))
def STSB(root, split):
"""STSB Dataset

For additional details refer to https://ixa2.si.ehu.eus/stswiki/index.php/STSbenchmark

Number of lines per split:
- train: 5749
- dev: 1500
- test: 1379

Args:
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache')
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`)

:returns: DataPipe that yields tuple of (index (int), label (float), sentence1 (str), sentence2 (str))
:rtype: (int, float, str, str)
"""
# TODO Remove this after removing conditional dependency
if not is_module_available("torchdata"):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(x=_PATH):
return os.path.join(root, os.path.basename(x))

def _extracted_filepath_fn(_=None):
return _filepath_fn(_EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(x):
return (int(x[3]), float(x[4]), x[5], x[6])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(URL): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(_filter_fn)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
parsed_data = data_dp.parse_csv(delimiter="\t", quoting=csv.QUOTE_NONE).map(_modify_res)
return parsed_data.shuffle().set_shuffle(False).sharding_filter()