Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f61a0b9
add FloReader datapipe
pmeier Nov 8, 2021
675eaa0
add NumericBinaryReader
pmeier Nov 8, 2021
e0157bc
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 8, 2021
05e934f
revert unrelated change
pmeier Nov 8, 2021
3a2d812
cleanup
pmeier Nov 8, 2021
2d7111d
cleanup
pmeier Nov 8, 2021
f984983
add comment for byte reversal
pmeier Nov 8, 2021
c4b46b7
use numpy after all
pmeier Nov 8, 2021
ba362a7
appease mypy
pmeier Nov 8, 2021
263b454
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 8, 2021
3bb9256
use .astype() with copy=False
pmeier Nov 9, 2021
5e029a9
add docstring and cleanuo
pmeier Nov 9, 2021
e9c5584
reuse current _read_flo and revert MNIST changes
pmeier Nov 10, 2021
fa4fafb
cleanup
pmeier Nov 10, 2021
61a71a1
revert demonstration
pmeier Nov 16, 2021
950dc49
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 16, 2021
68f2d95
refactor
pmeier Nov 16, 2021
a3823ba
cleanup
pmeier Nov 16, 2021
de865cf
add support for mutable memory
pmeier Nov 18, 2021
c3fd445
add test
pmeier Nov 18, 2021
7c3a33f
add comments
pmeier Nov 18, 2021
2c62670
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 18, 2021
aa780fd
catch more exceptions
pmeier Nov 18, 2021
e9031af
fix mypy
pmeier Nov 18, 2021
1d55fc0
fix variable names
pmeier Nov 18, 2021
5ebb5ae
hardcode flow sizes in test
pmeier Nov 18, 2021
ac3e4c2
add fix dtype docstring
pmeier Nov 18, 2021
507681a
expand comment on different reading modes
pmeier Nov 18, 2021
c52c547
add comment about files in update mode
pmeier Nov 18, 2021
80e8f25
add tests for fromfile
pmeier Nov 18, 2021
1979d17
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 18, 2021
2bb491b
cleanup
pmeier Nov 19, 2021
388ccb1
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 19, 2021
0969cf9
cleanup
pmeier Nov 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 10 additions & 33 deletions torchvision/prototype/datasets/_builtin/mnist.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import abc
import codecs
import functools
import io
import operator
import pathlib
import string
import sys
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast

import torch
Expand All @@ -30,14 +27,13 @@
image_buffer_from_array,
Decompressor,
INFINITE_BUFFER_SIZE,
NumericBinaryReader,
)
from torchvision.prototype.features import Image, Label


__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]

prod = functools.partial(functools.reduce, operator.mul)


class MNISTFileReader(IterDataPipe[torch.Tensor]):
_DTYPE_MAP = {
Expand All @@ -56,44 +52,25 @@ def __init__(
self.start = start
self.stop = stop

@staticmethod
def _decode(input: bytes) -> int:
return int(codecs.encode(input, "hex"), 16)

@staticmethod
def _to_tensor(chunk: bytes, *, dtype: torch.dtype, shape: List[int], reverse_bytes: bool) -> torch.Tensor:
# As is, the chunk is not writeable, because it is read from a file and not from memory. Thus, we copy here to
# avoid the warning that torch.frombuffer would emit otherwise. This also enables inplace operations on the
# contents, which would otherwise fail.
chunk = bytearray(chunk)
if reverse_bytes:
chunk.reverse()
tensor = torch.frombuffer(chunk, dtype=dtype).flip(0)
else:
tensor = torch.frombuffer(chunk, dtype=dtype)
return tensor.reshape(shape)

def __iter__(self) -> Iterator[torch.Tensor]:
for _, file in self.datapipe:
magic = self._decode(file.read(4))
reader = NumericBinaryReader(file, byte_order="big")

magic = int(reader.read(torch.int32))
dtype = self._DTYPE_MAP[magic // 256]
ndim = magic % 256 - 1

num_samples = self._decode(file.read(4))
shape = [self._decode(file.read(4)) for _ in range(ndim)]

num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
reverse_bytes = sys.byteorder == "little" and num_bytes_per_value > 1
chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value
num_samples = int(reader.read(torch.int32))
shape = cast(List[int], reader.read(torch.int32, shape=(ndim,)).tolist()) if ndim else []

start = self.start or 0
stop = min(self.stop, num_samples) if self.stop else num_samples

file.seek(start * chunk_size, 1)
if start:
reader.skip(dtype, shape=(start,))

for _ in range(stop - start):
yield self._to_tensor(file.read(chunk_size), dtype=dtype, shape=shape, reverse_bytes=reverse_bytes)
yield reader.read(dtype, shape=shape)


class _MNISTBase(Dataset):
Expand Down
49 changes: 49 additions & 0 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import csv
import enum
import functools
import gzip
import io
import lzma
import operator
import os
import os.path
import pathlib
import pickle
import sys
import textwrap
from typing import (
Sequence,
Expand All @@ -28,6 +31,7 @@

import numpy as np
import PIL.Image
import torch
import torch.distributed as dist
import torch.utils.data
from torch.utils.data import IterDataPipe
Expand All @@ -51,6 +55,7 @@
"path_accessor",
"path_comparator",
"Decompressor",
"read_flo",
]

K = TypeVar("K")
Expand Down Expand Up @@ -335,3 +340,47 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe:
# dp = dp.cycle(2)
dp = TakerDataPipe(dp, dataset_size)
return dp


prod = functools.partial(functools.reduce, operator.mul)


class NumericBinaryReader:
def __init__(self, file: IO, *, byte_order: str = sys.byteorder) -> None:
self._file = file
# torch.frombuffer interprets the bytes in the same byte order as the system. Thus, if the data is stored in
# the opposite byte order, we need to reverse the bytes before feeding them to torch.frombuffer().
self._reverse = byte_order != sys.byteorder

def _compute_params(self, dtype: torch.dtype, shape: Sequence[int]) -> Tuple[int, bool]:
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
num_values = prod(shape) if shape else 1
chunk_size = num_bytes_per_value * num_values
reverse = num_bytes_per_value > 1 and self._reverse
return chunk_size, reverse

def read(self, dtype: torch.dtype, *, shape: Sequence[int] = ()) -> torch.Tensor:
chunk_size, reverse = self._compute_params(dtype, shape)
# As is, the chunk we read is not writeable, because it is read from a file and not from memory. Thus, we copy
# here to a bytearray in order to avoid the warning that torch.frombuffer would emit otherwise. This also
# enables inplace operations on the contents, which would otherwise fail.
chunk = bytearray(self._file.read(chunk_size))
if reverse:
chunk.reverse()
tensor = torch.frombuffer(chunk, dtype=dtype).flip(0)
else:
tensor = torch.frombuffer(chunk, dtype=dtype)
return tensor.reshape(tuple(shape))

def skip(self, dtype: torch.dtype, *, shape: Sequence[int] = ()) -> None:
chunk_size, _ = self._compute_params(dtype, shape)
self._file.seek(chunk_size, 1)


def read_flo(file: IO) -> torch.Tensor:
if file.read(4) != b"PIEH":
raise ValueError("Magic number incorrect. Invalid .flo file")

reader = NumericBinaryReader(file, byte_order="little")
width, height = reader.read(torch.int32, shape=(2,)).tolist()
return reader.read(torch.float32, shape=(height, width, 2)).permute((2, 0, 1))