@@ -382,40 +382,51 @@ def __init__(
382382 self .test_file = self .data_file
383383 super (QMNIST , self ).__init__ (root , train , ** kwargs )
384384
385+ @property
386+ def images_file (self ) -> str :
387+ (url , _ ), _ = self .resources [self .subsets [self .what ]]
388+ return os .path .join (self .raw_folder , os .path .splitext (os .path .basename (url ))[0 ])
389+
390+ @property
391+ def labels_file (self ) -> str :
392+ _ , (url , _ ) = self .resources [self .subsets [self .what ]]
393+ return os .path .join (self .raw_folder , os .path .splitext (os .path .basename (url ))[0 ])
394+
395+ def _check_exists (self ) -> bool :
396+ return all (check_integrity (file ) for file in (self .images_file , self .labels_file ))
397+
398+ def _load_data (self ):
399+ data = read_sn3_pascalvincent_tensor (self .images_file )
400+ assert (data .dtype == torch .uint8 )
401+ assert (data .ndimension () == 3 )
402+
403+ targets = read_sn3_pascalvincent_tensor (self .labels_file ).long ()
404+ assert (targets .ndimension () == 2 )
405+
406+ if self .what == 'test10k' :
407+ data = data [0 :10000 , :, :].clone ()
408+ targets = targets [0 :10000 , :].clone ()
409+ elif self .what == 'test50k' :
410+ data = data [10000 :, :, :].clone ()
411+ targets = targets [10000 :, :].clone ()
412+
413+ return data , targets
414+
385415 def download (self ) -> None :
386- """Download the QMNIST data if it doesn't exist in processed_folder already.
416+ """Download the QMNIST data if it doesn't exist already.
387417 Note that we only download what has been asked for (argument 'what').
388418 """
389419 if self ._check_exists ():
390420 return
421+
391422 os .makedirs (self .raw_folder , exist_ok = True )
392- os .makedirs (self .processed_folder , exist_ok = True )
393423 split = self .resources [self .subsets [self .what ]]
394- files = []
395424
396- # download data files if not already there
397425 for url , md5 in split :
398426 filename = url .rpartition ('/' )[2 ]
399427 file_path = os .path .join (self .raw_folder , filename )
400428 if not os .path .isfile (file_path ):
401- download_url (url , root = self .raw_folder , filename = filename , md5 = md5 )
402- files .append (file_path )
403-
404- # process and save as torch files
405- print ('Processing...' )
406- data = read_sn3_pascalvincent_tensor (files [0 ])
407- assert (data .dtype == torch .uint8 )
408- assert (data .ndimension () == 3 )
409- targets = read_sn3_pascalvincent_tensor (files [1 ]).long ()
410- assert (targets .ndimension () == 2 )
411- if self .what == 'test10k' :
412- data = data [0 :10000 , :, :].clone ()
413- targets = targets [0 :10000 , :].clone ()
414- if self .what == 'test50k' :
415- data = data [10000 :, :, :].clone ()
416- targets = targets [10000 :, :].clone ()
417- with open (os .path .join (self .processed_folder , self .data_file ), 'wb' ) as f :
418- torch .save ((data , targets ), f )
429+ download_and_extract_archive (url , self .raw_folder , filename = filename , md5 = md5 )
419430
420431 def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
421432 # redefined to handle the compat flag
0 commit comments