Skip to content

Commit 6ea01db

Browse files
committed
Migrate mnist dataset from np.frombuffer
1 parent c790216 commit 6ea01db

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

torchvision/datasets/mnist.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -489,12 +489,12 @@ def get_int(b: bytes) -> int:
489489

490490

491491
SN3_PASCALVINCENT_TYPEMAP = {
492-
8: (torch.uint8, np.uint8, np.uint8),
493-
9: (torch.int8, np.int8, np.int8),
494-
11: (torch.int16, np.dtype(">i2"), "i2"),
495-
12: (torch.int32, np.dtype(">i4"), "i4"),
496-
13: (torch.float32, np.dtype(">f4"), "f4"),
497-
14: (torch.float64, np.dtype(">f8"), "f8"),
492+
8: torch.uint8,
493+
9: torch.int8,
494+
11: torch.int16,
495+
12: torch.int32,
496+
13: torch.float32,
497+
14: torch.float64,
498498
}
499499

500500

@@ -511,11 +511,11 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
511511
ty = magic // 256
512512
assert 1 <= nd <= 3
513513
assert 8 <= ty <= 14
514-
m = SN3_PASCALVINCENT_TYPEMAP[ty]
514+
torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
515515
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
516-
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
516+
parsed = torch.frombuffer(data, dtype=torch_type, offset=(4 * (nd + 1)))
517517
assert parsed.shape[0] == np.prod(s) or not strict
518-
return torch.from_numpy(parsed.astype(m[2])).view(*s)
518+
return parsed.view(*s)
519519

520520

521521
def read_label_file(path: str) -> torch.Tensor:

0 commit comments

Comments
 (0)