diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index cedc3ca5e77..3164b9857e2 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -3,6 +3,7 @@ import os.path import shutil import string +import sys import warnings from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.error import URLError @@ -489,12 +490,12 @@ def get_int(b: bytes) -> int: SN3_PASCALVINCENT_TYPEMAP = { - 8: (torch.uint8, np.uint8, np.uint8), - 9: (torch.int8, np.int8, np.int8), - 11: (torch.int16, np.dtype(">i2"), "i2"), - 12: (torch.int32, np.dtype(">i4"), "i4"), - 13: (torch.float32, np.dtype(">f4"), "f4"), - 14: (torch.float64, np.dtype(">f8"), "f8"), + 8: torch.uint8, + 9: torch.int8, + 11: torch.int16, + 12: torch.int32, + 13: torch.float32, + 14: torch.float64, } @@ -511,11 +512,19 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso ty = magic // 256 assert 1 <= nd <= 3 assert 8 <= ty <= 14 - m = SN3_PASCALVINCENT_TYPEMAP[ty] + torch_type = SN3_PASCALVINCENT_TYPEMAP[ty] s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] - parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) + + num_bytes_per_value = torch.iinfo(torch_type).bits // 8 + # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, + # we need to reverse the bytes before we can read them with torch.frombuffer(). + needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 + parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) + if needs_byte_reversal: + parsed = parsed.flip(0) + assert parsed.shape[0] == np.prod(s) or not strict - return torch.from_numpy(parsed.astype(m[2])).view(*s) + return parsed.view(*s) def read_label_file(path: str) -> torch.Tensor: