Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 128 additions & 33 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, \
cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root

from fakedata_generation import cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
Expand Down Expand Up @@ -119,37 +119,6 @@ def test_imagefolder_empty(self):
root, loader=lambda x: x, is_valid_file=lambda x: False
)

@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_mnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "MNIST") as root:
dataset = torchvision.datasets.MNIST(root, download=True)
self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)

@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_kmnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "KMNIST") as root:
dataset = torchvision.datasets.KMNIST(root, download=True)
self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)

@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_fashionmnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "FashionMNIST") as root:
dataset = torchvision.datasets.FashionMNIST(root, download=True)
self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)

@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_cityscapes(self):
with cityscapes_root() as root:

Expand Down Expand Up @@ -1499,5 +1468,131 @@ def _create_annotations_file(self, root, name, images, num_captions_per_image):
fh.write(f"{image.name}#{idx}\t{caption}\n")


class MNISTTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.MNIST

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))

_MAGIC_DTYPES = {
torch.uint8: 8,
torch.int8: 9,
torch.int16: 11,
torch.int32: 12,
torch.float32: 13,
torch.float64: 14,
}

_IMAGES_SIZE = (28, 28)
_IMAGES_DTYPE = torch.uint8

_LABELS_SIZE = ()
_LABELS_DTYPE = torch.uint8

def inject_fake_data(self, tmpdir, config):
raw_dir = pathlib.Path(tmpdir) / self.DATASET_CLASS.__name__ / "raw"
os.makedirs(raw_dir, exist_ok=True)

num_images = self._num_images(config)
self._create_binary_file(
raw_dir, self._images_file(config), (num_images, *self._IMAGES_SIZE), self._IMAGES_DTYPE
)
self._create_binary_file(
raw_dir, self._labels_file(config), (num_images, *self._LABELS_SIZE), self._LABELS_DTYPE
)
return num_images

def _num_images(self, config):
return 2 if config["train"] else 1

def _images_file(self, config):
return f"{self._prefix(config)}-images-idx3-ubyte"

def _labels_file(self, config):
return f"{self._prefix(config)}-labels-idx1-ubyte"

def _prefix(self, config):
return "train" if config["train"] else "t10k"

def _create_binary_file(self, root, filename, size, dtype):
with open(pathlib.Path(root) / filename, "wb") as fh:
for meta in (self._magic(dtype, len(size)), *size):
fh.write(self._encode(meta))

# If ever an MNIST variant is added that uses floating point data, this should be adapted.
data = torch.randint(0, torch.iinfo(dtype).max + 1, size, dtype=dtype)
fh.write(data.numpy().tobytes())

def _magic(self, dtype, dims):
return self._MAGIC_DTYPES[dtype] * 256 + dims

def _encode(self, v):
return torch.tensor(v, dtype=torch.int32).numpy().tobytes()[::-1]


class FashionMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.FashionMNIST


class KMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.KMNIST


class EMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.EMNIST

DEFAULT_CONFIG = dict(split="byclass")
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("byclass", "bymerge", "balanced", "letters", "digits", "mnist"), train=(True, False)
)

def _prefix(self, config):
return f"emnist-{config['split']}-{'train' if config['train'] else 'test'}"


class QMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.QMNIST

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(what=("train", "test", "test10k", "nist"))

_LABELS_SIZE = (8,)
_LABELS_DTYPE = torch.int32

def _num_images(self, config):
if config["what"] == "nist":
return 3
elif config["what"] == "train":
return 2
elif config["what"] == "test50k":
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create
# more than 10000 images for the dataset to not be empty. Since this takes significantly longer than the
# creation of all other splits, this is excluded from the 'ADDITIONAL_CONFIGS' and is tested only once in
# 'test_num_examples_test50k'.
return 10001
else:
return 1

def _labels_file(self, config):
return f"{self._prefix(config)}-labels-idx2-int"

def _prefix(self, config):
if config["what"] == "nist":
return "xnist"

if config["what"] is None:
what = "train" if config["train"] else "test"
elif config["what"].startswith("test"):
what = "test"
else:
what = config["what"]

return f"qmnist-{what}"

def test_num_examples_test50k(self):
with self.create_dataset(what="test50k") as (dataset, info):
# Since the split 'test50k' selects all images beginning from the index 10000, we subtract the number of
# created examples by this.
self.assertEqual(len(dataset), info["num_examples"] - 10000)


if __name__ == "__main__":
unittest.main()