Skip to content

Commit 508c79d

Browse files
NicolasHugpmeier
andauthored
Add GTSRB dataset to prototypes (#5214)
Co-authored-by: Philip Meier <[email protected]>
1 parent 8886a3c commit 508c79d

File tree

6 files changed

+337
-34
lines changed

6 files changed

+337
-34
lines changed

test/builtin_dataset_mocks.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,76 @@ def fer2013(info, root, config):
10171017
return num_samples
10181018

10191019

1020+
@DATASET_MOCKS.set_from_named_callable
1021+
def gtsrb(info, root, config):
1022+
num_examples_per_class = 5 if config.split == "train" else 3
1023+
classes = ("00000", "00042", "00012")
1024+
num_examples = num_examples_per_class * len(classes)
1025+
1026+
csv_columns = ["Filename", "Width", "Height", "Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2", "ClassId"]
1027+
1028+
def _make_ann_file(path, num_examples, class_idx):
1029+
if class_idx == "random":
1030+
class_idx = torch.randint(1, len(classes) + 1, size=(1,)).item()
1031+
1032+
with open(path, "w") as csv_file:
1033+
writer = csv.DictWriter(csv_file, fieldnames=csv_columns, delimiter=";")
1034+
writer.writeheader()
1035+
for image_idx in range(num_examples):
1036+
writer.writerow(
1037+
{
1038+
"Filename": f"{image_idx:05d}.ppm",
1039+
"Width": torch.randint(1, 100, size=()).item(),
1040+
"Height": torch.randint(1, 100, size=()).item(),
1041+
"Roi.X1": torch.randint(1, 100, size=()).item(),
1042+
"Roi.Y1": torch.randint(1, 100, size=()).item(),
1043+
"Roi.X2": torch.randint(1, 100, size=()).item(),
1044+
"Roi.Y2": torch.randint(1, 100, size=()).item(),
1045+
"ClassId": class_idx,
1046+
}
1047+
)
1048+
1049+
if config["split"] == "train":
1050+
train_folder = root / "GTSRB" / "Training"
1051+
train_folder.mkdir(parents=True)
1052+
1053+
for class_idx in classes:
1054+
create_image_folder(
1055+
train_folder,
1056+
name=class_idx,
1057+
file_name_fn=lambda image_idx: f"{class_idx}_{image_idx:05d}.ppm",
1058+
num_examples=num_examples_per_class,
1059+
)
1060+
_make_ann_file(
1061+
path=train_folder / class_idx / f"GT-{class_idx}.csv",
1062+
num_examples=num_examples_per_class,
1063+
class_idx=int(class_idx),
1064+
)
1065+
make_zip(root, "GTSRB-Training_fixed.zip", train_folder)
1066+
else:
1067+
test_folder = root / "GTSRB" / "Final_Test"
1068+
test_folder.mkdir(parents=True)
1069+
1070+
create_image_folder(
1071+
test_folder,
1072+
name="Images",
1073+
file_name_fn=lambda image_idx: f"{image_idx:05d}.ppm",
1074+
num_examples=num_examples,
1075+
)
1076+
1077+
make_zip(root, "GTSRB_Final_Test_Images.zip", test_folder)
1078+
1079+
_make_ann_file(
1080+
path=root / "GT-final_test.csv",
1081+
num_examples=num_examples,
1082+
class_idx="random",
1083+
)
1084+
1085+
make_zip(root, "GTSRB_Final_Test_GT.zip", "GT-final_test.csv")
1086+
1087+
return num_examples
1088+
1089+
10201090
@DATASET_MOCKS.set_from_named_callable
10211091
def clevr(info, root, config):
10221092
data_folder = root / "CLEVR_v1.0"

test/datasets_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
881881
files, dirs = _split_files_or_dirs(root, *files_or_dirs)
882882

883883
with opener(archive) as fh:
884-
for file in files:
884+
for file in sorted(files):
885885
adder(fh, file, file.relative_to(root))
886886

887887
if remove:

test/test_prototype_builtin_datasets.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
from pathlib import Path
23

34
import pytest
45
import torch
@@ -123,7 +124,7 @@ def scan(graph):
123124
if type(dp) is annotation_dp_type:
124125
break
125126
else:
126-
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.")
127+
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
127128

128129

129130
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
@@ -143,3 +144,19 @@ def test_extra_label(self, dataset_mock, config):
143144
("unused", bool),
144145
):
145146
assert key in sample and isinstance(sample[key], type)
147+
148+
149+
@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"])
150+
class TestGTSRB:
151+
def test_label_matches_path(self, dataset_mock, config):
152+
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
153+
# This test makes sure that they're both the same
154+
if config.split != "train":
155+
return
156+
157+
with dataset_mock.prepare(config):
158+
dataset = datasets.load(dataset_mock.name, **config)
159+
160+
for sample in dataset:
161+
label_from_path = int(Path(sample["image_path"]).parent.name)
162+
assert sample["label"] == label_from_path

0 commit comments

Comments
 (0)