From 8941f7dd42e3ba66cd4e4423da49dbd98d098e46 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 30 Sep 2021 15:41:06 +0200 Subject: [PATCH 1/8] add prototype datasets for MNIST and variants --- .../prototype/datasets/_builtin/__init__.py | 1 + .../prototype/datasets/_builtin/mnist.py | 374 ++++++++++++++++++ .../prototype/datasets/utils/_internal.py | 88 ++++- 3 files changed, 461 insertions(+), 2 deletions(-) create mode 100644 torchvision/prototype/datasets/_builtin/mnist.py diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 7d6961fa920..2b4b1df20c9 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -1 +1,2 @@ from .caltech import Caltech101, Caltech256 +from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py new file mode 100644 index 00000000000..172db2762d1 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -0,0 +1,374 @@ +import abc +import codecs +import functools +import io +import operator +import pathlib +import string +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple + +import numpy as np +import torch +from torch.utils.data import IterDataPipe +from torch.utils.data.datapipes.iter import ( + Demultiplexer, + Mapper, + ZipArchiveReader, + Zipper, + Shuffler, +) + +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, +) +from torchvision.prototype.datasets.utils._internal import ( + image_buffer_from_array, + Decompressor, + Slicer, + INFINITE_BUFFER_SIZE, +) + + +__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] + +prod = functools.partial(functools.reduce, operator.mul) + + +class MNISTFileReader(IterDataPipe): + _DTYPE_MAP = { + 8: "u1", # uint8 + 9: "i1", # int8 + 11: "i2", # int16 + 12: "i4", # int32 + 13: "f4", # float32 + 14: "f8", # float64 + } + + def __init__(self, datapipe: IterDataPipe) -> None: + self.datapipe = datapipe + + @staticmethod + def _decode(bytes): + return int(codecs.encode(bytes, "hex"), 16) + + def __iter__(self) -> Iterator[np.ndarray]: + for _, file in self.datapipe: + magic = self._decode(file.read(4)) + dtype_type = self._DTYPE_MAP[magic // 256] + ndim = magic % 256 - 1 + + num_samples = self._decode(file.read(4)) + shape = [self._decode(file.read(4)) for _ in range(ndim)] + + in_dtype = np.dtype(f">{dtype_type}") + out_dtype = np.dtype(dtype_type) + chunk_size = (prod(shape) if shape else 1) * in_dtype.itemsize + for _ in range(num_samples): + chunk = file.read(chunk_size) + yield np.frombuffer(chunk, dtype=in_dtype).astype(out_dtype).reshape(shape) + + +class _MNISTBase(Dataset): + _FORMAT = "png" + _URL_BASE: str + + @abc.abstractmethod + def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: + pass + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + (images_file, images_sha256), ( + labels_file, + labels_sha256, + ) = self._files_and_checksums(config) + + images = HttpResource(f"{self._URL_BASE}/{images_file}", sha256=images_sha256) + labels = HttpResource(f"{self._URL_BASE}/{labels_file}", sha256=labels_sha256) + + return [images, labels] + + def _collate_and_decode( + self, + data: Tuple[np.ndarray, np.ndarray], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ): + image_array, label_array = data + + image_buffer = image_buffer_from_array(image_array) + image = decoder(image_buffer) if decoder else image_buffer + + label = torch.tensor(label_array, dtype=torch.int64) + category = self.info.categories[int(label)] + + return dict(image=image, label=label, category=category) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + images_dp, labels_dp = resource_dps + + images_dp = Decompressor(images_dp) + images_dp = MNISTFileReader(images_dp) + + labels_dp = Decompressor(labels_dp) + labels_dp = MNISTFileReader(labels_dp) + + dp: IterDataPipe = Zipper(images_dp, labels_dp) + dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) + return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) + + +class MNIST(_MNISTBase): + @property + def info(self): + return DatasetInfo( + "mnist", + categories=10, + homepage="http://yann.lecun.com/exdb/mnist", + valid_options=dict( + split=("train", "test"), + ), + ) + + _URL_BASE = "http://yann.lecun.com/exdb/mnist" + _CHECKSUMS = { + "train-images-idx3-ubyte.gz": "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609", + "train-labels-idx1-ubyte.gz": "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c", + "t10k-images-idx3-ubyte.gz": "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6", + "t10k-labels-idx1-ubyte.gz": "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6", + } + + def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = "train" if config.split == "train" else "t10k" + images_file = f"{prefix}-images-idx3-ubyte.gz" + labels_file = f"{prefix}-labels-idx1-ubyte.gz" + return (images_file, self._CHECKSUMS[images_file]), ( + labels_file, + self._CHECKSUMS[labels_file], + ) + + +class FashionMNIST(MNIST): + @property + def info(self): + return DatasetInfo( + "fashionmnist", + categories=( + "T-shirt/top", + "Trouser", + "Pullover", + "Dress", + "Coat", + "Sandal", + "Shirt", + "Sneaker", + "Bag", + "Ankle boot", + ), + homepage="https://github.com/zalandoresearch/fashion-mnist", + valid_options=dict( + split=("train", "test"), + ), + ) + + _URL_BASE = "fashion-mnist.s3-website.eu-central-1.amazonaws.com/" + _CHECKSUMS = { + "train-images-idx3-ubyte.gz": "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84", + "train-labels-idx1-ubyte.gz": "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845", + "t10k-images-idx3-ubyte.gz": "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073", + "t10k-labels-idx1-ubyte.gz": "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5", + } + + +class KMNIST(MNIST): + @property + def info(self): + return DatasetInfo( + "kmnist", + categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], + homepage="http://codh.rois.ac.jp/kmnist/index.html.en", + valid_options=dict( + split=("train", "test"), + ), + ) + + _URL_BASE = "http://codh.rois.ac.jp/kmnist/index.html.en" + _CHECKSUMS = { + "train-images-idx3-ubyte.gz": "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4", + "train-labels-idx1-ubyte.gz": "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17", + "t10k-images-idx3-ubyte.gz": "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5", + "t10k-labels-idx1-ubyte.gz": "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c", + } + + +class EMNIST(_MNISTBase): + @property + def info(self): + return DatasetInfo( + "emnist", + # FIXME: shift the labels at runtime to always return a static label + categories=list(string.digits + string.ascii_letters), + homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", + valid_options=dict( + split=("train", "test"), + image_set=( + "mnist", + "byclass", + "bymerge", + "balanced", + "digits", + "letters", + ), + ), + ) + + _URL_BASE = "https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST" + + def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = f"emnist-{config.image_set}-{config.split}" + images_file = f"{prefix}-images-idx3-ubyte.gz" + labels_file = f"{prefix}-labels-idx1-ubyte.gz" + # Since EMNIST provides the data files inside an archive, we don't need provide checksums for them + return (images_file, ""), (labels_file, "") + + def resources(self, config: Optional[DatasetConfig] = None) -> List[OnlineResource]: + return [ + HttpResource( + f"{self._URL_BASE}/emnist-gzip.zip", + sha256="909a2a39c5e86bdd7662425e9b9c4a49bb582bf8d0edad427f3c3a9d0c6f7259", + ) + ] + + def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]: + path = pathlib.Path(data[0]) + (images_file, _), (labels_file, _) = self._files_and_checksums(config) + if path.name == images_file: + return 0 + elif path.name == labels_file: + return 1 + else: + return None + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + archive_dp = resource_dps[0] + archive_dp = ZipArchiveReader(archive_dp) + images_dp, labels_dp = Demultiplexer( + archive_dp, + 2, + functools.partial(self._classify_archive, config=config), # type:ignore[arg-type] + drop_none=True, + buffer_size=INFINITE_BUFFER_SIZE, + ) + return super()._make_datapipe([images_dp, labels_dp], config=config, decoder=decoder) + + +class QMNIST(_MNISTBase): + @property + def info(self): + return DatasetInfo( + "qmnist", + categories=10, + homepage="https://github.com/facebookresearch/qmnist", + valid_options=dict( + split=("train", "test", "test10k", "test50k", "nist"), + ), + ) + + _URL_BASE = "https://raw.githubusercontent.com/facebookresearch/qmnist/master" + _CHECKSUMS = { + "qmnist-train-images-idx3-ubyte.gz": "9e26a7bf1683614e065d7b76460ccd52807165b3f22561fb782bd9f38c52b51d", + "qmnist-train-labels-idx2-int.gz": "2c05dc77f6b916b38e455e97ab129a42a444f3dbef09b278a366f82904e0dd9f", + "qmnist-test-images-idx3-ubyte.gz": "43fc22bf7498b8fc98de98369d72f752d0deabc280a43a7bcc364ab19e57b375", + "qmnist-test-labels-idx2-int.gz": "9fbcbe594c3766fdf4f0b15c5165dc0d1e57ac604e01422608bb72c906030d06", + "xnist-images-idx3-ubyte.xz": "f075553993026d4359ded42208eff77a1941d3963c1eff49d6015814f15f0984", + "xnist-labels-idx2-int.xz": "db042968723ec2b7aed5f1beac25d2b6e983b9286d4f4bf725f1086e5ae55c4f", + } + + def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = "xnist" if config.split == "nist" else f"qmnist-{'train' if config.split== 'train' else 'test'}" + suffix = "xz" if config.split == "nist" else "gz" + images_file = f"{prefix}-images-idx3-ubyte.{suffix}" + labels_file = f"{prefix}-labels-idx2-int.{suffix}" + return (images_file, self._CHECKSUMS[images_file]), ( + labels_file, + self._CHECKSUMS[labels_file], + ) + + def _split_label(self, sample: Dict[str, Any]) -> Dict[str, Any]: + parts = [part.squeeze(0) for part in sample.pop("label").split(1)] + sample.update( + dict( + zip( + ( + "label", + "nist_hsf_series", + "nist_writer_id", + "digit_index", + "nist_label", + "global_digit_index", + ), + parts[:6], + ) + ) + ) + sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in parts[-2:]]))) + return sample + + def _collate_and_decode( + self, + data: Tuple[np.ndarray, np.ndarray], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ): + image_array, label_array = data + label_parts = label_array.tolist() + sample = super()._collate_and_decode((image_array, label_parts[0]), decoder=decoder) + + sample.update( + dict( + zip( + ("nist_hsf_series", "nist_writer_id", "digit_index", "nist_label", "global_digit_index"), + label_parts[1:6], + ) + ) + ) + sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in label_parts[-2:]]))) + return sample + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + dp = super()._make_datapipe(resource_dps, config=config, decoder=decoder) + # dp = Mapper(dp, self._split_label) + if config.split not in ("test10k", "test50k"): + return dp + + start: Optional[int] + stop: Optional[int] + if config.split == "test10k": + start = 0 + stop = 10000 + else: # config.split == "test50k" + start = 10000 + stop = None + + return Slicer(dp, start=start, stop=stop) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 56c9a2d8c07..3f184e62836 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -1,8 +1,17 @@ import collections.abc import difflib +import enum +import gzip import io +import itertools +import lzma +import os.path import pathlib -from typing import Collection, Sequence, Callable, Union, Any +from typing import Collection, Sequence, Callable, Any, Iterator, Optional, Tuple, TypeVar, Union + +import numpy as np +import PIL.Image +from torch.utils.data import IterDataPipe __all__ = [ @@ -10,12 +19,17 @@ "sequence_to_str", "add_suggestion", "create_categories_file", - "read_mat" + "read_mat", + "image_buffer_from_array", + "Decompressor", + "Slicer", ] # pseudo-infinite until a true infinite buffer is supported by all datapipes INFINITE_BUFFER_SIZE = 1_000_000_000 +D = TypeVar("D") + def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: if len(seq) == 1: @@ -66,3 +80,73 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: ) from error return sio.loadmat(buffer, **kwargs) + + +def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.BytesIO: + image = PIL.Image.fromarray(array) + buffer = io.BytesIO() + image.save(buffer, format=format) + buffer.seek(0) + return buffer + + +class CompressionType(enum.Enum): + GZIP = "gzip" + LZMA = "lzma" + + +class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]): + types = CompressionType + + _DECOMPRESSORS = { + types.GZIP: lambda file: gzip.GzipFile(fileobj=file), + types.LZMA: lambda file: lzma.LZMAFile(file), + } + + def __init__( + self, + datapipe: IterDataPipe[Tuple[str, io.IOBase]], + *, + type: Optional[Union[str, CompressionType]] = None, + ) -> None: + self.datapipe = datapipe + if isinstance(type, str): + type = self.types(type.upper()) + self.type = type + + def _detect_compression_type(self, path: str) -> CompressionType: + if self.type: + return self.type + + # TODO: this needs to be more elaborate + ext = os.path.splitext(path)[1] + if ext == ".gz": + return self.types.GZIP + elif ext == ".xz": + return self.types.LZMA + else: + raise RuntimeError("FIXME") + + def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: + for path, file in self.datapipe: + type = self._detect_compression_type(path) + decompressor = self._DECOMPRESSORS[type] + yield path, decompressor(file) + + +class Slicer(IterDataPipe[D]): + def __init__( + self, + datapipe: IterDataPipe[D], + *, + start: Optional[int] = None, + stop: Optional[int] = None, + step: Optional[int] = None, + ): + self.datapipe = datapipe + self.start = start + self.stop = stop + self.step = step + + def __iter__(self) -> Iterator[D]: + yield from itertools.islice(self.datapipe, self.start, self.stop, self.step) From 6fff66d44d88cedb0654e3c07d57165045c88d61 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 30 Sep 2021 16:01:55 +0200 Subject: [PATCH 2/8] fix mypy --- torchvision/prototype/datasets/_builtin/mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 172db2762d1..e54dde8d9b3 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -5,7 +5,7 @@ import operator import pathlib import string -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast import numpy as np import torch @@ -66,7 +66,7 @@ def __iter__(self) -> Iterator[np.ndarray]: in_dtype = np.dtype(f">{dtype_type}") out_dtype = np.dtype(dtype_type) - chunk_size = (prod(shape) if shape else 1) * in_dtype.itemsize + chunk_size = (cast(int, prod(shape)) if shape else 1) * in_dtype.itemsize for _ in range(num_samples): chunk = file.read(chunk_size) yield np.frombuffer(chunk, dtype=in_dtype).astype(out_dtype).reshape(shape) From 68f23aae673d5bfb0f33e350549a38042c5408ac Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Oct 2021 08:06:40 +0200 Subject: [PATCH 3/8] fix EMNIST labels --- .../prototype/datasets/_builtin/mnist.py | 55 +++++++++++++++---- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index e54dde8d9b3..bfdc4fa725b 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -95,6 +95,7 @@ def _collate_and_decode( self, data: Tuple[np.ndarray, np.ndarray], *, + config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ): image_array, label_array = data @@ -124,7 +125,7 @@ def _make_datapipe( dp: IterDataPipe = Zipper(images_dp, labels_dp) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) - return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) + return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder)) class MNIST(_MNISTBase): @@ -215,18 +216,17 @@ class EMNIST(_MNISTBase): def info(self): return DatasetInfo( "emnist", - # FIXME: shift the labels at runtime to always return a static label - categories=list(string.digits + string.ascii_letters), + categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", valid_options=dict( split=("train", "test"), image_set=( - "mnist", - "byclass", - "bymerge", - "balanced", - "digits", - "letters", + "Balanced", + "By_Merge", + "By_Class", + "Letters", + "Digits", + "MNIST", ), ), ) @@ -234,7 +234,7 @@ def info(self): _URL_BASE = "https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST" def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = f"emnist-{config.image_set}-{config.split}" + prefix = f"emnist-{config.image_set.replace('_', '').lower()}-{config.split}" images_file = f"{prefix}-images-idx3-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz" # Since EMNIST provides the data files inside an archive, we don't need provide checksums for them @@ -258,6 +258,38 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> else: return None + _LABEL_OFFSETS = { + 38: 1, + 39: 1, + 40: 1, + 41: 1, + 42: 1, + 43: 6, + 44: 8, + 45: 8, + 46: 9, + } + + def _collate_and_decode( + self, + data: Tuple[np.ndarray, np.ndarray], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ): + image_array, label_array = data + # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). + # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, + # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, + # since there is no 'c', 'd' corresponds to + # label 38 (10 digits + 26 uppercase letters + 3rd unmerged lower case letter - 1 for zero indexing), + # and at the same time corresponds to + # index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing) + # in self.categories. Thus, we need to add 1 to the label to correct this. + if config.image_set in ("Balanced", "By_Merge"): + label_array += np.array(self._LABEL_OFFSETS.get(int(label_array), 0), dtype=label_array.dtype) + return super()._collate_and_decode((image_array, label_array), config=config, decoder=decoder) + def _make_datapipe( self, resource_dps: List[IterDataPipe], @@ -333,11 +365,12 @@ def _collate_and_decode( self, data: Tuple[np.ndarray, np.ndarray], *, + config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ): image_array, label_array = data label_parts = label_array.tolist() - sample = super()._collate_and_decode((image_array, label_parts[0]), decoder=decoder) + sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder) sample.update( dict( From 9f3896dce5e53512f1ff1b363c74107800697bae Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Oct 2021 08:27:09 +0200 Subject: [PATCH 4/8] fix code format --- torchvision/prototype/datasets/_builtin/mnist.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index bfdc4fa725b..d41a6b9f4c4 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -17,7 +17,6 @@ Zipper, Shuffler, ) - from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, From a7ffeb029eb92d27e96f46fac8f61e2efe0e0a5c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Oct 2021 18:31:36 +0200 Subject: [PATCH 5/8] avoid encoding + decoding in every step --- .../prototype/datasets/_builtin/mnist.py | 40 +++++++------------ 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index d41a6b9f4c4..23a5537bdb1 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -5,7 +5,7 @@ import operator import pathlib import string -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, Union import numpy as np import torch @@ -17,8 +17,10 @@ Zipper, Shuffler, ) +from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, + DatasetType, DatasetConfig, DatasetInfo, HttpResource, @@ -92,15 +94,19 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _collate_and_decode( self, - data: Tuple[np.ndarray, np.ndarray], + data: Tuple[np.ndarray, Union[np.ndarray, int]], *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ): image_array, label_array = data - image_buffer = image_buffer_from_array(image_array) - image = decoder(image_buffer) if decoder else image_buffer + image: Union[torch.Tensor, io.BytesIO] + if decoder is raw: + image = torch.from_numpy(image_array) + else: + image_buffer = image_buffer_from_array(image_array) + image = decoder(image_buffer) if decoder else image_buffer label = torch.tensor(label_array, dtype=torch.int64) category = self.info.categories[int(label)] @@ -132,6 +138,7 @@ class MNIST(_MNISTBase): def info(self): return DatasetInfo( "mnist", + type=DatasetType.RAW, categories=10, homepage="http://yann.lecun.com/exdb/mnist", valid_options=dict( @@ -162,6 +169,7 @@ class FashionMNIST(MNIST): def info(self): return DatasetInfo( "fashionmnist", + type=DatasetType.RAW, categories=( "T-shirt/top", "Trouser", @@ -194,6 +202,7 @@ class KMNIST(MNIST): def info(self): return DatasetInfo( "kmnist", + type=DatasetType.RAW, categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], homepage="http://codh.rois.ac.jp/kmnist/index.html.en", valid_options=dict( @@ -215,6 +224,7 @@ class EMNIST(_MNISTBase): def info(self): return DatasetInfo( "emnist", + type=DatasetType.RAW, categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", valid_options=dict( @@ -313,6 +323,7 @@ class QMNIST(_MNISTBase): def info(self): return DatasetInfo( "qmnist", + type=DatasetType.RAW, categories=10, homepage="https://github.com/facebookresearch/qmnist", valid_options=dict( @@ -340,26 +351,6 @@ def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], self._CHECKSUMS[labels_file], ) - def _split_label(self, sample: Dict[str, Any]) -> Dict[str, Any]: - parts = [part.squeeze(0) for part in sample.pop("label").split(1)] - sample.update( - dict( - zip( - ( - "label", - "nist_hsf_series", - "nist_writer_id", - "digit_index", - "nist_label", - "global_digit_index", - ), - parts[:6], - ) - ) - ) - sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in parts[-2:]]))) - return sample - def _collate_and_decode( self, data: Tuple[np.ndarray, np.ndarray], @@ -390,7 +381,6 @@ def _make_datapipe( decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = super()._make_datapipe(resource_dps, config=config, decoder=decoder) - # dp = Mapper(dp, self._split_label) if config.split not in ("test10k", "test50k"): return dp From 47a0872b376425841b00b391a741f817747162c0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Oct 2021 10:27:47 +0200 Subject: [PATCH 6/8] discard data at the binary level instead of after decoding --- .../prototype/datasets/_builtin/mnist.py | 56 ++++++++++--------- .../prototype/datasets/utils/_internal.py | 20 ------- 2 files changed, 29 insertions(+), 47 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 23a5537bdb1..837ce2f47dc 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -29,7 +29,6 @@ from torchvision.prototype.datasets.utils._internal import ( image_buffer_from_array, Decompressor, - Slicer, INFINITE_BUFFER_SIZE, ) @@ -49,8 +48,10 @@ class MNISTFileReader(IterDataPipe): 14: "f8", # float64 } - def __init__(self, datapipe: IterDataPipe) -> None: + def __init__(self, datapipe: IterDataPipe, *, start: Optional[int], stop: Optional[int]) -> None: self.datapipe = datapipe + self.start = start + self.stop = stop @staticmethod def _decode(bytes): @@ -68,7 +69,12 @@ def __iter__(self) -> Iterator[np.ndarray]: in_dtype = np.dtype(f">{dtype_type}") out_dtype = np.dtype(dtype_type) chunk_size = (cast(int, prod(shape)) if shape else 1) * in_dtype.itemsize - for _ in range(num_samples): + + start = self.start or 0 + stop = self.stop or num_samples + + file.seek(file.tell() + start * chunk_size) + for _ in range(stop - start): chunk = file.read(chunk_size) yield np.frombuffer(chunk, dtype=in_dtype).astype(out_dtype).reshape(shape) @@ -92,6 +98,9 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: return [images, labels] + def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: + return None, None + def _collate_and_decode( self, data: Tuple[np.ndarray, Union[np.ndarray, int]], @@ -121,12 +130,13 @@ def _make_datapipe( decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: images_dp, labels_dp = resource_dps + start, stop = self.start_and_stop(config) images_dp = Decompressor(images_dp) - images_dp = MNISTFileReader(images_dp) + images_dp = MNISTFileReader(images_dp, start=start, stop=stop) labels_dp = Decompressor(labels_dp) - labels_dp = MNISTFileReader(labels_dp) + labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop) dp: IterDataPipe = Zipper(images_dp, labels_dp) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) @@ -351,6 +361,20 @@ def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], self._CHECKSUMS[labels_file], ) + def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: + start: Optional[int] + stop: Optional[int] + if config.split == "test10k": + start = 0 + stop = 10000 + elif config.split == "test50k": + start = 10000 + stop = None + else: + start = stop = None + + return start, stop + def _collate_and_decode( self, data: Tuple[np.ndarray, np.ndarray], @@ -372,25 +396,3 @@ def _collate_and_decode( ) sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in label_parts[-2:]]))) return sample - - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> IterDataPipe[Dict[str, Any]]: - dp = super()._make_datapipe(resource_dps, config=config, decoder=decoder) - if config.split not in ("test10k", "test50k"): - return dp - - start: Optional[int] - stop: Optional[int] - if config.split == "test10k": - start = 0 - stop = 10000 - else: # config.split == "test50k" - start = 10000 - stop = None - - return Slicer(dp, start=start, stop=stop) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index b00a96db8ef..72c55233e7d 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -3,7 +3,6 @@ import enum import gzip import io -import itertools import lzma import os.path import pathlib @@ -28,7 +27,6 @@ "path_accessor", "path_comparator", "Decompressor", - "Slicer", ] K = TypeVar("K") @@ -185,21 +183,3 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: type = self._detect_compression_type(path) decompressor = self._DECOMPRESSORS[type] yield path, decompressor(file) - - -class Slicer(IterDataPipe[D]): - def __init__( - self, - datapipe: IterDataPipe[D], - *, - start: Optional[int] = None, - stop: Optional[int] = None, - step: Optional[int] = None, - ): - self.datapipe = datapipe - self.start = start - self.stop = stop - self.step = step - - def __iter__(self) -> Iterator[D]: - yield from itertools.islice(self.datapipe, self.start, self.stop, self.step) From a0b8480f5f41fdbde794116e1fa9275648fe6288 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Oct 2021 10:29:36 +0200 Subject: [PATCH 7/8] cleanup --- torchvision/prototype/datasets/_builtin/mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 837ce2f47dc..dd94a1912cb 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -73,7 +73,7 @@ def __iter__(self) -> Iterator[np.ndarray]: start = self.start or 0 stop = self.stop or num_samples - file.seek(file.tell() + start * chunk_size) + file.seek(start * chunk_size, 1) for _ in range(stop - start): chunk = file.read(chunk_size) yield np.frombuffer(chunk, dtype=in_dtype).astype(out_dtype).reshape(shape) From 05b6eac2d352a521dac27574fb29900f3d36cef0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 7 Oct 2021 11:10:48 +0200 Subject: [PATCH 8/8] fix mypy --- torchvision/prototype/datasets/_builtin/mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index dd94a1912cb..b20c3ed6266 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -103,7 +103,7 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional def _collate_and_decode( self, - data: Tuple[np.ndarray, Union[np.ndarray, int]], + data: Tuple[np.ndarray, np.ndarray], *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]],