Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit ec20f88

Browse files
authored
Add support for CoLA dataset with unit tests (#1711)
* Add support for CoLA dataset + unit tests * Better test with differentiated rand_string * Remove lambda functions * Add dataset documentation * Add shuffle and sharding
1 parent 2a712f4 commit ec20f88

File tree

4 files changed

+171
-0
lines changed

4 files changed

+171
-0
lines changed

docs/source/datasets.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ AmazonReviewPolarity
4242

4343
.. autofunction:: AmazonReviewPolarity
4444

45+
CoLA
46+
~~~~~~~~~~~~~~~~~~~~
47+
48+
.. autofunction:: CoLA
49+
4550
DBpedia
4651
~~~~~~~
4752

test/datasets/test_cola.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import os
2+
import zipfile
3+
from collections import defaultdict
4+
from unittest.mock import patch
5+
6+
from parameterized import parameterized
7+
from torchtext.datasets.cola import CoLA
8+
9+
from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode
10+
from ..common.torchtext_test_case import TorchtextTestCase
11+
12+
13+
def _get_mock_dataset(root_dir):
14+
"""
15+
root_dir: directory to the mocked dataset
16+
"""
17+
base_dir = os.path.join(root_dir, "CoLA")
18+
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir")
19+
os.makedirs(temp_dataset_dir, exist_ok=True)
20+
21+
seed = 1
22+
mocked_data = defaultdict(list)
23+
for file_name in ("in_domain_train.tsv", "in_domain_dev.tsv", "out_of_domain_dev.tsv"):
24+
txt_file = os.path.join(temp_dataset_dir, file_name)
25+
with open(txt_file, "w", encoding="utf-8") as f:
26+
for _ in range(5):
27+
label = seed % 2
28+
rand_string_1 = get_random_unicode(seed)
29+
rand_string_2 = get_random_unicode(seed + 1)
30+
dataset_line = (rand_string_1, label, rand_string_2)
31+
# append line to correct dataset split
32+
mocked_data[os.path.splitext(file_name)[0]].append(dataset_line)
33+
f.write(f'"{rand_string_1}"\t"{label}"\t"{rand_string_2}"\n')
34+
seed += 1
35+
36+
compressed_dataset_path = os.path.join(base_dir, "cola_public_1.1.zip")
37+
# create zip file from dataset folder
38+
with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file:
39+
for file_name in ("in_domain_train.tsv", "in_domain_dev.tsv", "out_of_domain_dev.tsv"):
40+
txt_file = os.path.join(temp_dataset_dir, file_name)
41+
zip_file.write(txt_file, arcname=os.path.join("cola_public", "raw", file_name))
42+
43+
return mocked_data
44+
45+
46+
class TestCoLA(TempDirMixin, TorchtextTestCase):
47+
root_dir = None
48+
samples = []
49+
50+
@classmethod
51+
def setUpClass(cls):
52+
super().setUpClass()
53+
cls.root_dir = cls.get_base_temp_dir()
54+
cls.samples = _get_mock_dataset(cls.root_dir)
55+
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True)
56+
cls.patcher.start()
57+
58+
@classmethod
59+
def tearDownClass(cls):
60+
cls.patcher.stop()
61+
super().tearDownClass()
62+
63+
@parameterized.expand(["train", "test", "dev"])
64+
def test_cola(self, split):
65+
dataset = CoLA(root=self.root_dir, split=split)
66+
67+
samples = list(dataset)
68+
expected_samples = self.samples[split]
69+
for sample, expected_sample in zip_equal(samples, expected_samples):
70+
self.assertEqual(sample, expected_sample)
71+
72+
@parameterized.expand(["train", "test", "dev"])
73+
def test_cola_split_argument(self, split):
74+
dataset1 = CoLA(root=self.root_dir, split=split)
75+
(dataset2,) = CoLA(root=self.root_dir, split=(split,))
76+
77+
for d1, d2 in zip_equal(dataset1, dataset2):
78+
self.assertEqual(d1, d2)

