2222from torch import Tensor
2323from torch .utils .data import Dataset
2424
25- from tests import _PROJECT_ROOT
26-
27- #: local path to test datasets
28- PATH_DATASETS = os .path .join (_PROJECT_ROOT , 'Datasets' )
29-
3025
3126class MNIST (Dataset ):
3227 """
@@ -47,7 +42,7 @@ class MNIST(Dataset):
4742 downloaded again.
4843
4944 Examples:
50- >>> dataset = MNIST(download=True)
45+ >>> dataset = MNIST(".", download=True)
5146 >>> len(dataset)
5247 60000
5348 >>> torch.bincount(dataset.targets)
@@ -65,7 +60,7 @@ class MNIST(Dataset):
6560
6661 def __init__ (
6762 self ,
68- root : str = PATH_DATASETS ,
63+ root : str ,
6964 train : bool = True ,
7065 normalize : tuple = (0.1307 , 0.3081 ),
7166 download : bool = True ,
@@ -152,7 +147,7 @@ class TrialMNIST(MNIST):
152147 kwargs: Same as MNIST
153148
154149 Examples:
155- >>> dataset = TrialMNIST(download=True)
150+ >>> dataset = TrialMNIST(".", download=True)
156151 >>> len(dataset)
157152 300
158153 >>> sorted(set([d.item() for d in dataset.targets]))
@@ -161,15 +156,15 @@ class TrialMNIST(MNIST):
161156 tensor([100, 100, 100])
162157 """
163158
164- def __init__ (self , num_samples : int = 100 , digits : Optional [Sequence ] = (0 , 1 , 2 ), ** kwargs ):
159+ def __init__ (self , root : str , num_samples : int = 100 , digits : Optional [Sequence ] = (0 , 1 , 2 ), ** kwargs ):
165160 # number of examples per class
166161 self .num_samples = num_samples
167162 # take just a subset of MNIST dataset
168163 self .digits = sorted (digits ) if digits else list (range (10 ))
169164
170165 self .cache_folder_name = f"digits-{ '-' .join (str (d ) for d in self .digits )} _nb-{ self .num_samples } "
171166
172- super ().__init__ (normalize = (0.5 , 1.0 ), ** kwargs )
167+ super ().__init__ (root , normalize = (0.5 , 1.0 ), ** kwargs )
173168
174169 @staticmethod
175170 def _prepare_subset (full_data : torch .Tensor , full_targets : torch .Tensor , num_samples : int , digits : Sequence ):
0 commit comments