|
10 | 10 | import torchvision |
11 | 11 | from torchvision.datasets import utils |
12 | 12 | from common_utils import get_tmp_dir |
13 | | -from fakedata_generation import mnist_root, \ |
14 | | - cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root |
| 13 | + |
| 14 | +from fakedata_generation import cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root |
15 | 15 | import xml.etree.ElementTree as ET |
16 | 16 | from urllib.request import Request, urlopen |
17 | 17 | import itertools |
@@ -119,37 +119,6 @@ def test_imagefolder_empty(self): |
119 | 119 | root, loader=lambda x: x, is_valid_file=lambda x: False |
120 | 120 | ) |
121 | 121 |
|
122 | | - @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') |
123 | | - @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) |
124 | | - def test_mnist(self, mock_download_extract, mock_check_integrity): |
125 | | - num_examples = 30 |
126 | | - with mnist_root(num_examples, "MNIST") as root: |
127 | | - dataset = torchvision.datasets.MNIST(root, download=True) |
128 | | - self.generic_classification_dataset_test(dataset, num_images=num_examples) |
129 | | - img, target = dataset[0] |
130 | | - self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) |
131 | | - |
132 | | - @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') |
133 | | - @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) |
134 | | - def test_kmnist(self, mock_download_extract, mock_check_integrity): |
135 | | - num_examples = 30 |
136 | | - with mnist_root(num_examples, "KMNIST") as root: |
137 | | - dataset = torchvision.datasets.KMNIST(root, download=True) |
138 | | - self.generic_classification_dataset_test(dataset, num_images=num_examples) |
139 | | - img, target = dataset[0] |
140 | | - self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) |
141 | | - |
142 | | - @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') |
143 | | - @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) |
144 | | - def test_fashionmnist(self, mock_download_extract, mock_check_integrity): |
145 | | - num_examples = 30 |
146 | | - with mnist_root(num_examples, "FashionMNIST") as root: |
147 | | - dataset = torchvision.datasets.FashionMNIST(root, download=True) |
148 | | - self.generic_classification_dataset_test(dataset, num_images=num_examples) |
149 | | - img, target = dataset[0] |
150 | | - self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) |
151 | | - |
152 | | - @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows') |
153 | 122 | def test_cityscapes(self): |
154 | 123 | with cityscapes_root() as root: |
155 | 124 |
|
@@ -1499,5 +1468,131 @@ def _create_annotations_file(self, root, name, images, num_captions_per_image): |
1499 | 1468 | fh.write(f"{image.name}#{idx}\t{caption}\n") |
1500 | 1469 |
|
1501 | 1470 |
|
| 1471 | +class MNISTTestCase(datasets_utils.ImageDatasetTestCase): |
| 1472 | + DATASET_CLASS = datasets.MNIST |
| 1473 | + |
| 1474 | + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False)) |
| 1475 | + |
| 1476 | + _MAGIC_DTYPES = { |
| 1477 | + torch.uint8: 8, |
| 1478 | + torch.int8: 9, |
| 1479 | + torch.int16: 11, |
| 1480 | + torch.int32: 12, |
| 1481 | + torch.float32: 13, |
| 1482 | + torch.float64: 14, |
| 1483 | + } |
| 1484 | + |
| 1485 | + _IMAGES_SIZE = (28, 28) |
| 1486 | + _IMAGES_DTYPE = torch.uint8 |
| 1487 | + |
| 1488 | + _LABELS_SIZE = () |
| 1489 | + _LABELS_DTYPE = torch.uint8 |
| 1490 | + |
| 1491 | + def inject_fake_data(self, tmpdir, config): |
| 1492 | + raw_dir = pathlib.Path(tmpdir) / self.DATASET_CLASS.__name__ / "raw" |
| 1493 | + os.makedirs(raw_dir, exist_ok=True) |
| 1494 | + |
| 1495 | + num_images = self._num_images(config) |
| 1496 | + self._create_binary_file( |
| 1497 | + raw_dir, self._images_file(config), (num_images, *self._IMAGES_SIZE), self._IMAGES_DTYPE |
| 1498 | + ) |
| 1499 | + self._create_binary_file( |
| 1500 | + raw_dir, self._labels_file(config), (num_images, *self._LABELS_SIZE), self._LABELS_DTYPE |
| 1501 | + ) |
| 1502 | + return num_images |
| 1503 | + |
| 1504 | + def _num_images(self, config): |
| 1505 | + return 2 if config["train"] else 1 |
| 1506 | + |
| 1507 | + def _images_file(self, config): |
| 1508 | + return f"{self._prefix(config)}-images-idx3-ubyte" |
| 1509 | + |
| 1510 | + def _labels_file(self, config): |
| 1511 | + return f"{self._prefix(config)}-labels-idx1-ubyte" |
| 1512 | + |
| 1513 | + def _prefix(self, config): |
| 1514 | + return "train" if config["train"] else "t10k" |
| 1515 | + |
| 1516 | + def _create_binary_file(self, root, filename, size, dtype): |
| 1517 | + with open(pathlib.Path(root) / filename, "wb") as fh: |
| 1518 | + for meta in (self._magic(dtype, len(size)), *size): |
| 1519 | + fh.write(self._encode(meta)) |
| 1520 | + |
| 1521 | + # If ever an MNIST variant is added that uses floating point data, this should be adapted. |
| 1522 | + data = torch.randint(0, torch.iinfo(dtype).max + 1, size, dtype=dtype) |
| 1523 | + fh.write(data.numpy().tobytes()) |
| 1524 | + |
| 1525 | + def _magic(self, dtype, dims): |
| 1526 | + return self._MAGIC_DTYPES[dtype] * 256 + dims |
| 1527 | + |
| 1528 | + def _encode(self, v): |
| 1529 | + return torch.tensor(v, dtype=torch.int32).numpy().tobytes()[::-1] |
| 1530 | + |
| 1531 | + |
| 1532 | +class FashionMNISTTestCase(MNISTTestCase): |
| 1533 | + DATASET_CLASS = datasets.FashionMNIST |
| 1534 | + |
| 1535 | + |
| 1536 | +class KMNISTTestCase(MNISTTestCase): |
| 1537 | + DATASET_CLASS = datasets.KMNIST |
| 1538 | + |
| 1539 | + |
| 1540 | +class EMNISTTestCase(MNISTTestCase): |
| 1541 | + DATASET_CLASS = datasets.EMNIST |
| 1542 | + |
| 1543 | + DEFAULT_CONFIG = dict(split="byclass") |
| 1544 | + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( |
| 1545 | + split=("byclass", "bymerge", "balanced", "letters", "digits", "mnist"), train=(True, False) |
| 1546 | + ) |
| 1547 | + |
| 1548 | + def _prefix(self, config): |
| 1549 | + return f"emnist-{config['split']}-{'train' if config['train'] else 'test'}" |
| 1550 | + |
| 1551 | + |
| 1552 | +class QMNISTTestCase(MNISTTestCase): |
| 1553 | + DATASET_CLASS = datasets.QMNIST |
| 1554 | + |
| 1555 | + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(what=("train", "test", "test10k", "nist")) |
| 1556 | + |
| 1557 | + _LABELS_SIZE = (8,) |
| 1558 | + _LABELS_DTYPE = torch.int32 |
| 1559 | + |
| 1560 | + def _num_images(self, config): |
| 1561 | + if config["what"] == "nist": |
| 1562 | + return 3 |
| 1563 | + elif config["what"] == "train": |
| 1564 | + return 2 |
| 1565 | + elif config["what"] == "test50k": |
| 1566 | + # The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create |
| 1567 | + # more than 10000 images for the dataset to not be empty. Since this takes significantly longer than the |
| 1568 | + # creation of all other splits, this is excluded from the 'ADDITIONAL_CONFIGS' and is tested only once in |
| 1569 | + # 'test_num_examples_test50k'. |
| 1570 | + return 10001 |
| 1571 | + else: |
| 1572 | + return 1 |
| 1573 | + |
| 1574 | + def _labels_file(self, config): |
| 1575 | + return f"{self._prefix(config)}-labels-idx2-int" |
| 1576 | + |
| 1577 | + def _prefix(self, config): |
| 1578 | + if config["what"] == "nist": |
| 1579 | + return "xnist" |
| 1580 | + |
| 1581 | + if config["what"] is None: |
| 1582 | + what = "train" if config["train"] else "test" |
| 1583 | + elif config["what"].startswith("test"): |
| 1584 | + what = "test" |
| 1585 | + else: |
| 1586 | + what = config["what"] |
| 1587 | + |
| 1588 | + return f"qmnist-{what}" |
| 1589 | + |
| 1590 | + def test_num_examples_test50k(self): |
| 1591 | + with self.create_dataset(what="test50k") as (dataset, info): |
| 1592 | + # Since the split 'test50k' selects all images beginning from the index 10000, we subtract the number of |
| 1593 | + # created examples by this. |
| 1594 | + self.assertEqual(len(dataset), info["num_examples"] - 10000) |
| 1595 | + |
| 1596 | + |
1502 | 1597 | if __name__ == "__main__": |
1503 | 1598 | unittest.main() |
0 commit comments