Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions torchvision/datasets/stl10.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import os.path
import numpy as np
from typing import Any, Callable, Optional, Tuple

from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive, verify_str_arg
Expand Down Expand Up @@ -45,8 +46,15 @@ class STL10(VisionDataset):
]
splits = ('train', 'train+unlabeled', 'unlabeled', 'test')

def __init__(self, root, split='train', folds=None, transform=None,
target_transform=None, download=False):
def __init__(
self,
root: str,
split: str = "train",
folds: Optional[int] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(STL10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = verify_str_arg(split, "split", self.splits)
Expand All @@ -60,6 +68,7 @@ def __init__(self, root, split='train', folds=None, transform=None,
'You can use download=True to download it')

# now load the picked numpy arrays
self.labels: np.ndarray
if self.split == 'train':
self.data, self.labels = self.__loadfile(
self.train_list[0][0], self.train_list[1][0])
Expand Down Expand Up @@ -87,7 +96,7 @@ def __init__(self, root, split='train', folds=None, transform=None,
with open(class_file) as f:
self.classes = f.read().splitlines()

def _verify_folds(self, folds):
def _verify_folds(self, folds: Optional[int]) -> Optional[int]:
if folds is None:
return folds
elif isinstance(folds, int):
Expand All @@ -100,14 +109,15 @@ def _verify_folds(self, folds):
msg = "Expected type None or int for argument folds, but got type {}."
raise ValueError(msg.format(type(folds)))

def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target class.
"""
target: Optional[int]
if self.labels is not None:
img, target = self.data[index], int(self.labels[index])
else:
Expand All @@ -125,10 +135,10 @@ def __getitem__(self, index):

return img, target

def __len__(self):
def __len__(self) -> int:
return self.data.shape[0]

def __loadfile(self, data_file, labels_file=None):
def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
labels = None
if labels_file:
path_to_labels = os.path.join(
Expand All @@ -145,7 +155,7 @@ def __loadfile(self, data_file, labels_file=None):

return images, labels

def _check_integrity(self):
def _check_integrity(self) -> bool:
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
Expand All @@ -154,17 +164,17 @@ def _check_integrity(self):
return False
return True

def download(self):
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
self._check_integrity()

def extra_repr(self):
def extra_repr(self) -> str:
return "Split: {split}".format(**self.__dict__)

def __load_folds(self, folds):
def __load_folds(self, folds: Optional[int]) -> None:
# loads one of the folds if specified
if folds is None:
return
Expand Down