1010import gzip
1111import lzma
1212from typing import Any , Callable , Dict , IO , List , Optional , Tuple , Union
13+ from urllib .error import URLError
1314from .utils import download_url , download_and_extract_archive , extract_archive , \
1415 verify_str_arg
1516
@@ -31,11 +32,16 @@ class MNIST(VisionDataset):
3132 target and transforms it.
3233 """
3334
35+ mirrors = [
36+ 'http://yann.lecun.com/exdb/mnist/' ,
37+ 'https://ossci-datasets.s3.amazonaws.com/mnist/' ,
38+ ]
39+
3440 resources = [
35- ("http://yann.lecun.com/exdb/mnist/ train-images-idx3-ubyte.gz" , "f68b3c2dcbeaaa9fbdd348bbdeb94873" ),
36- ("http://yann.lecun.com/exdb/mnist/ train-labels-idx1-ubyte.gz" , "d53e105ee54ea40749a09fcbcd1e9432" ),
37- ("http://yann.lecun.com/exdb/mnist/ t10k-images-idx3-ubyte.gz" , "9fb629c4189551a2d022fa330f9573f3" ),
38- ("http://yann.lecun.com/exdb/mnist/ t10k-labels-idx1-ubyte.gz" , "ec29112dd5afa0611ce80d1b7f02629c" )
41+ ("train-images-idx3-ubyte.gz" , "f68b3c2dcbeaaa9fbdd348bbdeb94873" ),
42+ ("train-labels-idx1-ubyte.gz" , "d53e105ee54ea40749a09fcbcd1e9432" ),
43+ ("t10k-images-idx3-ubyte.gz" , "9fb629c4189551a2d022fa330f9573f3" ),
44+ ("t10k-labels-idx1-ubyte.gz" , "ec29112dd5afa0611ce80d1b7f02629c" )
3945 ]
4046
4147 training_file = 'training.pt'
@@ -141,9 +147,26 @@ def download(self) -> None:
141147 os .makedirs (self .processed_folder , exist_ok = True )
142148
143149 # download files
144- for url , md5 in self .resources :
145- filename = url .rpartition ('/' )[2 ]
146- download_and_extract_archive (url , download_root = self .raw_folder , filename = filename , md5 = md5 )
150+ for filename , md5 in self .resources :
151+ for mirror in self .mirrors :
152+ url = "{}{}" .format (mirror , filename )
153+ try :
154+ print ("Downloading {}" .format (url ))
155+ download_and_extract_archive (
156+ url , download_root = self .raw_folder ,
157+ filename = filename ,
158+ md5 = md5
159+ )
160+ except URLError as error :
161+ print (
162+ "Failed to download (trying next):\n {}" .format (error )
163+ )
164+ continue
165+ finally :
166+ print ()
167+ break
168+ else :
169+ raise RuntimeError ("Error downloading {}" .format (filename ))
147170
148171 # process and save as torch files
149172 print ('Processing...' )
0 commit comments