Skip to content

Commit aa36599

Browse files
authored
add tests for MNIST and variants (#3423)
* add tests for MNIST and variants * remove old tests and fakedata generation * fix default config detection for if dataset has variable keywords * use split="mnist" as default for EMNIST * fix QMNIST tests * lint * fix special kwargs detection * Revert "use split="mnist" as default for EMNIST" This reverts commit 62c9b23. * fix tests * fix QMNIST test case name * remove dead code from test * Revert "remove old tests and fakedata generation" This reverts commit a285b97. * remove old tests * readd removed import
1 parent 0818c68 commit aa36599

File tree

1 file changed

+128
-33
lines changed

1 file changed

+128
-33
lines changed

test/test_datasets.py

Lines changed: 128 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import torchvision
1111
from torchvision.datasets import utils
1212
from common_utils import get_tmp_dir
13-
from fakedata_generation import mnist_root, \
14-
cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root
13+
14+
from fakedata_generation import cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root
1515
import xml.etree.ElementTree as ET
1616
from urllib.request import Request, urlopen
1717
import itertools
@@ -119,37 +119,6 @@ def test_imagefolder_empty(self):
119119
root, loader=lambda x: x, is_valid_file=lambda x: False
120120
)
121121

122-
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
123-
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
124-
def test_mnist(self, mock_download_extract, mock_check_integrity):
125-
num_examples = 30
126-
with mnist_root(num_examples, "MNIST") as root:
127-
dataset = torchvision.datasets.MNIST(root, download=True)
128-
self.generic_classification_dataset_test(dataset, num_images=num_examples)
129-
img, target = dataset[0]
130-
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
131-
132-
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
133-
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
134-
def test_kmnist(self, mock_download_extract, mock_check_integrity):
135-
num_examples = 30
136-
with mnist_root(num_examples, "KMNIST") as root:
137-
dataset = torchvision.datasets.KMNIST(root, download=True)
138-
self.generic_classification_dataset_test(dataset, num_images=num_examples)
139-
img, target = dataset[0]
140-
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
141-
142-
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
143-
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
144-
def test_fashionmnist(self, mock_download_extract, mock_check_integrity):
145-
num_examples = 30
146-
with mnist_root(num_examples, "FashionMNIST") as root:
147-
dataset = torchvision.datasets.FashionMNIST(root, download=True)
148-
self.generic_classification_dataset_test(dataset, num_images=num_examples)
149-
img, target = dataset[0]
150-
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
151-
152-
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
153122
def test_cityscapes(self):
154123
with cityscapes_root() as root:
155124

@@ -1499,5 +1468,131 @@ def _create_annotations_file(self, root, name, images, num_captions_per_image):
14991468
fh.write(f"{image.name}#{idx}\t{caption}\n")
15001469

15011470

