Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ General use cases are as follows: ::

def tokenize(label, line):
return line.split()

tokens = []
for label, line in train_iter:
tokens += tokenize(label, line)
Expand Down Expand Up @@ -73,6 +73,11 @@ IMDb

.. autofunction:: IMDB

SST2
~~~~

.. autofunction:: SST2


Language Modeling
^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -152,4 +157,3 @@ EnWik9
~~~~~~

.. autofunction:: EnWik9

42 changes: 0 additions & 42 deletions test/experimental/test_datasets.py

This file was deleted.

45 changes: 24 additions & 21 deletions torchtext/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib

from .ag_news import AG_NEWS
from .amazonreviewfull import AmazonReviewFull
from .amazonreviewpolarity import AmazonReviewPolarity
Expand All @@ -8,39 +9,41 @@
from .imdb import IMDB
from .iwslt2016 import IWSLT2016
from .iwslt2017 import IWSLT2017
from .multi30k import Multi30k
from .penntreebank import PennTreebank
from .sogounews import SogouNews
from .squad1 import SQuAD1
from .squad2 import SQuAD2
from .sst2 import SST2
from .udpos import UDPOS
from .wikitext103 import WikiText103
from .wikitext2 import WikiText2
from .yahooanswers import YahooAnswers
from .yelpreviewfull import YelpReviewFull
from .yelpreviewpolarity import YelpReviewPolarity
from .multi30k import Multi30k

DATASETS = {
'AG_NEWS': AG_NEWS,
'AmazonReviewFull': AmazonReviewFull,
'AmazonReviewPolarity': AmazonReviewPolarity,
'CoNLL2000Chunking': CoNLL2000Chunking,
'DBpedia': DBpedia,
'EnWik9': EnWik9,
'IMDB': IMDB,
'IWSLT2016': IWSLT2016,
'IWSLT2017': IWSLT2017,
'PennTreebank': PennTreebank,
'SQuAD1': SQuAD1,
'SQuAD2': SQuAD2,
'SogouNews': SogouNews,
'UDPOS': UDPOS,
'WikiText103': WikiText103,
'WikiText2': WikiText2,
'YahooAnswers': YahooAnswers,
'YelpReviewFull': YelpReviewFull,
'YelpReviewPolarity': YelpReviewPolarity,
'Multi30k': Multi30k
"AG_NEWS": AG_NEWS,
"AmazonReviewFull": AmazonReviewFull,
"AmazonReviewPolarity": AmazonReviewPolarity,
"CoNLL2000Chunking": CoNLL2000Chunking,
"DBpedia": DBpedia,
"EnWik9": EnWik9,
"IMDB": IMDB,
"IWSLT2016": IWSLT2016,
"IWSLT2017": IWSLT2017,
"Multi30k": Multi30k,
"PennTreebank": PennTreebank,
"SQuAD1": SQuAD1,
"SQuAD2": SQuAD2,
"SogouNews": SogouNews,
"SST2": SST2,
"UDPOS": UDPOS,
"WikiText103": WikiText103,
"WikiText2": WikiText2,
"YahooAnswers": YahooAnswers,
"YelpReviewFull": YelpReviewFull,
"YelpReviewPolarity": YelpReviewPolarity,
}

URLS = {}
Expand Down
82 changes: 82 additions & 0 deletions torchtext/datasets/sst2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import os

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_add_docstring_header,
_create_dataset_directory,
_wrap_split_argument,
)

if is_module_available("torchdata"):
from torchdata.datapipes.iter import IterableWrapper, FileOpener

# we import HttpReader from _download_hooks so we can swap out public URLs
# with interal URLs when the dataset is used within Facebook
from torchtext._download_hooks import HttpReader


URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip"

MD5 = "9f81648d4199384278b86e315dac217c"

NUM_LINES = {
"train": 67349,
"dev": 872,
"test": 1821,
}

_PATH = "SST-2.zip"

DATASET_NAME = "SST2"

_EXTRACTED_FILES = {
"train": os.path.join("SST-2", "train.tsv"),
"dev": os.path.join("SST-2", "dev.tsv"),
"test": os.path.join("SST-2", "test.tsv"),
}


@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "dev", "test"))
def SST2(root, split):
# TODO Remove this after removing conditional dependency
if not is_module_available("torchdata"):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(x)),
hash_dict={os.path.join(root, os.path.basename(URL)): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(
mode="wb", same_filepath_fn=True
)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
)
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b")
.read_from_zip()
.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(
mode="wb", same_filepath_fn=True
)

data_dp = FileOpener(cache_decompressed_dp, mode="b")
# test split for SST2 doesn't have labels
if split == "test":
parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(
lambda t: (t[1].strip(),)
)
else:
parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(
lambda t: (t[0].strip(), int(t[1]))
)
return parsed_data
3 changes: 1 addition & 2 deletions torchtext/experimental/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from . import raw
from . import sst2

__all__ = ["raw", "sst2"]
__all__ = ["raw"]
109 changes: 0 additions & 109 deletions torchtext/experimental/datasets/sst2.py

This file was deleted.