From 97f3344164f10ee8e09319751856d120b02f14af Mon Sep 17 00:00:00 2001 From: Sergii Khomenko Date: Tue, 12 Oct 2021 11:18:01 +0100 Subject: [PATCH 1/3] Migrate mnist dataset from np.frombuffer --- torchvision/datasets/mnist.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index cedc3ca5e77..e9697497109 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -489,12 +489,12 @@ def get_int(b: bytes) -> int: SN3_PASCALVINCENT_TYPEMAP = { - 8: (torch.uint8, np.uint8, np.uint8), - 9: (torch.int8, np.int8, np.int8), - 11: (torch.int16, np.dtype(">i2"), "i2"), - 12: (torch.int32, np.dtype(">i4"), "i4"), - 13: (torch.float32, np.dtype(">f4"), "f4"), - 14: (torch.float64, np.dtype(">f8"), "f8"), + 8: torch.uint8, + 9: torch.int8, + 11: torch.int16, + 12: torch.int32, + 13: torch.float32, + 14: torch.float64, } @@ -511,11 +511,11 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso ty = magic // 256 assert 1 <= nd <= 3 assert 8 <= ty <= 14 - m = SN3_PASCALVINCENT_TYPEMAP[ty] + torch_type = SN3_PASCALVINCENT_TYPEMAP[ty] s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] - parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) + parsed = torch.frombuffer(data, dtype=torch_type, offset=(4 * (nd + 1))) assert parsed.shape[0] == np.prod(s) or not strict - return torch.from_numpy(parsed.astype(m[2])).view(*s) + return parsed.view(*s) def read_label_file(path: str) -> torch.Tensor: From 02cb43a030e69b4e73e83090a919ef1f77ff46f2 Mon Sep 17 00:00:00 2001 From: Sergii Khomenko Date: Wed, 20 Oct 2021 22:36:02 +0100 Subject: [PATCH 2/3] Add a copy with bytearray for non-writable buffers --- torchvision/datasets/mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index e9697497109..6faae6b9393 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -513,7 +513,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso assert 8 <= ty <= 14 torch_type = SN3_PASCALVINCENT_TYPEMAP[ty] s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] - parsed = torch.frombuffer(data, dtype=torch_type, offset=(4 * (nd + 1))) + parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) assert parsed.shape[0] == np.prod(s) or not strict return parsed.view(*s) From 38f189ae6f42cdb2defe70a312c9d648849eda69 Mon Sep 17 00:00:00 2001 From: Sergii Khomenko Date: Thu, 21 Oct 2021 13:28:30 +0100 Subject: [PATCH 3/3] Add byte reversal for mnist --- torchvision/datasets/mnist.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 6faae6b9393..3164b9857e2 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -3,6 +3,7 @@ import os.path import shutil import string +import sys import warnings from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.error import URLError @@ -513,7 +514,15 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso assert 8 <= ty <= 14 torch_type = SN3_PASCALVINCENT_TYPEMAP[ty] s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] + + num_bytes_per_value = torch.iinfo(torch_type).bits // 8 + # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, + # we need to reverse the bytes before we can read them with torch.frombuffer(). + needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) + if needs_byte_reversal: + parsed = parsed.flip(0) + assert parsed.shape[0] == np.prod(s) or not strict return parsed.view(*s)