Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
ca39eb5
Add typing for apply_func
stancld Apr 26, 2021
4146be6
Add typing for argparse'
stancld Apr 26, 2021
5582fad
Add typing for cli
stancld Apr 26, 2021
1282cd2
Merge remote-tracking branch 'upstream/master' into typing_for_pl-uti…
stancld Apr 26, 2021
5a432dd
Add typing for cloud_io
stancld Apr 26, 2021
71da685
Add typing for data and debugging
stancld Apr 26, 2021
31a8315
Add typing for device_dtype_mixin.py
stancld Apr 26, 2021
cf1240c
Add typing for device_parser
stancld Apr 26, 2021
65d3bd5
Add typing for distributed
stancld Apr 26, 2021
e67d4c7
Add typing for imports
stancld Apr 26, 2021
a854722
Add typing for memory
stancld Apr 26, 2021
ff7c24a
Add typing for metrics
stancld Apr 26, 2021
bbb01c4
Add typing for model_helpers
stancld Apr 26, 2021
b2cdfc1
Add typing for parsing
stancld Apr 26, 2021
900d994
Add typing for parsing
stancld Apr 26, 2021
24b0125
Add typing to multiple files and fix a typo
stancld Apr 26, 2021
7ca6024
Add typing for xla_device
stancld Apr 26, 2021
eae8620
Merge remote-tracking branch 'upstream/master' into typing_for_pl-uti…
stancld Apr 26, 2021
7576c45
Add missing whitespace afer ':' in parsing.py
stancld Apr 26, 2021
ab50848
Merge branch 'master' into typing_for_pl-utilities
tchaton Apr 28, 2021
2cee78f
Add some missing typing
stancld Apr 28, 2021
8ed0117
Merge master into the branch and resolve conflict
stancld Apr 28, 2021
a0de5e3
Fix remote/master to the branch and resolve a conflict
stancld May 3, 2021
702b240
Merge branch 'master' into typing_for_pl-utilities
tchaton May 4, 2021
9b42400
Do some fixes after reviews
stancld May 4, 2021
69c7995
Merge upstream/master into typing_for_pl-utilities
stancld May 4, 2021
9cd7d19
Add some missing commas'
stancld May 4, 2021
923144d
Merge upstream/master branch and resolve conflict
stancld May 6, 2021
99811aa
Add import Set to device_parser.py
stancld May 6, 2021
1298ef5
Merge branch 'master' into typing_for_pl-utilities
justusschock May 7, 2021
7908c39
Remove unused import to pass PEP8
stancld May 7, 2021
198aa41
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2021
7fcbf96
Fix some issues after review
stancld May 8, 2021
fef2907
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2021
60d615d
Fix typo a make consistent typing in cli.py
stancld May 8, 2021
0854986
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2021
14b172d
* Reflect typing for recursive functions
stancld May 8, 2021
6690bb9
Capitalize the name of recursive-dict type
stancld May 8, 2021
644776a
mypy
Borda May 11, 2021
bc1f1de
Merge upstream/master to typing_for_pl-utilities
stancld May 15, 2021
e924379
Fix typing for enums.py and xla_device.py
stancld May 15, 2021
2659a92
Remove string types where not neede in 2 files
stancld May 15, 2021
0a16dae
Fix typing for utilities/argparse.py
stancld May 15, 2021
db1a7d1
Add missing typing for utilities/debugging.py
stancld May 15, 2021
1993c79
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2021
c6f2f4e
Change type of NotImplemented to Any
stancld May 15, 2021
9cb30cc
Fix mypy compatibility in a few files
stancld May 15, 2021
5630112
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2021
00a145b
[WIP] Fix mypy compatibility for parsing.py
stancld May 15, 2021
87be8a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2021
8eda497
Import Literal from typing_extensions to support python version < 3.8
stancld May 15, 2021
d198abf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2021
f1e0e7e
Remove unusued import and circular import
stancld May 15, 2021
94f7bf7
Fix another bunch of mypy issues
stancld May 15, 2021
3117438
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2021
1888663
typo
tchaton May 17, 2021
4ae0c14
Fix pl.LightningModule type in parsing.py
stancld May 18, 2021
4ee6a8f
Merge 'upstream/master' into typing_for_pl-utilities
stancld May 18, 2021
9a71d85
Add back deleted MisconfigurationException (device)
stancld May 21, 2021
38683f3
Merge upstream/master into typing_for_pl-utilities
stancld May 25, 2021
b7f9ca7
[WIP] Tackle some other mypy issues
stancld May 25, 2021
d2e06b3
Fix some issues after a review
stancld May 27, 2021
e8db47e
Merge 'upstream/master' into typing_for_pl-utilities
stancld May 27, 2021
eb902c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections.abc import Mapping, Sequence
from copy import copy
from functools import partial
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -33,13 +33,17 @@
Batch = type(None)