torchtext/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .amazonreviewfull import AmazonReviewFull
55
from .amazonreviewpolarity import AmazonReviewPolarity
66
from .cc100 import CC100
7+
from .cola import CoLA
78
from .conll2000chunking import CoNLL2000Chunking
89
from .dbpedia import DBpedia
910
from .enwik9 import EnWik9
@@ -28,6 +29,7 @@
2829
"AmazonReviewFull": AmazonReviewFull,
2930
"AmazonReviewPolarity": AmazonReviewPolarity,
3031
"CC100": CC100,
32+
"CoLA": CoLA,
3133
"CoNLL2000Chunking": CoNLL2000Chunking,
3234
"DBpedia": DBpedia,
3335
"EnWik9": EnWik9,

torchtext/datasets/cola.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import csv
2+
import os
3+
from typing import Union, Tuple
4+
5+
from torchtext._internal.module_utils import is_module_available
6+
from torchtext.data.datasets_utils import _create_dataset_directory, _wrap_split_argument
7+
8+
if is_module_available("torchdata"):
9+
from torchdata.datapipes.iter import FileOpener, IterableWrapper
10+
from torchtext._download_hooks import HttpReader
11+
12+
URL = "https://nyu-mll.github.io/CoLA/cola_public_1.1.zip"
13+
14+
MD5 = "9f6d88c3558ec424cd9d66ea03589aba"
15+
16+
_PATH = "cola_public_1.1.zip"
17+
18+
NUM_LINES = {"train": 8551, "dev": 527, "test": 516}
19+
20+
_EXTRACTED_FILES = {
21+
"train": os.path.join("cola_public", "raw", "in_domain_train.tsv"),
22+
"dev": os.path.join("cola_public", "raw", "in_domain_dev.tsv"),
23+
"test": os.path.join("cola_public", "raw", "out_of_domain_dev.tsv"),
24+
}
25+
26+
DATASET_NAME = "CoLA"
27+
28+
29+
@_create_dataset_directory(dataset_name=DATASET_NAME)
30+
@_wrap_split_argument(("train", "dev", "test"))
31+
def CoLA(root: str, split: Union[Tuple[str], str]):
32+
"""CoLA dataset
33+
34+
For additional details refer to https://nyu-mll.github.io/CoLA/
35+
36+
Number of lines per split:
37+
- train: 8551
38+
- dev: 527
39+
- test: 516
40+
41+
Args:
42+
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache')
43+
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`)
44+
45+
46+
:returns: DataPipe that yields rows from CoLA dataset (source (str), label (int), sentence (str))
47+
:rtype: (str, int, str)
48+
"""
49+
if not is_module_available("torchdata"):
50+
raise ModuleNotFoundError(
51+
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
52+
)
53+
54+
def _filepath_fn(_=None):
55+
return os.path.join(root, _PATH)
56+
57+
def _extracted_filepath_fn(_=None):
58+
return os.path.join(root, _EXTRACTED_FILES[split])
59+
60+
def _filter_fn(x):
61+
return _EXTRACTED_FILES[split] in x[0]
62+
63+
def _modify_res(t):
64+
return (t[0], int(t[1]), t[3])
65+
66+
def _filter_res(x):
67+
return len(x) == 4
68+
69+
url_dp = IterableWrapper([URL])
70+
cache_compressed_dp = url_dp.on_disk_cache(
71+
filepath_fn=_filepath_fn,
72+
hash_dict={_filepath_fn(): MD5},
73+
hash_type="md5",
74+
)
75+
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)
76+
77+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
78+
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn)
79+
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
80+
81+
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
82+
# some context stored at top of the file needs to be removed
83+
parsed_data = (
84+
data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res)
85+
)
86+
return parsed_data.shuffle().set_shuffle(False).sharding_filter()

0 commit comments

Comments
 (0)