From d47ede984b08b5f6bf7d5b867a92b78f9c6c6cfd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 30 Sep 2021 15:54:48 +0200 Subject: [PATCH 01/11] make mypy more strict for prototype datasets --- mypy.ini | 16 ++++++++++++++++ .../prototype/datasets/_builtin/caltech.py | 13 ++++++++----- torchvision/prototype/datasets/_folder.py | 6 +++--- torchvision/prototype/datasets/decoder.py | 3 ++- torchvision/prototype/datasets/utils/_dataset.py | 12 +++++++----- .../prototype/datasets/utils/_resource.py | 2 +- 6 files changed, 37 insertions(+), 15 deletions(-) diff --git a/mypy.ini b/mypy.ini index dac60e11ce0..2016b93fd2d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -4,6 +4,22 @@ files = torchvision show_error_codes = True pretty = True +[mypy-torchvision.prototype.*] + +; untyped definitions and calls +disallow_untyped_defs = True + +; None and Optional handling +no_implicit_optional = True + +; warnings +warn_unused_ignores = True +warn_return_any = True +warn_unreachable = True + +; miscellaneous strictness flags +allow_redefinition = True + [mypy-torchvision.io._video_opt.*] ignore_errors = True diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 7f6021522c8..8a608f54bd1 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -84,7 +84,10 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: return category, id def _collate_and_decode_sample( - self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] + self, + data: Tuple[Tuple[str, str], Tuple[str, io.IOBase], Tuple[str, io.IOBase]], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: key, image_data, ann_data = data category, _ = key @@ -119,12 +122,12 @@ def _make_datapipe( images_dp, anns_dp = resource_dps images_dp = TarArchiveReader(images_dp) - images_dp = Filter(images_dp, self._is_not_background_image) + images_dp: IterDataPipe = Filter(images_dp, self._is_not_background_image) # FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved # images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) anns_dp = TarArchiveReader(anns_dp) - anns_dp = Filter(anns_dp, self._is_ann) + anns_dp: IterDataPipe = Filter(anns_dp, self._is_ann) dp = KeyZipper( images_dp, @@ -139,7 +142,7 @@ def _make_datapipe( def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = TarArchiveReader(dp) - dp = Filter(dp, self._is_not_background_image) + dp: IterDataPipe = Filter(dp, self._is_not_background_image) dir_names = {pathlib.Path(path).parent.name for path, _ in dp} create_categories_file(HERE, self.name, sorted(dir_names)) @@ -188,7 +191,7 @@ def _make_datapipe( ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = TarArchiveReader(dp) - dp = Filter(dp, self._is_not_rogue_file) + dp: IterDataPipe = Filter(dp, self._is_not_rogue_file) # FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved # dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index 5626f68650f..0e8bb56fc36 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -26,7 +26,7 @@ def _collate_and_decode_data( *, root: pathlib.Path, categories: List[str], - decoder, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: path, buffer = data data = decoder(buffer) if decoder else buffer @@ -50,8 +50,8 @@ def from_data_folder( root = pathlib.Path(root).expanduser().resolve() categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else "" - dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks) - dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) + dp = FileLister(str(root), recursive=recursive, masks=masks) + dp: IterDataPipe = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root)) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = FileLoader(dp) return ( diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py index d4897bebc91..f280d81a36f 100644 --- a/torchvision/prototype/datasets/decoder.py +++ b/torchvision/prototype/datasets/decoder.py @@ -1,4 +1,5 @@ import io +from typing import cast import PIL.Image import torch @@ -9,4 +10,4 @@ def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor: - return pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper())) + return cast(torch.Tensor, pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper()))) diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 19fb3b1d596..1937f51f9d5 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -15,6 +15,8 @@ NoReturn, Iterable, Tuple, + Iterator, + cast, ) import torch @@ -27,7 +29,7 @@ from ._resource import OnlineResource -def make_repr(name: str, items: Iterable[Tuple[str, Any]]): +def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str: def to_str(sep: str) -> str: return sep.join([f"{key}={value}" for key, value in items]) @@ -46,7 +48,7 @@ def to_str(sep: str) -> str: class DatasetConfig(Mapping): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: data = dict(*args, **kwargs) self.__dict__["__data__"] = data self.__dict__["__final_hash__"] = hash(tuple(data.items())) @@ -54,10 +56,10 @@ def __init__(self, *args, **kwargs): def __getitem__(self, name: str) -> Any: return self.__dict__["__data__"][name] - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return iter(self.__dict__["__data__"].keys()) - def __len__(self): + def __len__(self) -> int: return len(self.__dict__["__data__"]) def __getattr__(self, name: str) -> Any: @@ -81,7 +83,7 @@ def __delattr__(self, item: Any) -> NoReturn: raise RuntimeError(f"'{type(self).__name__}' object is immutable") def __hash__(self) -> int: - return self.__dict__["__final_hash__"] + return cast(int, self.__dict__["__final_hash__"]) def __eq__(self, other: Any) -> bool: if not isinstance(other, DatasetConfig): diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 3f372d0f5b7..522fd4e3c46 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -8,7 +8,7 @@ # FIXME -def compute_sha256(_) -> str: +def compute_sha256(path: pathlib.Path) -> str: return "" From be3babf3ea6be13fd0c0ff6c17038871c0635d7e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Oct 2021 08:29:46 +0200 Subject: [PATCH 02/11] fix code format --- torchvision/prototype/datasets/_builtin/caltech.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 50b0b96a88f..7c25037691f 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -121,8 +121,7 @@ def _make_datapipe( images_dp = TarArchiveReader(images_dp) images_dp: IterDataPipe = Filter(images_dp, self._is_not_background_image) - # FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved - # images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) + images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) anns_dp = TarArchiveReader(anns_dp) anns_dp: IterDataPipe = Filter(anns_dp, self._is_ann) @@ -190,8 +189,7 @@ def _make_datapipe( dp = resource_dps[0] dp = TarArchiveReader(dp) dp: IterDataPipe = Filter(dp, self._is_not_rogue_file) - # FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved - # dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) + dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: From b437fa31b9a3527ba6c76a6b6d09e95263801caf Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 20 Oct 2021 15:54:47 +0200 Subject: [PATCH 03/11] apply strictness only to datasets --- mypy.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index ffb113081f6..6069eece9bc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -5,7 +5,7 @@ show_error_codes = True pretty = True allow_redefinition = True -[mypy-torchvision.prototype.*] +[mypy-torchvision.prototype.datasets.*] ; untyped definitions and calls disallow_untyped_defs = True From 182a4ea6ffac6ac9e78d12215e67e056937ad497 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 20 Oct 2021 15:54:58 +0200 Subject: [PATCH 04/11] fix more mypy issues --- .../prototype/datasets/_builtin/celeba.py | 28 ++++++++++--------- .../prototype/datasets/_builtin/cifar.py | 10 +++---- .../prototype/datasets/_builtin/mnist.py | 20 ++++++------- .../prototype/datasets/_builtin/sbd.py | 22 +++++++++------ .../prototype/datasets/_builtin/voc.py | 10 +++++-- torchvision/prototype/datasets/benchmark.py | 2 ++ .../datasets/generate_category_files.py | 4 ++- .../prototype/datasets/utils/_internal.py | 6 ++-- 8 files changed, 59 insertions(+), 43 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index d86eaf27fab..1a604cd756e 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,6 +1,6 @@ import csv import io -from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union, Iterator import torch from torchdata.datapipes.iter import ( @@ -23,18 +23,20 @@ from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor +csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) + + class CelebACSVParser(IterDataPipe): def __init__( self, - datapipe, + datapipe: IterDataPipe[Tuple[str, io.IOBase]], *, - has_header, - ): + has_header: bool, + ) -> None: self.datapipe = datapipe self.has_header = has_header - self._fmtparams = dict(delimiter=" ", skipinitialspace=True) - def __iter__(self): + def __iter__(self) -> Iterator[Tuple[str, Union[Dict[str, str], List[str]]]]: for _, file in self.datapipe: file = (line.decode() for line in file) @@ -42,18 +44,18 @@ def __iter__(self): # The first row is skipped, because it only contains the number of samples next(file) - # Empty field names are filtered out, because some files have an extr white space after the header + # Empty field names are filtered out, because some files have an extra white space after the header # line, which is recognized as extra column - fieldnames = [name for name in next(csv.reader([next(file)], **self._fmtparams)) if name] + fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name] # Some files do not include a label for the image ID column if fieldnames[0] != "image_id": fieldnames.insert(0, "image_id") - for line in csv.DictReader(file, fieldnames=fieldnames, **self._fmtparams): - yield line.pop("image_id"), line + for line_dict in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"): + yield line_dict.pop("image_id"), line_dict else: - for line in csv.reader(file, **self._fmtparams): - yield line[0], line[1:] + for line_list in csv.reader(file, dialect="celeba"): + yield line_list[0], line_list[1:] class CelebA(Dataset): @@ -104,7 +106,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: "2": "test", } - def _filter_split(self, data: Tuple[str, str], *, split): + def _filter_split(self, data: Tuple[str, str], *, split: str) -> bool: _, split_id = data return self._SPLIT_ID_TO_NAME[split_id[0]] == split diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 4fbd993d311..4949c56f6cb 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -3,7 +3,7 @@ import io import pathlib import pickle -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast import numpy as np import torch @@ -56,7 +56,7 @@ def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) - def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: _, file = data - return pickle.load(file, encoding="latin1") + return cast(Dict[str, Any], pickle.load(file, encoding="latin1")) def _collate_and_decode( self, @@ -98,7 +98,7 @@ def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = TarArchiveReader(dp) dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME)) dp: IterDataPipe = Mapper(dp, self._unpickle) - return next(iter(dp))[self._CATEGORIES_KEY] + return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) class Cifar10(_CifarBase): @@ -133,9 +133,9 @@ class Cifar100(_CifarBase): _META_FILE_NAME = "meta" _CATEGORIES_KEY = "fine_label_names" - def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool: + def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: path = pathlib.Path(data[0]) - return path.name == config.split + return path.name == cast(str, config.split) @property def info(self) -> DatasetInfo: diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 2413a2fb084..9778c920f37 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -54,7 +54,7 @@ def __init__(self, datapipe: IterDataPipe, *, start: Optional[int], stop: Option self.stop = stop @staticmethod - def _decode(bytes): + def _decode(bytes: bytes) -> int: return int(codecs.encode(bytes, "hex"), 16) def __iter__(self) -> Iterator[np.ndarray]: @@ -107,7 +107,7 @@ def _collate_and_decode( *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ): + ) -> Dict[str, Any]: image_array, label_array = data image: Union[torch.Tensor, io.BytesIO] @@ -145,7 +145,7 @@ def _make_datapipe( class MNIST(_MNISTBase): @property - def info(self): + def info(self) -> DatasetInfo: return DatasetInfo( "mnist", type=DatasetType.RAW, @@ -176,7 +176,7 @@ def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], class FashionMNIST(MNIST): @property - def info(self): + def info(self) -> DatasetInfo: return DatasetInfo( "fashionmnist", type=DatasetType.RAW, @@ -209,7 +209,7 @@ def info(self): class KMNIST(MNIST): @property - def info(self): + def info(self) -> DatasetInfo: return DatasetInfo( "kmnist", type=DatasetType.RAW, @@ -231,7 +231,7 @@ def info(self): class EMNIST(_MNISTBase): @property - def info(self): + def info(self) -> DatasetInfo: return DatasetInfo( "emnist", type=DatasetType.RAW, @@ -295,7 +295,7 @@ def _collate_and_decode( *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ): + ) -> Dict[str, Any]: image_array, label_array = data # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, @@ -321,7 +321,7 @@ def _make_datapipe( images_dp, labels_dp = Demultiplexer( archive_dp, 2, - functools.partial(self._classify_archive, config=config), # type:ignore[arg-type] + functools.partial(self._classify_archive, config=config), drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) @@ -330,7 +330,7 @@ def _make_datapipe( class QMNIST(_MNISTBase): @property - def info(self): + def info(self) -> DatasetInfo: return DatasetInfo( "qmnist", type=DatasetType.RAW, @@ -381,7 +381,7 @@ def _collate_and_decode( *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ): + ) -> Dict[str, Any]: image_array, label_array = data label_parts = label_array.tolist() sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder) diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index c0244aa534a..4654a20c61f 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,7 +1,7 @@ import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, cast import numpy as np import torch @@ -135,7 +135,7 @@ def _make_datapipe( split_dp, images_dp, anns_dp = Demultiplexer( archive_dp, 3, - self._classify_archive, # type: ignore[arg-type] + self._classify_archive, buffer_size=INFINITE_BUFFER_SIZE, drop_none=True, ) @@ -165,9 +165,15 @@ def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: lines = tuple(zip(*iter(dp)))[1] pattern = re.compile(r"\s*'(?P\w+)';\s*%(?P