77import torch
88import codecs
99import string
10- import gzip
11- import lzma
12- from typing import Any , Callable , Dict , IO , List , Optional , Tuple , Union
10+ from typing import Any , Callable , Dict , List , Optional , Tuple
1311from urllib .error import URLError
14- from .utils import download_url , download_and_extract_archive , extract_archive , \
15- verify_str_arg
12+ from .utils import download_and_extract_archive , extract_archive , verify_str_arg , check_integrity
13+ import shutil
1614
1715
1816class MNIST (VisionDataset ):
@@ -81,18 +79,42 @@ def __init__(
8179 target_transform = target_transform )
8280 self .train = train # training set or test set
8381
82+ if self ._check_legacy_exist ():
83+ self .data , self .targets = self ._load_legacy_data ()
84+ return
85+
8486 if download :
8587 self .download ()
8688
8789 if not self ._check_exists ():
8890 raise RuntimeError ('Dataset not found.' +
8991 ' You can use download=True to download it' )
9092
91- if self .train :
92- data_file = self .training_file
93- else :
94- data_file = self .test_file
95- self .data , self .targets = torch .load (os .path .join (self .processed_folder , data_file ))
93+ self .data , self .targets = self ._load_data ()
94+
95+ def _check_legacy_exist (self ):
96+ processed_folder_exists = os .path .exists (self .processed_folder )
97+ if not processed_folder_exists :
98+ return False
99+
100+ return all (
101+ check_integrity (os .path .join (self .processed_folder , file )) for file in (self .training_file , self .test_file )
102+ )
103+
104+ def _load_legacy_data (self ):
105+ # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
106+ # directly.
107+ data_file = self .training_file if self .train else self .test_file
108+ return torch .load (os .path .join (self .processed_folder , data_file ))
109+
110+ def _load_data (self ):
111+ image_file = f"{ 'train' if self .train else 't10k' } -images-idx3-ubyte"
112+ data = read_image_file (os .path .join (self .raw_folder , image_file ))
113+
114+ label_file = f"{ 'train' if self .train else 't10k' } -labels-idx1-ubyte"
115+ targets = read_label_file (os .path .join (self .raw_folder , label_file ))
116+
117+ return data , targets
96118
97119 def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
98120 """
@@ -132,19 +154,18 @@ def class_to_idx(self) -> Dict[str, int]:
132154 return {_class : i for i , _class in enumerate (self .classes )}
133155
134156 def _check_exists (self ) -> bool :
135- return ( os . path . exists ( os . path . join ( self . processed_folder ,
136- self .training_file )) and
137- os . path . exists ( os . path . join ( self .processed_folder ,
138- self . test_file )) )
157+ return all (
158+ check_integrity ( os . path . join ( self .raw_folder , os . path . splitext ( os . path . basename ( url ))[ 0 ]))
159+ for url , _ in self .resources
160+ )
139161
140162 def download (self ) -> None :
141- """Download the MNIST data if it doesn't exist in processed_folder already."""
163+ """Download the MNIST data if it doesn't exist already."""
142164
143165 if self ._check_exists ():
144166 return
145167
146168 os .makedirs (self .raw_folder , exist_ok = True )
147- os .makedirs (self .processed_folder , exist_ok = True )
148169
149170 # download files
150171 for filename , md5 in self .resources :
@@ -168,24 +189,6 @@ def download(self) -> None:
168189 else :
169190 raise RuntimeError ("Error downloading {}" .format (filename ))
170191
171- # process and save as torch files
172- print ('Processing...' )
173-
174- training_set = (
175- read_image_file (os .path .join (self .raw_folder , 'train-images-idx3-ubyte' )),
176- read_label_file (os .path .join (self .raw_folder , 'train-labels-idx1-ubyte' ))
177- )
178- test_set = (
179- read_image_file (os .path .join (self .raw_folder , 't10k-images-idx3-ubyte' )),
180- read_label_file (os .path .join (self .raw_folder , 't10k-labels-idx1-ubyte' ))
181- )
182- with open (os .path .join (self .processed_folder , self .training_file ), 'wb' ) as f :
183- torch .save (training_set , f )
184- with open (os .path .join (self .processed_folder , self .test_file ), 'wb' ) as f :
185- torch .save (test_set , f )
186-
187- print ('Done!' )
188-
189192 def extra_repr (self ) -> str :
190193 return "Split: {}" .format ("Train" if self .train is True else "Test" )
191194
@@ -298,44 +301,39 @@ def _training_file(split) -> str:
298301 def _test_file (split ) -> str :
299302 return 'test_{}.pt' .format (split )
300303
304+ @property
305+ def _file_prefix (self ) -> str :
306+ return f"emnist-{ self .split } -{ 'train' if self .train else 'test' } "
307+
308+ @property
309+ def images_file (self ) -> str :
310+ return os .path .join (self .raw_folder , f"{ self ._file_prefix } -images-idx3-ubyte" )
311+
312+ @property
313+ def labels_file (self ) -> str :
314+ return os .path .join (self .raw_folder , f"{ self ._file_prefix } -labels-idx1-ubyte" )
315+
316+ def _load_data (self ):
317+ return read_image_file (self .images_file ), read_label_file (self .labels_file )
318+
319+ def _check_exists (self ) -> bool :
320+ return all (check_integrity (file ) for file in (self .images_file , self .labels_file ))
321+
301322 def download (self ) -> None :
302- """Download the EMNIST data if it doesn't exist in processed_folder already."""
303- import shutil
323+ """Download the EMNIST data if it doesn't exist already."""
304324
305325 if self ._check_exists ():
306326 return
307327
308328 os .makedirs (self .raw_folder , exist_ok = True )
309- os .makedirs (self .processed_folder , exist_ok = True )
310329
311- # download files
312- print ('Downloading and extracting zip archive' )
313- download_and_extract_archive (self .url , download_root = self .raw_folder , filename = "emnist.zip" ,
314- remove_finished = True , md5 = self .md5 )
330+ download_and_extract_archive (self .url , download_root = self .raw_folder , md5 = self .md5 )
315331 gzip_folder = os .path .join (self .raw_folder , 'gzip' )
316332 for gzip_file in os .listdir (gzip_folder ):
317333 if gzip_file .endswith ('.gz' ):
318- extract_archive (os .path .join (gzip_folder , gzip_file ), gzip_folder )
319-
320- # process and save as torch files
321- for split in self .splits :
322- print ('Processing ' + split )
323- training_set = (
324- read_image_file (os .path .join (gzip_folder , 'emnist-{}-train-images-idx3-ubyte' .format (split ))),
325- read_label_file (os .path .join (gzip_folder , 'emnist-{}-train-labels-idx1-ubyte' .format (split )))
326- )
327- test_set = (
328- read_image_file (os .path .join (gzip_folder , 'emnist-{}-test-images-idx3-ubyte' .format (split ))),
329- read_label_file (os .path .join (gzip_folder , 'emnist-{}-test-labels-idx1-ubyte' .format (split )))
330- )
331- with open (os .path .join (self .processed_folder , self ._training_file (split )), 'wb' ) as f :
332- torch .save (training_set , f )
333- with open (os .path .join (self .processed_folder , self ._test_file (split )), 'wb' ) as f :
334- torch .save (test_set , f )
334+ extract_archive (os .path .join (gzip_folder , gzip_file ), self .raw_folder )
335335 shutil .rmtree (gzip_folder )
336336
337- print ('Done!' )
338-
339337
340338class QMNIST (MNIST ):
341339 """`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
@@ -404,40 +402,51 @@ def __init__(
404402 self .test_file = self .data_file
405403 super (QMNIST , self ).__init__ (root , train , ** kwargs )
406404
405+ @property
406+ def images_file (self ) -> str :
407+ (url , _ ), _ = self .resources [self .subsets [self .what ]]
408+ return os .path .join (self .raw_folder , os .path .splitext (os .path .basename (url ))[0 ])
409+
410+ @property
411+ def labels_file (self ) -> str :
412+ _ , (url , _ ) = self .resources [self .subsets [self .what ]]
413+ return os .path .join (self .raw_folder , os .path .splitext (os .path .basename (url ))[0 ])
414+
415+ def _check_exists (self ) -> bool :
416+ return all (check_integrity (file ) for file in (self .images_file , self .labels_file ))
417+
418+ def _load_data (self ):
419+ data = read_sn3_pascalvincent_tensor (self .images_file )
420+ assert (data .dtype == torch .uint8 )
421+ assert (data .ndimension () == 3 )
422+
423+ targets = read_sn3_pascalvincent_tensor (self .labels_file ).long ()
424+ assert (targets .ndimension () == 2 )
425+
426+ if self .what == 'test10k' :
427+ data = data [0 :10000 , :, :].clone ()
428+ targets = targets [0 :10000 , :].clone ()
429+ elif self .what == 'test50k' :
430+ data = data [10000 :, :, :].clone ()
431+ targets = targets [10000 :, :].clone ()
432+
433+ return data , targets
434+
407435 def download (self ) -> None :
408- """Download the QMNIST data if it doesn't exist in processed_folder already.
436+ """Download the QMNIST data if it doesn't exist already.
409437 Note that we only download what has been asked for (argument 'what').
410438 """
411439 if self ._check_exists ():
412440 return
441+
413442 os .makedirs (self .raw_folder , exist_ok = True )
414- os .makedirs (self .processed_folder , exist_ok = True )
415443 split = self .resources [self .subsets [self .what ]]
416- files = []
417444
418- # download data files if not already there
419445 for url , md5 in split :
420446 filename = url .rpartition ('/' )[2 ]
421447 file_path = os .path .join (self .raw_folder , filename )
422448 if not os .path .isfile (file_path ):
423- download_url (url , root = self .raw_folder , filename = filename , md5 = md5 )
424- files .append (file_path )
425-
426- # process and save as torch files
427- print ('Processing...' )
428- data = read_sn3_pascalvincent_tensor (files [0 ])
429- assert (data .dtype == torch .uint8 )
430- assert (data .ndimension () == 3 )
431- targets = read_sn3_pascalvincent_tensor (files [1 ]).long ()
432- assert (targets .ndimension () == 2 )
433- if self .what == 'test10k' :
434- data = data [0 :10000 , :, :].clone ()
435- targets = targets [0 :10000 , :].clone ()
436- if self .what == 'test50k' :
437- data = data [10000 :, :, :].clone ()
438- targets = targets [10000 :, :].clone ()
439- with open (os .path .join (self .processed_folder , self .data_file ), 'wb' ) as f :
440- torch .save ((data , targets ), f )
449+ download_and_extract_archive (url , self .raw_folder , filename = filename , md5 = md5 )
441450
442451 def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
443452 # redefined to handle the compat flag
@@ -459,19 +468,6 @@ def get_int(b: bytes) -> int:
459468 return int (codecs .encode (b , 'hex' ), 16 )
460469
461470
462- def open_maybe_compressed_file (path : Union [str , IO ]) -> Union [IO , gzip .GzipFile ]:
463- """Return a file object that possibly decompresses 'path' on the fly.
464- Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
465- """
466- if not isinstance (path , torch ._six .string_classes ):
467- return path
468- if path .endswith ('.gz' ):
469- return gzip .open (path , 'rb' )
470- if path .endswith ('.xz' ):
471- return lzma .open (path , 'rb' )
472- return open (path , 'rb' )
473-
474-
475471SN3_PASCALVINCENT_TYPEMAP = {
476472 8 : (torch .uint8 , np .uint8 , np .uint8 ),
477473 9 : (torch .int8 , np .int8 , np .int8 ),
@@ -482,12 +478,12 @@ def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]
482478}
483479
484480
485- def read_sn3_pascalvincent_tensor (path : Union [ str , IO ] , strict : bool = True ) -> torch .Tensor :
481+ def read_sn3_pascalvincent_tensor (path : str , strict : bool = True ) -> torch .Tensor :
486482 """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
487483 Argument may be a filename, compressed filename, or file object.
488484 """
489485 # read
490- with open_maybe_compressed_file (path ) as f :
486+ with open (path , "rb" ) as f :
491487 data = f .read ()
492488 # parse
493489 magic = get_int (data [0 :4 ])
@@ -503,16 +499,14 @@ def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) ->
503499
504500
505501def read_label_file (path : str ) -> torch .Tensor :
506- with open (path , 'rb' ) as f :
507- x = read_sn3_pascalvincent_tensor (f , strict = False )
502+ x = read_sn3_pascalvincent_tensor (path , strict = False )
508503 assert (x .dtype == torch .uint8 )
509504 assert (x .ndimension () == 1 )
510505 return x .long ()
511506
512507
513508def read_image_file (path : str ) -> torch .Tensor :
514- with open (path , 'rb' ) as f :
515- x = read_sn3_pascalvincent_tensor (f , strict = False )
509+ x = read_sn3_pascalvincent_tensor (path , strict = False )
516510 assert (x .dtype == torch .uint8 )
517511 assert (x .ndimension () == 3 )
518512 return x
0 commit comments