From a821bd2fcc603263d846aaad8461accbd5e86527 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sat, 1 Aug 2020 12:48:12 +0200 Subject: [PATCH 1/2] add typehints for torchvision.datasets.stl10 --- torchvision/datasets/stl10.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 6bec45afe2b..29a37ae8014 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -2,6 +2,7 @@ import os import os.path import numpy as np +from typing import Any, Callable, Optional, Tuple from .vision import VisionDataset from .utils import check_integrity, download_and_extract_archive, verify_str_arg @@ -45,8 +46,17 @@ class STL10(VisionDataset): ] splits = ('train', 'train+unlabeled', 'unlabeled', 'test') - def __init__(self, root, split='train', folds=None, transform=None, - target_transform=None, download=False): + labels: np.ndarray + + def __init__( + self, + root: str, + split: str = "train", + folds: Optional[int] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: super(STL10, self).__init__(root, transform=transform, target_transform=target_transform) self.split = verify_str_arg(split, "split", self.splits) @@ -87,7 +97,7 @@ def __init__(self, root, split='train', folds=None, transform=None, with open(class_file) as f: self.classes = f.read().splitlines() - def _verify_folds(self, folds): + def _verify_folds(self, folds: Optional[int]) -> Optional[int]: if folds is None: return folds elif isinstance(folds, int): @@ -100,7 +110,7 @@ def _verify_folds(self, folds): msg = "Expected type None or int for argument folds, but got type {}." raise ValueError(msg.format(type(folds))) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index @@ -108,6 +118,7 @@ def __getitem__(self, index): Returns: tuple: (image, target) where target is index of the target class. """ + target: Optional[int] if self.labels is not None: img, target = self.data[index], int(self.labels[index]) else: @@ -125,10 +136,10 @@ def __getitem__(self, index): return img, target - def __len__(self): + def __len__(self) -> int: return self.data.shape[0] - def __loadfile(self, data_file, labels_file=None): + def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]: labels = None if labels_file: path_to_labels = os.path.join( @@ -145,7 +156,7 @@ def __loadfile(self, data_file, labels_file=None): return images, labels - def _check_integrity(self): + def _check_integrity(self) -> bool: root = self.root for fentry in (self.train_list + self.test_list): filename, md5 = fentry[0], fentry[1] @@ -154,17 +165,17 @@ def _check_integrity(self): return False return True - def download(self): + def download(self) -> None: if self._check_integrity(): print('Files already downloaded and verified') return download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) self._check_integrity() - def extra_repr(self): + def extra_repr(self) -> str: return "Split: {split}".format(**self.__dict__) - def __load_folds(self, folds): + def __load_folds(self, folds: Optional[int]) -> None: # loads one of the folds if specified if folds is None: return From 9f6f45ebead1a3ffd81ac688e9416fa8cbbfedb2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 3 Aug 2020 13:25:53 +0200 Subject: [PATCH 2/2] move annotation from class to instance scope --- torchvision/datasets/stl10.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 29a37ae8014..1d619183330 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -46,8 +46,6 @@ class STL10(VisionDataset): ] splits = ('train', 'train+unlabeled', 'unlabeled', 'test') - labels: np.ndarray - def __init__( self, root: str, @@ -70,6 +68,7 @@ def __init__( 'You can use download=True to download it') # now load the picked numpy arrays + self.labels: np.ndarray if self.split == 'train': self.data, self.labels = self.__loadfile( self.train_list[0][0], self.train_list[1][0])