Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit d72124c

Browse files
Nayef211nayef211
andauthored
Migrate SST2 from experimental to datasets folder (#1538)
* Migrating SST2 from experimental to datasets folder * Added SST2 to docs and to init file * Removing empty line from docs Co-authored-by: nayef211 <[email protected]>
1 parent ce1ce99 commit d72124c

File tree

6 files changed

+113
-176
lines changed

6 files changed

+113
-176
lines changed

docs/source/datasets.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ General use cases are as follows: ::
1313

1414
def tokenize(label, line):
1515
return line.split()
16-
16+
1717
tokens = []
1818
for label, line in train_iter:
1919
tokens += tokenize(label, line)
@@ -73,6 +73,11 @@ IMDb
7373

7474
.. autofunction:: IMDB
7575

76+
SST2
77+
~~~~
78+
79+
.. autofunction:: SST2
80+
7681

7782
Language Modeling
7883
^^^^^^^^^^^^^^^^^
@@ -152,4 +157,3 @@ EnWik9
152157
~~~~~~
153158

154159
.. autofunction:: EnWik9
155-

test/experimental/test_datasets.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

torchtext/datasets/__init__.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib
2+
23
from .ag_news import AG_NEWS
34
from .amazonreviewfull import AmazonReviewFull
45
from .amazonreviewpolarity import AmazonReviewPolarity
@@ -8,39 +9,41 @@
89
from .imdb import IMDB
910
from .iwslt2016 import IWSLT2016
1011
from .iwslt2017 import IWSLT2017
12+
from .multi30k import Multi30k
1113
from .penntreebank import PennTreebank
1214
from .sogounews import SogouNews
1315
from .squad1 import SQuAD1
1416
from .squad2 import SQuAD2
17+
from .sst2 import SST2
1518
from .udpos import UDPOS
1619
from .wikitext103 import WikiText103
1720
from .wikitext2 import WikiText2
1821
from .yahooanswers import YahooAnswers
1922
from .yelpreviewfull import YelpReviewFull
2023
from .yelpreviewpolarity import YelpReviewPolarity
21-
from .multi30k import Multi30k
2224

2325
DATASETS = {
24-
'AG_NEWS': AG_NEWS,
25-
'AmazonReviewFull': AmazonReviewFull,
26-
'AmazonReviewPolarity': AmazonReviewPolarity,
27-
'CoNLL2000Chunking': CoNLL2000Chunking,
28-
'DBpedia': DBpedia,
29-
'EnWik9': EnWik9,
30-
'IMDB': IMDB,
31-
'IWSLT2016': IWSLT2016,
32-
'IWSLT2017': IWSLT2017,
33-
'PennTreebank': PennTreebank,
34-
'SQuAD1': SQuAD1,
35-
'SQuAD2': SQuAD2,
36-
'SogouNews': SogouNews,
37-
'UDPOS': UDPOS,
38-
'WikiText103': WikiText103,
39-
'WikiText2': WikiText2,
40-
'YahooAnswers': YahooAnswers,
41-
'YelpReviewFull': YelpReviewFull,
42-
'YelpReviewPolarity': YelpReviewPolarity,
43-
'Multi30k': Multi30k
26+
"AG_NEWS": AG_NEWS,
27+
"AmazonReviewFull": AmazonReviewFull,
28+
"AmazonReviewPolarity": AmazonReviewPolarity,
29+
"CoNLL2000Chunking": CoNLL2000Chunking,
30+
"DBpedia": DBpedia,
31+
"EnWik9": EnWik9,
32+
"IMDB": IMDB,
33+
"IWSLT2016": IWSLT2016,
34+
"IWSLT2017": IWSLT2017,
35+
"Multi30k": Multi30k,
36+
"PennTreebank": PennTreebank,
37+
"SQuAD1": SQuAD1,
38+
"SQuAD2": SQuAD2,
39+
"SogouNews": SogouNews,
40+
"SST2": SST2,
41+
"UDPOS": UDPOS,
42+
"WikiText103": WikiText103,
43+
"WikiText2": WikiText2,
44+
"YahooAnswers": YahooAnswers,
45+
"YelpReviewFull": YelpReviewFull,
46+
"YelpReviewPolarity": YelpReviewPolarity,
4447
}
4548

4649
URLS = {}

torchtext/datasets/sst2.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import os
3+
4+
from torchtext._internal.module_utils import is_module_available
5+
from torchtext.data.datasets_utils import (
6+
_add_docstring_header,
7+
_create_dataset_directory,
8+
_wrap_split_argument,
9+
)
10+
11+
if is_module_available("torchdata"):
12+
from torchdata.datapipes.iter import IterableWrapper, FileOpener
13+
14+
# we import HttpReader from _download_hooks so we can swap out public URLs
15+
# with interal URLs when the dataset is used within Facebook
16+
from torchtext._download_hooks import HttpReader
17+
18+
19+
URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip"
20+
21+
MD5 = "9f81648d4199384278b86e315dac217c"
22+
23+
NUM_LINES = {
24+
"train": 67349,
25+
"dev": 872,
26+
"test": 1821,
27+
}
28+
29+
_PATH = "SST-2.zip"
30+
31+
DATASET_NAME = "SST2"
32+
33+
_EXTRACTED_FILES = {
34+
"train": os.path.join("SST-2", "train.tsv"),
35+
"dev": os.path.join("SST-2", "dev.tsv"),
36+
"test": os.path.join("SST-2", "test.tsv"),
37+
}
38+
39+
40+
@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
41+
@_create_dataset_directory(dataset_name=DATASET_NAME)
42+
@_wrap_split_argument(("train", "dev", "test"))
43+
def SST2(root, split):
44+
# TODO Remove this after removing conditional dependency
45+
if not is_module_available("torchdata"):
46+
raise ModuleNotFoundError(
47+
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
48+
)
49+
50+
url_dp = IterableWrapper([URL])
51+
cache_compressed_dp = url_dp.on_disk_cache(
52+
filepath_fn=lambda x: os.path.join(root, os.path.basename(x)),
53+
hash_dict={os.path.join(root, os.path.basename(URL)): MD5},
54+
hash_type="md5",
55+
)
56+
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(
57+
mode="wb", same_filepath_fn=True
58+
)
59+
60+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
61+
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
62+
)
63+
cache_decompressed_dp = (
64+
FileOpener(cache_decompressed_dp, mode="b")
65+
.read_from_zip()
66+
.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
67+
)
68+
cache_decompressed_dp = cache_decompressed_dp.end_caching(
69+
mode="wb", same_filepath_fn=True
70+
)
71+
72+
data_dp = FileOpener(cache_decompressed_dp, mode="b")
73+
# test split for SST2 doesn't have labels
74+
if split == "test":
75+
parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(
76+
lambda t: (t[1].strip(),)
77+
)
78+
else:
79+
parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(
80+
lambda t: (t[0].strip(), int(t[1]))
81+
)
82+
return parsed_data
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from . import raw
2-
from . import sst2
32

4-
__all__ = ["raw", "sst2"]
3+
__all__ = ["raw"]

torchtext/experimental/datasets/sst2.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

0 commit comments

Comments
 (0)