1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import logging
1415import os
1516import platform
16- from typing import Optional
17+ import random
18+ import time
19+ import urllib
20+ from typing import Optional , Tuple
1721from urllib .error import HTTPError
1822from warnings import warn
1923
20- from torch .utils .data import DataLoader , random_split
24+ import torch
25+ from torch .utils .data import DataLoader , Dataset , random_split
2126
2227from pl_examples import _DATASETS_PATH
2328from pytorch_lightning import LightningDataModule
2732 from torchvision import transforms as transform_lib
2833
2934
35+ class _MNIST (Dataset ):
36+ """Carbon copy of ``tests.helpers.datasets.MNIST``.
37+
38+ We cannot import the tests as they are not distributed with the package.
39+ See https://github.com/PyTorchLightning/pytorch-lightning/pull/7614#discussion_r671183652 for more context.
40+ """
41+
42+ RESOURCES = (
43+ "https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt" ,
44+ "https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt" ,
45+ )
46+
47+ TRAIN_FILE_NAME = "training.pt"
48+ TEST_FILE_NAME = "test.pt"
49+ cache_folder_name = "complete"
50+
51+ def __init__ (
52+ self , root : str , train : bool = True , normalize : tuple = (0.1307 , 0.3081 ), download : bool = True , ** kwargs
53+ ):
54+ super ().__init__ ()
55+ self .root = root
56+ self .train = train # training set or test set
57+ self .normalize = normalize
58+
59+ self .prepare_data (download )
60+
61+ data_file = self .TRAIN_FILE_NAME if self .train else self .TEST_FILE_NAME
62+ self .data , self .targets = self ._try_load (os .path .join (self .cached_folder_path , data_file ))
63+
64+ def __getitem__ (self , idx : int ) -> Tuple [torch .Tensor , int ]:
65+ img = self .data [idx ].float ().unsqueeze (0 )
66+ target = int (self .targets [idx ])
67+
68+ if self .normalize is not None and len (self .normalize ) == 2 :
69+ img = self .normalize_tensor (img , * self .normalize )
70+
71+ return img , target
72+
73+ def __len__ (self ) -> int :
74+ return len (self .data )
75+
76+ @property
77+ def cached_folder_path (self ) -> str :
78+ return os .path .join (self .root , "MNIST" , self .cache_folder_name )
79+
80+ def _check_exists (self , data_folder : str ) -> bool :
81+ existing = True
82+ for fname in (self .TRAIN_FILE_NAME , self .TEST_FILE_NAME ):
83+ existing = existing and os .path .isfile (os .path .join (data_folder , fname ))
84+ return existing
85+
86+ def prepare_data (self , download : bool = True ):
87+ if download and not self ._check_exists (self .cached_folder_path ):
88+ self ._download (self .cached_folder_path )
89+ if not self ._check_exists (self .cached_folder_path ):
90+ raise RuntimeError ("Dataset not found." )
91+
92+ def _download (self , data_folder : str ) -> None :
93+ os .makedirs (data_folder , exist_ok = True )
94+ for url in self .RESOURCES :
95+ logging .info (f"Downloading { url } " )
96+ fpath = os .path .join (data_folder , os .path .basename (url ))
97+ urllib .request .urlretrieve (url , fpath )
98+
99+ @staticmethod
100+ def _try_load (path_data , trials : int = 30 , delta : float = 1.0 ):
101+ """Resolving loading from the same time from multiple concurrent processes."""
102+ res , exception = None , None
103+ assert trials , "at least some trial has to be set"
104+ assert os .path .isfile (path_data ), f"missing file: { path_data } "
105+ for _ in range (trials ):
106+ try :
107+ res = torch .load (path_data )
108+ # todo: specify the possible exception
109+ except Exception as e :
110+ exception = e
111+ time .sleep (delta * random .random ())
112+ else :
113+ break
114+ if exception is not None :
115+ # raise the caught exception
116+ raise exception
117+ return res
118+
119+ @staticmethod
120+ def normalize_tensor (tensor : torch .Tensor , mean : float = 0.0 , std : float = 1.0 ) -> torch .Tensor :
121+ mean = torch .as_tensor (mean , dtype = tensor .dtype , device = tensor .device )
122+ std = torch .as_tensor (std , dtype = tensor .dtype , device = tensor .device )
123+ return tensor .sub (mean ).div (std )
124+
125+
30126def MNIST (* args , ** kwargs ):
31127 torchvision_mnist_available = not bool (os .getenv ("PL_USE_MOCKED_MNIST" , False ))
32128 if torchvision_mnist_available :
@@ -39,7 +135,7 @@ def MNIST(*args, **kwargs):
39135 torchvision_mnist_available = False
40136 if not torchvision_mnist_available :
41137 print ("`torchvision.datasets.MNIST` not available. Using our hosted version" )
42- from tests . helpers . datasets import MNIST
138+ MNIST = _MNIST
43139 return MNIST (* args , ** kwargs )
44140
45141
0 commit comments