Skip to content

Commit 3a159df

Browse files
authored
add typehints for torchvision.datasets.stl10 (#2540)
* add typehints for torchvision.datasets.stl10 * move annotation from class to instance scope
1 parent 7fc47ea commit 3a159df

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

torchvision/datasets/stl10.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import os.path
44
import numpy as np
5+
from typing import Any, Callable, Optional, Tuple
56

67
from .vision import VisionDataset
78
from .utils import check_integrity, download_and_extract_archive, verify_str_arg
@@ -45,8 +46,15 @@ class STL10(VisionDataset):
4546
]
4647
splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
4748

48-
def __init__(self, root, split='train', folds=None, transform=None,
49-
target_transform=None, download=False):
49+
def __init__(
50+
self,
51+
root: str,
52+
split: str = "train",
53+
folds: Optional[int] = None,
54+
transform: Optional[Callable] = None,
55+
target_transform: Optional[Callable] = None,
56+
download: bool = False,
57+
) -> None:
5058
super(STL10, self).__init__(root, transform=transform,
5159
target_transform=target_transform)
5260
self.split = verify_str_arg(split, "split", self.splits)
@@ -60,6 +68,7 @@ def __init__(self, root, split='train', folds=None, transform=None,
6068
'You can use download=True to download it')
6169

6270
# now load the picked numpy arrays
71+
self.labels: np.ndarray
6372
if self.split == 'train':
6473
self.data, self.labels = self.__loadfile(
6574
self.train_list[0][0], self.train_list[1][0])
@@ -87,7 +96,7 @@ def __init__(self, root, split='train', folds=None, transform=None,
8796
with open(class_file) as f:
8897
self.classes = f.read().splitlines()
8998

90-
def _verify_folds(self, folds):
99+
def _verify_folds(self, folds: Optional[int]) -> Optional[int]:
91100
if folds is None:
92101
return folds
93102
elif isinstance(folds, int):
@@ -100,14 +109,15 @@ def _verify_folds(self, folds):
100109
msg = "Expected type None or int for argument folds, but got type {}."
101110
raise ValueError(msg.format(type(folds)))
102111

103-
def __getitem__(self, index):
112+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
104113
"""
105114
Args:
106115
index (int): Index
107116
108117
Returns:
109118
tuple: (image, target) where target is index of the target class.
110119
"""
120+
target: Optional[int]
111121
if self.labels is not None:
112122
img, target = self.data[index], int(self.labels[index])
113123
else:
@@ -125,10 +135,10 @@ def __getitem__(self, index):
125135

126136
return img, target
127137

128-
def __len__(self):
138+
def __len__(self) -> int:
129139
return self.data.shape[0]
130140

131-
def __loadfile(self, data_file, labels_file=None):
141+
def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
132142
labels = None
133143
if labels_file:
134144
path_to_labels = os.path.join(
@@ -145,7 +155,7 @@ def __loadfile(self, data_file, labels_file=None):
145155

146156
return images, labels
147157

148-
def _check_integrity(self):
158+
def _check_integrity(self) -> bool:
149159
root = self.root
150160
for fentry in (self.train_list + self.test_list):
151161
filename, md5 = fentry[0], fentry[1]
@@ -154,17 +164,17 @@ def _check_integrity(self):
154164
return False
155165
return True
156166

157-
def download(self):
167+
def download(self) -> None:
158168
if self._check_integrity():
159169
print('Files already downloaded and verified')
160170
return
161171
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
162172
self._check_integrity()
163173

164-
def extra_repr(self):
174+
def extra_repr(self) -> str:
165175
return "Split: {split}".format(**self.__dict__)
166176

167-
def __load_folds(self, folds):
177+
def __load_folds(self, folds: Optional[int]) -> None:
168178
# loads one of the folds if specified
169179
if folds is None:
170180
return

0 commit comments

Comments
 (0)