22import os
33import os .path
44import numpy as np
5+ from typing import Any , Callable , Optional , Tuple
56
67from .vision import VisionDataset
78from .utils import check_integrity , download_and_extract_archive , verify_str_arg
@@ -45,8 +46,15 @@ class STL10(VisionDataset):
4546 ]
4647 splits = ('train' , 'train+unlabeled' , 'unlabeled' , 'test' )
4748
48- def __init__ (self , root , split = 'train' , folds = None , transform = None ,
49- target_transform = None , download = False ):
49+ def __init__ (
50+ self ,
51+ root : str ,
52+ split : str = "train" ,
53+ folds : Optional [int ] = None ,
54+ transform : Optional [Callable ] = None ,
55+ target_transform : Optional [Callable ] = None ,
56+ download : bool = False ,
57+ ) -> None :
5058 super (STL10 , self ).__init__ (root , transform = transform ,
5159 target_transform = target_transform )
5260 self .split = verify_str_arg (split , "split" , self .splits )
@@ -60,6 +68,7 @@ def __init__(self, root, split='train', folds=None, transform=None,
6068 'You can use download=True to download it' )
6169
6270 # now load the picked numpy arrays
71+ self .labels : np .ndarray
6372 if self .split == 'train' :
6473 self .data , self .labels = self .__loadfile (
6574 self .train_list [0 ][0 ], self .train_list [1 ][0 ])
@@ -87,7 +96,7 @@ def __init__(self, root, split='train', folds=None, transform=None,
8796 with open (class_file ) as f :
8897 self .classes = f .read ().splitlines ()
8998
90- def _verify_folds (self , folds ) :
99+ def _verify_folds (self , folds : Optional [ int ]) -> Optional [ int ] :
91100 if folds is None :
92101 return folds
93102 elif isinstance (folds , int ):
@@ -100,14 +109,15 @@ def _verify_folds(self, folds):
100109 msg = "Expected type None or int for argument folds, but got type {}."
101110 raise ValueError (msg .format (type (folds )))
102111
103- def __getitem__ (self , index ) :
112+ def __getitem__ (self , index : int ) -> Tuple [ Any , Any ] :
104113 """
105114 Args:
106115 index (int): Index
107116
108117 Returns:
109118 tuple: (image, target) where target is index of the target class.
110119 """
120+ target : Optional [int ]
111121 if self .labels is not None :
112122 img , target = self .data [index ], int (self .labels [index ])
113123 else :
@@ -125,10 +135,10 @@ def __getitem__(self, index):
125135
126136 return img , target
127137
128- def __len__ (self ):
138+ def __len__ (self ) -> int :
129139 return self .data .shape [0 ]
130140
131- def __loadfile (self , data_file , labels_file = None ):
141+ def __loadfile (self , data_file : str , labels_file : Optional [ str ] = None ) -> Tuple [ np . ndarray , Optional [ np . ndarray ]] :
132142 labels = None
133143 if labels_file :
134144 path_to_labels = os .path .join (
@@ -145,7 +155,7 @@ def __loadfile(self, data_file, labels_file=None):
145155
146156 return images , labels
147157
148- def _check_integrity (self ):
158+ def _check_integrity (self ) -> bool :
149159 root = self .root
150160 for fentry in (self .train_list + self .test_list ):
151161 filename , md5 = fentry [0 ], fentry [1 ]
@@ -154,17 +164,17 @@ def _check_integrity(self):
154164 return False
155165 return True
156166
157- def download (self ):
167+ def download (self ) -> None :
158168 if self ._check_integrity ():
159169 print ('Files already downloaded and verified' )
160170 return
161171 download_and_extract_archive (self .url , self .root , filename = self .filename , md5 = self .tgz_md5 )
162172 self ._check_integrity ()
163173
164- def extra_repr (self ):
174+ def extra_repr (self ) -> str :
165175 return "Split: {split}" .format (** self .__dict__ )
166176
167- def __load_folds (self , folds ) :
177+ def __load_folds (self , folds : Optional [ int ]) -> None :
168178 # loads one of the folds if specified
169179 if folds is None :
170180 return
0 commit comments