1471+
class MNISTTestCase(datasets_utils.ImageDatasetTestCase):
1472+
DATASET_CLASS = datasets.MNIST
1473+
1474+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
1475+
1476+
_MAGIC_DTYPES = {
1477+
torch.uint8: 8,
1478+
torch.int8: 9,
1479+
torch.int16: 11,
1480+
torch.int32: 12,
1481+
torch.float32: 13,
1482+
torch.float64: 14,
1483+
}
1484+
1485+
_IMAGES_SIZE = (28, 28)
1486+
_IMAGES_DTYPE = torch.uint8
1487+
1488+
_LABELS_SIZE = ()
1489+
_LABELS_DTYPE = torch.uint8
1490+
1491+
def inject_fake_data(self, tmpdir, config):
1492+
raw_dir = pathlib.Path(tmpdir) / self.DATASET_CLASS.__name__ / "raw"
1493+
os.makedirs(raw_dir, exist_ok=True)
1494+
1495+
num_images = self._num_images(config)
1496+
self._create_binary_file(
1497+
raw_dir, self._images_file(config), (num_images, *self._IMAGES_SIZE), self._IMAGES_DTYPE
1498+
)
1499+
self._create_binary_file(
1500+
raw_dir, self._labels_file(config), (num_images, *self._LABELS_SIZE), self._LABELS_DTYPE
1501+
)
1502+
return num_images
1503+
1504+
def _num_images(self, config):
1505+
return 2 if config["train"] else 1
1506+
1507+
def _images_file(self, config):
1508+
return f"{self._prefix(config)}-images-idx3-ubyte"
1509+
1510+
def _labels_file(self, config):
1511+
return f"{self._prefix(config)}-labels-idx1-ubyte"
1512+
1513+
def _prefix(self, config):
1514+
return "train" if config["train"] else "t10k"
1515+
1516+
def _create_binary_file(self, root, filename, size, dtype):
1517+
with open(pathlib.Path(root) / filename, "wb") as fh:
1518+
for meta in (self._magic(dtype, len(size)), *size):
1519+
fh.write(self._encode(meta))
1520+
1521+
# If ever an MNIST variant is added that uses floating point data, this should be adapted.
1522+
data = torch.randint(0, torch.iinfo(dtype).max + 1, size, dtype=dtype)
1523+
fh.write(data.numpy().tobytes())
1524+
1525+
def _magic(self, dtype, dims):
1526+
return self._MAGIC_DTYPES[dtype] * 256 + dims
1527+
1528+
def _encode(self, v):
1529+
return torch.tensor(v, dtype=torch.int32).numpy().tobytes()[::-1]
1530+
1531+
1532+
class FashionMNISTTestCase(MNISTTestCase):
1533+
DATASET_CLASS = datasets.FashionMNIST
1534+
1535+
1536+
class KMNISTTestCase(MNISTTestCase):
1537+
DATASET_CLASS = datasets.KMNIST
1538+
1539+
1540+
class EMNISTTestCase(MNISTTestCase):
1541+
DATASET_CLASS = datasets.EMNIST
1542+
1543+
DEFAULT_CONFIG = dict(split="byclass")
1544+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
1545+
split=("byclass", "bymerge", "balanced", "letters", "digits", "mnist"), train=(True, False)
1546+
)
1547+
1548+
def _prefix(self, config):
1549+
return f"emnist-{config['split']}-{'train' if config['train'] else 'test'}"
1550+
1551+
1552+
class QMNISTTestCase(MNISTTestCase):
1553+
DATASET_CLASS = datasets.QMNIST
1554+
1555+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(what=("train", "test", "test10k", "nist"))
1556+
1557+
_LABELS_SIZE = (8,)
1558+
_LABELS_DTYPE = torch.int32
1559+
1560+
def _num_images(self, config):
1561+
if config["what"] == "nist":
1562+
return 3
1563+
elif config["what"] == "train":
1564+
return 2
1565+
elif config["what"] == "test50k":
1566+
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create
1567+
# more than 10000 images for the dataset to not be empty. Since this takes significantly longer than the
1568+
# creation of all other splits, this is excluded from the 'ADDITIONAL_CONFIGS' and is tested only once in
1569+
# 'test_num_examples_test50k'.
1570+
return 10001
1571+
else:
1572+
return 1
1573+
1574+
def _labels_file(self, config):
1575+
return f"{self._prefix(config)}-labels-idx2-int"
1576+
1577+
def _prefix(self, config):
1578+
if config["what"] == "nist":
1579+
return "xnist"
1580+
1581+
if config["what"] is None:
1582+
what = "train" if config["train"] else "test"
1583+
elif config["what"].startswith("test"):
1584+
what = "test"
1585+
else:
1586+
what = config["what"]
1587+
1588+
return f"qmnist-{what}"
1589+
1590+
def test_num_examples_test50k(self):
1591+
with self.create_dataset(what="test50k") as (dataset, info):
1592+
# Since the split 'test50k' selects all images beginning from the index 10000, we subtract the number of
1593+
# created examples by this.
1594+
self.assertEqual(len(dataset), info["num_examples"] - 10000)
1595+
1596+
15021597
if __name__ == "__main__":
15031598
unittest.main()

0 commit comments

Comments
 (0)