@@ -37,27 +37,28 @@ def __init__(
3737
3838 def __iter__ (self ) -> Iterator [torch .Tensor ]:
3939 for _ , file in self .datapipe :
40- read = functools .partial (fromfile , file , byte_order = "big" )
40+ try :
41+ read = functools .partial (fromfile , file , byte_order = "big" )
4142
42- magic = int (read (dtype = torch .int32 , count = 1 ))
43- dtype = self ._DTYPE_MAP [magic // 256 ]
44- ndim = magic % 256 - 1
43+ magic = int (read (dtype = torch .int32 , count = 1 ))
44+ dtype = self ._DTYPE_MAP [magic // 256 ]
45+ ndim = magic % 256 - 1
4546
46- num_samples = int (read (dtype = torch .int32 , count = 1 ))
47- shape = cast (List [int ], read (dtype = torch .int32 , count = ndim ).tolist ()) if ndim else []
48- count = prod (shape ) if shape else 1
47+ num_samples = int (read (dtype = torch .int32 , count = 1 ))
48+ shape = cast (List [int ], read (dtype = torch .int32 , count = ndim ).tolist ()) if ndim else []
49+ count = prod (shape ) if shape else 1
4950
50- start = self .start or 0
51- stop = min (self .stop , num_samples ) if self .stop else num_samples
51+ start = self .start or 0
52+ stop = min (self .stop , num_samples ) if self .stop else num_samples
5253
53- if start :
54- num_bytes_per_value = (torch .finfo if dtype .is_floating_point else torch .iinfo )(dtype ).bits // 8
55- file .seek (num_bytes_per_value * count * start , 1 )
54+ if start :
55+ num_bytes_per_value = (torch .finfo if dtype .is_floating_point else torch .iinfo )(dtype ).bits // 8
56+ file .seek (num_bytes_per_value * count * start , 1 )
5657
57- for _ in range (stop - start ):
58- yield read (dtype = dtype , count = count ).reshape (shape )
59-
60- file .close ()
58+ for _ in range (stop - start ):
59+ yield read (dtype = dtype , count = count ).reshape (shape )
60+ finally :
61+ file .close ()
6162
6263
6364class _MNISTBase (Dataset ):
0 commit comments