def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None):
def to_dtype_tensor(
value: Union[int, float, List[Union[int, float]]],
dtype: Optional[torch.dtype] = None,
device: Union[str, torch.device] = None,
) -> torch.Tensor:
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")
return torch.tensor(value, dtype=dtype, device=device)


def from_numpy(value, device: torch.device = None):
def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor:
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")
return torch.from_numpy(value).to(device)
Expand All @@ -56,11 +60,11 @@ def from_numpy(value, device: torch.device = None):

def apply_to_collection(
data: Any,
dtype: Union[type, tuple],
dtype: Union[type, Tuple[type]],
function: Callable,
*args,
wrong_dtype: Optional[Union[type, tuple]] = None,
**kwargs
*args: Any,
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
**kwargs: Any,
) -> Any:
"""
Recursively applies a function to all elements of a certain dtype.
Expand Down Expand Up @@ -123,14 +127,14 @@ class TransferableDataType(ABC):
"""

@classmethod
def __subclasshook__(cls, subclass):
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is TransferableDataType:
to = getattr(subclass, "to", None)
return callable(to)
return NotImplemented


def move_data_to_device(batch: Any, device: torch.device):
def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
"""
Transfers a collection of data to the given device. Any object that defines a method
``to(device)`` will be moved and all other objects in the collection will be left untouched.
Expand All @@ -148,7 +152,7 @@ def move_data_to_device(batch: Any, device: torch.device):
- :class:`torch.device`
"""

def batch_to(data):
def batch_to(data: Any) -> Any:
# try to move torchtext data first
if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):

Expand All @@ -168,14 +172,13 @@ def batch_to(data):
return apply_to_collection(batch, dtype=dtype, function=batch_to)


def convert_to_tensors(data, device: torch.device = None):
def convert_to_tensors(data: Any, device: Union[str, torch.device] = None) -> Any:
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")

for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device))

def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device):
def _move_to_device_and_make_contiguous(t: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor:
return t.to(device).contiguous()

data = apply_to_collection(data, torch.Tensor, partial(_move_to_device_and_make_contiguous, device=device))
Expand Down
35 changes: 18 additions & 17 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import os
from argparse import _ArgumentGroup, ArgumentParser, Namespace
from contextlib import suppress
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import pytorch_lightning as pl
from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
def from_argparse_args(cls: Type['pl.Trainer'], args: Union[Namespace, ArgumentParser], **kwargs: Any) -> 'pl.Trainer':
"""Create an instance from CLI arguments.
Eventually use varibles from OS environement which are defined as "PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"

Expand Down Expand Up @@ -52,7 +53,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
return cls(**trainer_kwargs)


def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
def parse_argparser(cls: Type['pl.Trainer'], arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
"""Parse CLI arguments, required for custom bool types."""
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser

Expand All @@ -77,7 +78,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
return Namespace(**modified_args)


def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
def parse_env_variables(cls: Type['pl.Trainer'], template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
"""Parse environment arguments if they are defined.

Example:
Expand Down Expand Up @@ -106,7 +107,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s")
return Namespace(**env_args)


def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
def get_init_arguments_and_types(cls: Type['pl.Trainer']) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the class signature and returns argument names, types and default values.

Returns:
Expand Down Expand Up @@ -134,7 +135,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
return name_type_default


def _get_abbrev_qualified_cls_name(cls):
def _get_abbrev_qualified_cls_name(cls: Type['pl.Trainer']) -> str:
assert isinstance(cls, type), repr(cls)
if cls.__module__.startswith("pytorch_lightning."):
# Abbreviate.
Expand All @@ -145,11 +146,11 @@ def _get_abbrev_qualified_cls_name(cls):


def add_argparse_args(
cls,
cls: Type['pl.Trainer'],
parent_parser: ArgumentParser,
*,
use_argument_group=True,
) -> ArgumentParser:
use_argument_group: bool = True,
) -> Union[_ArgumentGroup, ArgumentParser]:
r"""Extends existing argparse by default attributes for ``cls``.

Args:
Expand Down Expand Up @@ -189,7 +190,7 @@ def add_argparse_args(
raise RuntimeError("Please only pass an ArgumentParser instance.")
if use_argument_group:
group_name = _get_abbrev_qualified_cls_name(cls)
parser = parent_parser.add_argument_group(group_name)
parser: Union[_ArgumentGroup, ArgumentParser] = parent_parser.add_argument_group(group_name)
else:
parser = ArgumentParser(
parents=[parent_parser],
Expand All @@ -212,16 +213,16 @@ def add_argparse_args(
args_help = _parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__ or "")

for arg, arg_types, arg_default in args_and_types:
arg_types = [at for at in allowed_types if at in arg_types]
arg_types = tuple(at for at in allowed_types if at in arg_types)
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
arg_kwargs: Dict[str, Any] = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = str_to_bool
use_type: Callable[[str], Union[bool, int, float, str]] = str_to_bool
elif int in arg_types:
use_type = str_to_bool_or_int
elif str in arg_types:
Expand Down Expand Up @@ -260,7 +261,7 @@ def add_argparse_args(

def _parse_args_from_docstring(docstring: str) -> Dict[str, str]:
arg_block_indent = None
current_arg = None
current_arg = ''
parsed = {}
for line in docstring.split("\n"):
stripped = line.lstrip()
Expand All @@ -281,21 +282,21 @@ def _parse_args_from_docstring(docstring: str) -> Dict[str, str]:
return parsed


def _gpus_allowed_type(x) -> Union[int, str]:
def _gpus_allowed_type(x: str) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)


def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover
def _gpus_arg_default(x: str) -> None: # pragma: no-cover
# unused, but here for backward compatibility with old checkpoints that need to be able to
# unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8
# see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
pass


def _int_or_float_type(x) -> Union[int, float]:
def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]:
if '.' in str(x):
return float(x)
else:
Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning"""

def __init__(self, *args, parse_as_dict: bool = True, **kwargs) -> None:
def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input

For full details of accepted arguments see `ArgumentParser.__init__
Expand All @@ -53,7 +53,7 @@ def add_lightning_class_args(
self,
lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule]],
nested_key: str,
subclass_mode: bool = False
subclass_mode: bool = False,
) -> None:
"""
Adds arguments from a lightning class to a nested key of the parser
Expand Down Expand Up @@ -94,17 +94,17 @@ class LightningCLI:
def __init__(
self,
model_class: Type[LightningModule],
datamodule_class: Type[LightningDataModule] = None,
datamodule_class: Optional[Type[LightningDataModule]] = None,
save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback,
trainer_class: Type[Trainer] = Trainer,
trainer_defaults: Dict[str, Any] = None,
seed_everything_default: int = None,
trainer_defaults: Optional[Dict[str, Any]] = None,
seed_everything_default: Optional[int] = None,
description: str = 'pytorch-lightning trainer command line tool',
env_prefix: str = 'PL',
env_parse: bool = False,
parser_kwargs: Dict[str, Any] = None,
parser_kwargs: Optional[Dict[str, Any]] = None,
subclass_mode_model: bool = False,
subclass_mode_data: bool = False
subclass_mode_data: bool = False,
) -> None:
"""
Receives as input pytorch-lightning classes, which are instantiated
Expand Down
13 changes: 8 additions & 5 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@

import io
from pathlib import Path
from typing import IO, Union
from typing import Any, Dict, IO, Optional, Union

import fsspec
import torch
from fsspec.implementations.local import LocalFileSystem
from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem
from packaging.version import Version


def load(path_or_url: Union[str, IO, Path], map_location=None):
def load(
path_or_url: Union[str, IO, Path],
map_location: Optional[Union[str, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]] = None,
) -> Any:
if not isinstance(path_or_url, (str, Path)):
# any sort of BytesIO or similiar
return torch.load(path_or_url, map_location=map_location)
Expand All @@ -33,7 +36,7 @@ def load(path_or_url: Union[str, IO, Path], map_location=None):
return torch.load(f, map_location=map_location)


def get_filesystem(path: Union[str, Path]):
def get_filesystem(path: Union[str, Path]) -> AbstractFileSystem:
path = str(path)
if "://" in path:
# use the fileystem from the protocol specified
Expand All @@ -43,7 +46,7 @@ def get_filesystem(path: Union[str, Path]):
return LocalFileSystem()


def atomic_save(checkpoint, filepath: str):
def atomic_save(checkpoint: Any, filepath: str) -> None:
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.

Args:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_lightning.utilities import rank_zero_warn


def has_iterable_dataset(dataloader: DataLoader):
def has_iterable_dataset(dataloader: DataLoader) -> bool:
return hasattr(dataloader, 'dataset') and isinstance(dataloader.dataset, IterableDataset)


Expand Down
Loading