From ca39eb50bd55215491291466a3558f84f180d9ed Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 13:06:36 +0200 Subject: [PATCH 01/51] Add typing for apply_func --- pytorch_lightning/utilities/apply_func.py | 30 +++++++++++++---------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index e100a803bcd00..d384c0edbd3bf 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -16,7 +16,7 @@ from collections.abc import Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -33,13 +33,17 @@ Batch = type(None) -def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None): +def to_dtype_tensor( + value: Union[int, float, List[Union[int, float]]], + dtype: Optional[torch.dtype] = None, + device: Union[str, torch.device] = None +) -> torch.Tensor: if device is None: raise MisconfigurationException("device (torch.device) should be provided.") return torch.tensor(value, dtype=dtype, device=device) -def from_numpy(value, device: torch.device = None): +def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor: if device is None: raise MisconfigurationException("device (torch.device) should be provided.") return torch.from_numpy(value).to(device) @@ -56,11 +60,11 @@ def from_numpy(value, device: torch.device = None): def apply_to_collection( data: Any, - dtype: Union[type, tuple], + dtype: Union[torch.dtype, Tuple[torch.dtype]], function: Callable, - *args, - wrong_dtype: Optional[Union[type, tuple]] = None, - **kwargs + *args: Any, + wrong_dtype: Optional[Union[torch.dtype, Tuple[torch.dtype]]] = None, + **kwargs: Any ) -> Any: """ Recursively applies a function to all elements of a certain dtype. @@ -87,7 +91,7 @@ def apply_to_collection( if isinstance(data, Mapping): return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()}) - if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple + if isinstance(data, Tuple) and hasattr(data, '_fields'): # named tuple return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) if isinstance(data, Sequence) and not isinstance(data, str): @@ -116,14 +120,14 @@ class TransferableDataType(ABC): """ @classmethod - def __subclasshook__(cls, subclass): + def __subclasshook__(cls, subclass: Any) -> Union[bool, type(NotImplemented)]: if cls is TransferableDataType: to = getattr(subclass, "to", None) return callable(to) return NotImplemented -def move_data_to_device(batch: Any, device: torch.device): +def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: """ Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be moved and all other objects in the collection will be left untouched. @@ -141,7 +145,7 @@ def move_data_to_device(batch: Any, device: torch.device): - :class:`torch.device` """ - def batch_to(data): + def batch_to(data: Any) -> Any: # try to move torchtext data first if _TORCHTEXT_AVAILABLE and isinstance(data, Batch): @@ -161,14 +165,14 @@ def batch_to(data): return apply_to_collection(batch, dtype=dtype, function=batch_to) -def convert_to_tensors(data, device: torch.device = None): +def convert_to_tensors(data: Any, device: Union[str, torch.device] = None) -> Any: if device is None: raise MisconfigurationException("device (torch.device) should be provided.") for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) - def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device): + def _move_to_device_and_make_contiguous(t: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor: return t.to(device).contiguous() data = apply_to_collection(data, torch.Tensor, partial(_move_to_device_and_make_contiguous, device=device)) From 4146be6cdac5883cc4067429d6fcd3591fef7ac0 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 13:13:58 +0200 Subject: [PATCH 02/51] Add typing for argparse' --- pytorch_lightning/utilities/argparse.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index dc99b923c6702..4a7f96f4447b3 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -17,10 +17,11 @@ from contextlib import suppress from typing import Any, Dict, List, Tuple, Union +import pytorch_lightning as pl from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_str -def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): +def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs: Any) -> 'pl.Trainer': """Create an instance from CLI arguments. Eventually use varibles from OS environement which are defined as "PL__" @@ -134,7 +135,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: return name_type_default -def _get_abbrev_qualified_cls_name(cls): +def _get_abbrev_qualified_cls_name(cls) -> str: assert isinstance(cls, type), repr(cls) if cls.__module__.startswith("pytorch_lightning."): # Abbreviate. @@ -148,7 +149,7 @@ def add_argparse_args( cls, parent_parser: ArgumentParser, *, - use_argument_group=True, + use_argument_group: bool = True, ) -> ArgumentParser: r"""Extends existing argparse by default attributes for ``cls``. @@ -279,21 +280,21 @@ def _parse_args_from_docstring(docstring: str) -> Dict[str, str]: return parsed -def _gpus_allowed_type(x) -> Union[int, str]: +def _gpus_allowed_type(x: str) -> Union[int, str]: if ',' in x: return str(x) else: return int(x) -def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover +def _gpus_arg_default(x: str) -> None: # pragma: no-cover # unused, but here for backward compatibility with old checkpoints that need to be able to # unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8 # see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 pass -def _int_or_float_type(x) -> Union[int, float]: +def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]: if '.' in str(x): return float(x) else: From 5582fad548794719ee4c2ef09fc976f8c7b58781 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 13:23:04 +0200 Subject: [PATCH 03/51] Add typing for cli --- pytorch_lightning/utilities/cli.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 93729a53db25d..730fe2798ef24 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -13,8 +13,9 @@ # limitations under the License. import os from argparse import Namespace -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Union +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -32,7 +33,7 @@ class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning""" - def __init__(self, *args, parse_as_dict: bool = True, **kwargs) -> None: + def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input For full details of accepted arguments see `ArgumentParser.__init__ @@ -50,7 +51,7 @@ def __init__(self, *args, parse_as_dict: bool = True, **kwargs) -> None: def add_lightning_class_args( self, - lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule]], + lightning_class: Union['pl.Trainer', 'pl.LightningModule', 'pl.LightningDataModule'], nested_key: str, subclass_mode: bool = False ) -> None: @@ -92,16 +93,16 @@ class LightningCLI: def __init__( self, - model_class: Type[LightningModule], - datamodule_class: Type[LightningDataModule] = None, - save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback, - trainer_class: Type[Trainer] = Trainer, + model_class: 'pl.LightningModule', + datamodule_class: Optional['pl.LightningDataModule'] = None, + save_config_callback: SaveConfigCallback = SaveConfigCallback, + trainer_class: 'pl.Trainer' = Trainer, trainer_defaults: Dict[str, Any] = None, seed_everything_default: int = None, description: str = 'pytorch-lightning trainer command line tool', env_prefix: str = 'PL', env_parse: bool = False, - parser_kwargs: Dict[str, Any] = None, + parser_kwargs: Optional[Dict[str, Any]] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False ) -> None: From 5a432ddd431d8b4d27bc5298f5d1251e4bcdb875 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 13:48:54 +0200 Subject: [PATCH 04/51] Add typing for cloud_io --- pytorch_lightning/utilities/cloud_io.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 8b65ca143976e..07d52d88d88c5 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -15,7 +15,7 @@ import io import os from pathlib import Path -from typing import IO, Union +from typing import Any, Dict, IO, Optional, Union import fsspec import torch @@ -34,7 +34,10 @@ def isdir(self, path: str) -> bool: return os.path.isdir(path) # follows symlinks -def load(path_or_url: Union[str, IO, Path], map_location=None): +def load( + path_or_url: Union[str, IO, Path], + map_location: Optional[Union[str, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]] = None +) -> Any: if not isinstance(path_or_url, (str, Path)): # any sort of BytesIO or similiar return torch.load(path_or_url, map_location=map_location) @@ -45,7 +48,7 @@ def load(path_or_url: Union[str, IO, Path], map_location=None): return torch.load(f, map_location=map_location) -def get_filesystem(path: Union[str, Path]): +def get_filesystem(path: Union[str, Path]) -> LocalFileSystem: path = str(path) if "://" in path: # use the fileystem from the protocol specified @@ -55,7 +58,7 @@ def get_filesystem(path: Union[str, Path]): return _LightningLocalFileSystem() -def atomic_save(checkpoint, filepath: str): +def atomic_save(checkpoint: Any, filepath: str) -> None: """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. Args: From 71da6853a47267d6cb2e294c183c4bfa340b8638 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 15:18:17 +0200 Subject: [PATCH 05/51] Add typing for data and debugging * [WIP] Need to verify if logged metrics are of a type torch.Tensor or whether Any can be passed in --- pytorch_lightning/utilities/data.py | 2 +- pytorch_lightning/utilities/debugging.py | 46 +++++++++++++++--------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 27345fda3b110..8f22d1668d831 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -19,7 +19,7 @@ from pytorch_lightning.utilities import rank_zero_warn -def has_iterable_dataset(dataloader: DataLoader): +def has_iterable_dataset(dataloader: DataLoader) -> bool: return hasattr(dataloader, 'dataset') and isinstance(dataloader.dataset, IterableDataset) diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 56833fd03735a..58fb1aeffc2ef 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -16,10 +16,15 @@ import time from collections import Counter from functools import wraps -from typing import Any, Callable, Optional +from typing import Any, Dict, Callable, List, Optional +import torch +from torch.utils.data import DataLoader -def enabled_only(fn: Callable): +import pytorch_lightning as pl + + +def enabled_only(fn: Callable) -> Optional[Callable]: """Decorate a logger method to run it only on the process with rank 0. Args: @@ -27,7 +32,7 @@ def enabled_only(fn: Callable): """ @wraps(fn) - def wrapped_fn(self, *args, **kwargs): + def wrapped_fn(self, *args: Any, **kwargs: Any) -> Optional[Any]: if self.enabled: fn(self, *args, **kwargs) @@ -36,7 +41,7 @@ def wrapped_fn(self, *args, **kwargs): class InternalDebugger(object): - def __init__(self, trainer): + def __init__(self, trainer: 'pl.Trainer') -> None: self.enabled = os.environ.get('PL_DEV_DEBUG', '0') == '1' self.trainer = trainer self.logged_metrics = [] @@ -56,7 +61,7 @@ def __init__(self, trainer): def track_event( self, evt_type: str, - evt_value: Any = None, + evt_value: Optional[Any] = None, global_rank: Optional[int] = None, local_rank: Optional[int] = None, comment: str = '' @@ -70,7 +75,7 @@ def track_event( "comment": comment, }) - def count_events(self, evt_type: str, strict=False) -> int: + def count_events(self, evt_type: str, strict: bool = False) -> int: count = 0 for evt in self.events: if strict and evt["event"] == evt_type: @@ -80,7 +85,7 @@ def count_events(self, evt_type: str, strict=False) -> int: return count @enabled_only - def track_load_dataloader_call(self, name, dataloaders): + def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]): loader_counts = len(dataloaders) lengths = [] @@ -111,18 +116,25 @@ def track_load_dataloader_call(self, name, dataloaders): self.test_dataloader_calls.append(values) @enabled_only - def track_logged_metrics_history(self, scalar_metrics): + def track_logged_metrics_history(self, scalar_metrics: Dict[str, torch.Tensor]) -> None: scalar_metrics['global_step'] = self.trainer.global_step self.logged_metrics.append(scalar_metrics) @enabled_only - def track_train_loss_history(self, batch_idx, loss): + def track_train_loss_history(self, batch_idx: int, loss: torch.Tensor) -> None: loss_dict = {'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, 'loss': loss.detach()} self.saved_train_losses.append(loss_dict) @enabled_only def track_lr_schedulers_update( - self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None + self, + batch_idx: int, + interval: int, + scheduler_idx: int, + old_lr: float, + new_lr: float, + monitor_key: Optional[str] = None, + monitor_val: Optional[torch.Tensor] = None ): loss_dict = { 'batch_idx': batch_idx, @@ -137,7 +149,7 @@ def track_lr_schedulers_update( self.saved_lr_scheduler_updates.append(loss_dict) @enabled_only - def track_eval_loss_history(self, batch_idx, dataloader_idx, output): + def track_eval_loss_history(self, batch_idx: int, dataloader_idx: int, output: torch.Tensor): loss_dict = { 'sanity_check': self.trainer.sanity_checking, 'dataloader_idx': dataloader_idx, @@ -152,12 +164,12 @@ def track_eval_loss_history(self, batch_idx, dataloader_idx, output): self.saved_val_losses.append(loss_dict) @enabled_only - def track_pbar_metrics_history(self, metrics): + def track_pbar_metrics_history(self, metrics: Dict[str, torch.Tensor]) -> None: metrics['debug_epoch'] = self.trainer.current_epoch self.pbar_added_metrics.append(metrics) @enabled_only - def track_early_stopping_history(self, callback, current): + def track_early_stopping_history(self, callback: 'pl.Callback', current: torch.Tensor) -> None: debug_dict = { 'epoch': self.trainer.current_epoch, 'global_step': self.trainer.global_step, @@ -169,7 +181,7 @@ def track_early_stopping_history(self, callback, current): self.early_stopping_history.append(debug_dict) @enabled_only - def track_checkpointing_history(self, filepath): + def track_checkpointing_history(self, filepath: str) -> None: cb = self.trainer.checkpoint_callback debug_dict = { 'epoch': self.trainer.current_epoch, @@ -181,12 +193,12 @@ def track_checkpointing_history(self, filepath): self.checkpoint_callback_history.append(debug_dict) @property - def num_seen_sanity_check_batches(self): + def num_seen_sanity_check_batches(self) -> int: count = len([x for x in self.saved_val_losses if x['sanity_check']]) return count @property - def num_seen_val_check_batches(self): + def num_seen_val_check_batches(self) -> Counter: counts = Counter() for x in self.saved_val_losses: if not x['sanity_check']: @@ -194,7 +206,7 @@ def num_seen_val_check_batches(self): return counts @property - def num_seen_test_check_batches(self): + def num_seen_test_check_batches(self) -> Counter: counts = Counter() for x in self.saved_test_losses: if not x['sanity_check']: From 31a8315bad4deadadb361e1723b2e178ac421520 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 21:01:34 +0200 Subject: [PATCH 06/51] Add typing for device_dtype_mixin.py --- pytorch_lightning/utilities/device_dtype_mixin.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index f08a421ce1844..77c48d36d5cfd 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Any, Optional, Union import torch from torch.nn import Module @@ -22,7 +22,7 @@ class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ['device', 'dtype'] - def __init__(self): + def __init__(self) -> None: super().__init__() self._dtype = torch.get_default_dtype() self._device = torch.device('cpu') @@ -47,7 +47,7 @@ def device(self) -> Union[str, torch.device]: return device @parameter_validation - def to(self, *args, **kwargs) -> Module: + def to(self, *args: Any, **kwargs: Any) -> Module: """Moves and/or casts the parameters and buffers. This can be called as @@ -178,7 +178,7 @@ def half(self) -> Module: self.__update_properties(dtype=torch.half) return super().half() - def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): + def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: def apply_fn(module): if not isinstance(module, DeviceDtypeModuleMixin): From cf1240cf8f3040341455f97f52002ccc13d9d350 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 21:04:39 +0200 Subject: [PATCH 07/51] Add typing for device_parser --- pytorch_lightning/utilities/device_parser.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index f81a4ece1c6d0..dee559b717f60 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, MutableSequence, Optional, Tuple, Union +from typing import Any, List, MutableSequence, Optional, Set, Tuple, Union import torch @@ -171,13 +171,13 @@ def _check_data_type(device_ids: Any) -> None: raise MisconfigurationException("Device ID's (GPU/TPU) must be int, string or sequence of ints or None.") -def _tpu_cores_valid(tpu_cores): +def _tpu_cores_valid(tpu_cores: Optional[Union[int, List[int], Tuple[int], Set[int]]]) -> bool: # allow 1 or 8 cores if tpu_cores in (1, 8, None): return True # allow picking 1 of 8 indexes - if isinstance(tpu_cores, (list, tuple, set)): + if isinstance(tpu_cores, (List, Tuple, Set)): has_1_tpu_idx = len(tpu_cores) == 1 is_valid_tpu_idx = tpu_cores[0] in range(1, 9) @@ -187,7 +187,7 @@ def _tpu_cores_valid(tpu_cores): return False -def _parse_tpu_cores_str(tpu_cores): +def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]: if tpu_cores in ('1', '8'): tpu_cores = int(tpu_cores) else: From 65d3bd5d9d46d639abcb9d1969da02501aa02d4e Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 21:14:12 +0200 Subject: [PATCH 08/51] Add typing for distributed --- pytorch_lightning/utilities/distributed.py | 28 +++++++++++----------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index a54d00a983d9e..1df011f80e38b 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -16,7 +16,7 @@ import os import warnings from functools import partial, wraps -from typing import Any, Optional, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel @@ -41,10 +41,10 @@ class group: log = logging.getLogger(__name__) -def rank_zero_only(fn): +def rank_zero_only(fn: Callable) -> Optional[Callable]: @wraps(fn) - def wrapped_fn(*args, **kwargs): + def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: if rank_zero_only.rank == 0: return fn(*args, **kwargs) @@ -65,15 +65,15 @@ def _get_rank() -> int: rank_zero_only.rank = getattr(rank_zero_only, 'rank', _get_rank()) -def _warn(*args, **kwargs): +def _warn(*args: Any, **kwargs: Any) -> None: warnings.warn(*args, **kwargs) -def _info(*args, **kwargs): +def _info(*args: Any, **kwargs: Any) -> None: log.info(*args, **kwargs) -def _debug(*args, **kwargs): +def _debug(*args: Any, **kwargs: Any) -> None: log.debug(*args, **kwargs) @@ -83,7 +83,7 @@ def _debug(*args, **kwargs): rank_zero_deprecation = partial(rank_zero_warn, category=DeprecationWarning) -def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): +def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None) -> List[torch.Tensor]: """ Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes @@ -114,7 +114,7 @@ def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None) def sync_ddp_if_available( - result: Union[torch.Tensor], + result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: @@ -135,7 +135,7 @@ def sync_ddp_if_available( def sync_ddp( - result: Union[torch.Tensor], + result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: @@ -174,7 +174,7 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward(ctx, tensor, group=group.WORLD): + def forward(ctx: object, tensor: torch.Tensor, group=group.WORLD) -> torch.Tensor: ctx.group = group gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] @@ -185,7 +185,7 @@ def forward(ctx, tensor, group=group.WORLD): return gathered_tensor @staticmethod - def backward(ctx, *grad_output): + def backward(ctx: object, *grad_output: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, None]: grad_output = torch.cat(grad_output) torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) @@ -194,7 +194,7 @@ def backward(ctx, *grad_output): def all_gather_ddp_if_available( - tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False + tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False ) -> torch.Tensor: """ Function to gather a tensor from several distributed processes @@ -220,8 +220,8 @@ def all_gather_ddp_if_available( def register_ddp_comm_hook( model: DistributedDataParallel, ddp_comm_state: Optional[object] = None, - ddp_comm_hook: Optional[callable] = None, - ddp_comm_wrapper: Optional[callable] = None, + ddp_comm_hook: Optional[Callable] = None, + ddp_comm_wrapper: Optional[Callable] = None, ) -> None: """ Function to register communication hook for DDP model From e67d4c771a19fb8b71262b7e069b00c1172595a0 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 21:18:29 +0200 Subject: [PATCH 09/51] Add typing for imports --- pytorch_lightning/utilities/imports.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 791cef7ff2665..4b92a277e0ca7 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -17,6 +17,7 @@ import platform import sys from importlib.util import find_spec +from typing import Callable import torch from packaging.version import Version @@ -42,7 +43,7 @@ def _module_available(module_path: str) -> bool: return False -def _compare_version(package: str, op, version) -> bool: +def _compare_version(package: str, op: Callable, version: str) -> bool: """ Compare package version with some requirements From a854722e090f2e5f828421cc009fab94a3da1282 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 21:24:11 +0200 Subject: [PATCH 10/51] Add typing for memory --- pytorch_lightning/utilities/memory.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 6c01390a8c81e..0ee5e4ac4c9ce 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -13,11 +13,15 @@ # limitations under the License. import gc +from typing import Any, Dict, Union import torch -def recursive_detach(in_dict: dict, to_cpu: bool = False) -> dict: +def recursive_detach( + in_dict: Dict[Any, Union[torch.Tensor, Dict[Any, torch.Tensor]]], + to_cpu: bool = False +) -> Dict[Any, torch.Tensor]: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries @@ -43,14 +47,14 @@ def recursive_detach(in_dict: dict, to_cpu: bool = False) -> dict: return out_dict -def is_oom_error(exception): +def is_oom_error(exception: Any) -> bool: return is_cuda_out_of_memory(exception) \ or is_cudnn_snafu(exception) \ or is_out_of_cpu_memory(exception) # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py -def is_cuda_out_of_memory(exception): +def is_cuda_out_of_memory(exception: Any) -> bool: return isinstance(exception, RuntimeError) \ and len(exception.args) == 1 \ and "CUDA" in exception.args[0] \ @@ -58,7 +62,7 @@ def is_cuda_out_of_memory(exception): # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py -def is_cudnn_snafu(exception): +def is_cudnn_snafu(exception: Any) -> bool: # For/because of https://github.com/pytorch/pytorch/issues/4107 return isinstance(exception, RuntimeError) \ and len(exception.args) == 1 \ @@ -66,14 +70,14 @@ def is_cudnn_snafu(exception): # based on https://github.com/BlackHC/toma/blob/master/toma/cpu_memory.py -def is_out_of_cpu_memory(exception): +def is_out_of_cpu_memory(exception: Any) -> bool: return isinstance(exception, RuntimeError) \ and len(exception.args) == 1 \ and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py -def garbage_collection_cuda(): +def garbage_collection_cuda() -> None: """Garbage collection Torch (CUDA) memory.""" gc.collect() if torch.cuda.is_available(): From ff7c24a7f94ed836ef9f52829e332d8313a770be Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 21:28:40 +0200 Subject: [PATCH 11/51] Add typing for metrics --- pytorch_lightning/utilities/metrics.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/metrics.py b/pytorch_lightning/utilities/metrics.py index bd57470dc270e..e547e69894a05 100644 --- a/pytorch_lightning/utilities/metrics.py +++ b/pytorch_lightning/utilities/metrics.py @@ -13,12 +13,14 @@ # limitations under the License. """Helper functions to operate on metric values. """ +from typing import Dict, Union + import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -def metrics_to_scalars(metrics: dict) -> dict: +def metrics_to_scalars(metrics: Dict[str, Union[torch.Tensor, dict]]) -> Dict[str, float]: """ Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """ # TODO: this is duplicated in MetricsHolder. should be unified From bbb01c4a7f4987ae7faaf76b0a544c0ea6adf020 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 21:30:36 +0200 Subject: [PATCH 12/51] Add typing for model_helpers --- pytorch_lightning/utilities/model_helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 87bd9e6c4545d..6aaeb96dec599 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -14,11 +14,12 @@ from typing import Union +import pytorch_lightning as pl from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule -def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool: +def is_overridden(method_name: str, model: Union['pl.LightningModule', 'pl.LightningDataModule']) -> bool: # if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super # TODO - refector this function to accept model_name, instance, parent so it makes more sense super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule From b2cdfc1987c84723d16e3b6c434d548a29d9c826 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 21:45:55 +0200 Subject: [PATCH 13/51] Add typing for parsing --- pytorch_lightning/utilities/parsing.py | 31 +++++++++++++------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index ae83ba15a9c52..9fdc5f20432ab 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -16,8 +16,9 @@ import pickle import types from argparse import Namespace -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_warn @@ -66,7 +67,7 @@ def is_picklable(obj: object) -> bool: return False -def clean_namespace(hparams): +def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None: """Removes all unpicklable entries from hparams""" hparams_dict = hparams @@ -96,7 +97,7 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]: # self is always first n_self = init_params[0].name - def _get_first_if_any(params, param_type): + def _get_first_if_any(params: Sequence[Namespace], param_type: Type) -> str: for p in params: if p.kind == param_type: return p.name @@ -107,7 +108,7 @@ def _get_first_if_any(params, param_type): return n_self, n_args, n_kwargs -def get_init_args(frame) -> dict: +def get_init_args(frame) -> Dict[str, Any]: _, _, _, local_vars = inspect.getargvalues(frame) if '__class__' not in local_vars: return {} @@ -123,7 +124,7 @@ def get_init_args(frame) -> dict: return local_args -def collect_init_args(frame, path_args: list, inside: bool = False) -> list: +def collect_init_args(frame, path_args: List[Dict[str, Any]], inside: bool = False) -> List[Dict[str, Any]]: """ Recursively collects the arguments passed to the child constructors in the inheritance tree. @@ -149,7 +150,7 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list: return path_args -def flatten_dict(source, result=None): +def flatten_dict(source: Dict[str, Any], result=None) -> Dict[str, Any]: if result is None: result = {} @@ -164,7 +165,7 @@ def flatten_dict(source, result=None): def save_hyperparameters( obj: Any, - *args, + *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None ) -> None: @@ -219,16 +220,16 @@ class AttributeDict(Dict): "my-key": 3.14 """ - def __getattr__(self, key): + def __getattr__(self, key: str) -> Optional[Any]: try: return self[key] except KeyError as exp: raise AttributeError(f'Missing attribute "{key}"') from exp - def __setattr__(self, key, val): + def __setattr__(self, key: str, val: Any) -> None: self[key] = val - def __repr__(self): + def __repr__(self) -> str: if not len(self): return "" max_key_length = max([len(str(k)) for k in self]) @@ -238,7 +239,7 @@ def __repr__(self): return out -def _lightning_get_all_attr_holders(model, attribute): +def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) -> Any: """ Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -263,7 +264,7 @@ def _lightning_get_all_attr_holders(model, attribute): return holders -def _lightning_get_first_attr_holder(model, attribute): +def _lightning_get_first_attr_holder(model: 'pl.LightningModule', attribute: str) -> Optional[Any]: """ Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, @@ -276,7 +277,7 @@ def _lightning_get_first_attr_holder(model, attribute): return holders[-1] -def lightning_hasattr(model, attribute): +def lightning_hasattr(model: 'pl.LightningModule', attribute: str) -> bool: """ Special hasattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -284,7 +285,7 @@ def lightning_hasattr(model, attribute): return _lightning_get_first_attr_holder(model, attribute) is not None -def lightning_getattr(model, attribute): +def lightning_getattr(model: 'pl.LightningModule', attribute: str) -> Any: """ Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -306,7 +307,7 @@ def lightning_getattr(model, attribute): return getattr(holder, attribute) -def lightning_setattr(model, attribute, value): +def lightning_setattr(model: 'pl.LightningModule', attribute: str, value: Any) -> None: """ Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict. From 900d9949af7561eac5ad51a8927e97edd6f76936 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 21:46:08 +0200 Subject: [PATCH 14/51] Add typing for parsing --- pytorch_lightning/utilities/parsing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 9fdc5f20432ab..12d9f68597ae7 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -108,7 +108,7 @@ def _get_first_if_any(params: Sequence[Namespace], param_type: Type) -> str: return n_self, n_args, n_kwargs -def get_init_args(frame) -> Dict[str, Any]: +def get_init_args(frame: object) -> Dict[str, Any]: _, _, _, local_vars = inspect.getargvalues(frame) if '__class__' not in local_vars: return {} @@ -124,7 +124,7 @@ def get_init_args(frame) -> Dict[str, Any]: return local_args -def collect_init_args(frame, path_args: List[Dict[str, Any]], inside: bool = False) -> List[Dict[str, Any]]: +def collect_init_args(frame:object, path_args: List[Dict[str, Any]], inside: bool = False) -> List[Dict[str, Any]]: """ Recursively collects the arguments passed to the child constructors in the inheritance tree. @@ -150,7 +150,7 @@ def collect_init_args(frame, path_args: List[Dict[str, Any]], inside: bool = Fal return path_args -def flatten_dict(source: Dict[str, Any], result=None) -> Dict[str, Any]: +def flatten_dict(source: Dict[str, Any], result: Optioanl[Dict[str, Any]] = None) -> Dict[str, Any]: if result is None: result = {} From 24b0125ff1c524813eddf7fd24ad111443fcc66a Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 21:56:52 +0200 Subject: [PATCH 15/51] Add typing to multiple files and fix a typo * Files: seed, upgrade_checkpoint, warnings * Fix typo in parsing --- pytorch_lightning/utilities/parsing.py | 2 +- pytorch_lightning/utilities/seed.py | 2 +- pytorch_lightning/utilities/upgrade_checkpoint.py | 2 +- pytorch_lightning/utilities/warnings.py | 8 +++++--- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 12d9f68597ae7..4224e1c7f6630 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -150,7 +150,7 @@ def collect_init_args(frame:object, path_args: List[Dict[str, Any]], inside: boo return path_args -def flatten_dict(source: Dict[str, Any], result: Optioanl[Dict[str, Any]] = None) -> Dict[str, Any]: +def flatten_dict(source: Dict[str, Any], result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: if result is None: result = {} diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index b7eaba72c1b02..0fc149d0b7d04 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -78,7 +78,7 @@ def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> return random.randint(min_seed_value, max_seed_value) -def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover +def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover """ The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed with ``seed_everything(seed, workers=True)``. diff --git a/pytorch_lightning/utilities/upgrade_checkpoint.py b/pytorch_lightning/utilities/upgrade_checkpoint.py index 4896845f10263..d83a682511188 100644 --- a/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -30,7 +30,7 @@ log = logging.getLogger(__name__) -def upgrade_checkpoint(filepath): +def upgrade_checkpoint(filepath: str) -> None: checkpoint = torch.load(filepath) checkpoint["callbacks"] = checkpoint.get("callbacks") or {} diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py index a3dde95fa928f..6dfb67133cf19 100644 --- a/pytorch_lightning/utilities/warnings.py +++ b/pytorch_lightning/utilities/warnings.py @@ -11,18 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + from pytorch_lightning.utilities.distributed import rank_zero_warn class WarningCache: - def __init__(self): + def __init__(self) -> None: self.warnings = set() - def warn(self, m, *args, **kwargs): + def warn(self, m: Any, *args: Any, **kwargs: Any) -> None: if m not in self.warnings: self.warnings.add(m) rank_zero_warn(m, *args, **kwargs) - def clear(self): + def clear(self) -> None: self.warnings.clear() From 7ca60247c253540e67747595ec3a13d0dd4be6d9 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 22:03:45 +0200 Subject: [PATCH 16/51] Add typing for xla_device --- pytorch_lightning/utilities/xla_device.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 49ec176d4cdbb..8e2836c59bdae 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -16,6 +16,7 @@ import queue as q import traceback from multiprocessing import Process, Queue +from typing import Any, Callable, Union from pytorch_lightning.utilities.imports import _XLA_AVAILABLE @@ -26,7 +27,7 @@ TPU_CHECK_TIMEOUT = 25 -def inner_f(queue, func, *args, **kwargs): # pragma: no cover +def inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover try: queue.put(func(*args, **kwargs)) # todo: specify the possible exception @@ -35,10 +36,10 @@ def inner_f(queue, func, *args, **kwargs): # pragma: no cover queue.put(None) -def pl_multi_process(func): +def pl_multi_process(func: Callable) -> Callable: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Union[Any, bool]: queue = Queue() proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) proc.start() From 7576c450146d98a83c78d30a4c0f430088e82331 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 26 Apr 2021 22:25:08 +0200 Subject: [PATCH 17/51] Add missing whitespace afer ':' in parsing.py --- pytorch_lightning/utilities/parsing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 4224e1c7f6630..765b548932675 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -124,7 +124,7 @@ def get_init_args(frame: object) -> Dict[str, Any]: return local_args -def collect_init_args(frame:object, path_args: List[Dict[str, Any]], inside: bool = False) -> List[Dict[str, Any]]: +def collect_init_args(frame: object, path_args: List[Dict[str, Any]], inside: bool = False) -> List[Dict[str, Any]]: """ Recursively collects the arguments passed to the child constructors in the inheritance tree. From 2cee78f75af4d941ff53fb1c17a6fe813637cd18 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Wed, 28 Apr 2021 16:37:53 +0200 Subject: [PATCH 18/51] Add some missing typing --- pytorch_lightning/utilities/cli.py | 4 ++-- pytorch_lightning/utilities/debugging.py | 6 +++--- pytorch_lightning/utilities/device_dtype_mixin.py | 2 +- pytorch_lightning/utilities/distributed.py | 2 +- pytorch_lightning/utilities/finite_checks.py | 6 +++--- pytorch_lightning/utilities/memory.py | 2 +- pytorch_lightning/utilities/seed.py | 2 +- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 730fe2798ef24..7b306d58dc4ed 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -97,8 +97,8 @@ def __init__( datamodule_class: Optional['pl.LightningDataModule'] = None, save_config_callback: SaveConfigCallback = SaveConfigCallback, trainer_class: 'pl.Trainer' = Trainer, - trainer_defaults: Dict[str, Any] = None, - seed_everything_default: int = None, + trainer_defaults: Optional[Dict[str, Any]] = None, + seed_everything_default: Optional[int] = None, description: str = 'pytorch-lightning trainer command line tool', env_prefix: str = 'PL', env_parse: bool = False, diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 58fb1aeffc2ef..d441b025388fe 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -85,7 +85,7 @@ def count_events(self, evt_type: str, strict: bool = False) -> int: return count @enabled_only - def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]): + def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]) -> None: loader_counts = len(dataloaders) lengths = [] @@ -135,7 +135,7 @@ def track_lr_schedulers_update( new_lr: float, monitor_key: Optional[str] = None, monitor_val: Optional[torch.Tensor] = None - ): + ) -> None: loss_dict = { 'batch_idx': batch_idx, 'interval': interval, @@ -149,7 +149,7 @@ def track_lr_schedulers_update( self.saved_lr_scheduler_updates.append(loss_dict) @enabled_only - def track_eval_loss_history(self, batch_idx: int, dataloader_idx: int, output: torch.Tensor): + def track_eval_loss_history(self, batch_idx: int, dataloader_idx: int, output: torch.Tensor) -> None: loss_dict = { 'sanity_check': self.trainer.sanity_checking, 'dataloader_idx': dataloader_idx, diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index 77c48d36d5cfd..27d2b352bf4ff 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -32,7 +32,7 @@ def dtype(self) -> Union[str, torch.dtype]: return self._dtype @dtype.setter - def dtype(self, new_dtype: Union[str, torch.dtype]): + def dtype(self, new_dtype: Union[str, torch.dtype]) -> None: # necessary to avoid infinite recursion raise RuntimeError('Cannot set the dtype explicitly. Please use module.to(new_dtype).') diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 1df011f80e38b..7cee9efa2c052 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -83,7 +83,7 @@ def _debug(*args: Any, **kwargs: Any) -> None: rank_zero_deprecation = partial(rank_zero_warn, category=DeprecationWarning) -def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None) -> List[torch.Tensor]: +def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]: """ Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes diff --git a/pytorch_lightning/utilities/finite_checks.py b/pytorch_lightning/utilities/finite_checks.py index 770ea7a2276f0..c67958ad74f67 100644 --- a/pytorch_lightning/utilities/finite_checks.py +++ b/pytorch_lightning/utilities/finite_checks.py @@ -16,19 +16,19 @@ import logging import torch -import torch.nn as nn +from torch.nn import Module log = logging.getLogger(__name__) -def print_nan_gradients(model: nn.Module) -> None: +def print_nan_gradients(model: Module) -> None: """ Iterates over model parameters and prints out parameter + gradient information if NaN. """ for param in model.parameters(): if (param.grad is not None) and torch.isnan(param.grad.float()).any(): log.info(param, param.grad) -def detect_nan_parameters(model: nn.Module) -> None: +def detect_nan_parameters(model: Module) -> None: """ Iterates over model parameters and prints gradients if any parameter is not finite. """ for name, param in model.named_parameters(): if not torch.isfinite(param).all(): diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 0ee5e4ac4c9ce..5658a28bd0036 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -21,7 +21,7 @@ def recursive_detach( in_dict: Dict[Any, Union[torch.Tensor, Dict[Any, torch.Tensor]]], to_cpu: bool = False -) -> Dict[Any, torch.Tensor]: +) -> Dict[str, torch.Tensor]: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 51547d5576e74..78abd06e3b289 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -88,7 +88,7 @@ def reset_seed() -> None: seed_everything(int(seed)) -def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover +def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover """ The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed with ``seed_everything(seed, workers=True)``. From 9b42400404f06b5f3726cb3c6a5b5c10995bedcc Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Tue, 4 May 2021 11:22:55 +0200 Subject: [PATCH 19/51] Do some fixes after reviews * In apply_func.py device arg is marked as Optional since torch.tensor() and other function can handle device=None * In cloud_io.py LocalFileSystem replaced with AbstractFileSystem * In cli.py: return Type to typing --- pytorch_lightning/utilities/apply_func.py | 15 ++++----------- pytorch_lightning/utilities/cli.py | 10 +++++----- pytorch_lightning/utilities/cloud_io.py | 4 ++-- 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index d384c0edbd3bf..f501fe6284e49 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -36,16 +36,12 @@ def to_dtype_tensor( value: Union[int, float, List[Union[int, float]]], dtype: Optional[torch.dtype] = None, - device: Union[str, torch.device] = None + device: Optional[Union[str, torch.device]] = None ) -> torch.Tensor: - if device is None: - raise MisconfigurationException("device (torch.device) should be provided.") return torch.tensor(value, dtype=dtype, device=device) -def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor: - if device is None: - raise MisconfigurationException("device (torch.device) should be provided.") +def from_numpy(value: np.ndarray, device: Optional[Union[str, torch.device]] = None) -> torch.Tensor: return torch.from_numpy(value).to(device) @@ -64,7 +60,7 @@ def apply_to_collection( function: Callable, *args: Any, wrong_dtype: Optional[Union[torch.dtype, Tuple[torch.dtype]]] = None, - **kwargs: Any + **kwargs: Any, ) -> Any: """ Recursively applies a function to all elements of a certain dtype. @@ -165,10 +161,7 @@ def batch_to(data: Any) -> Any: return apply_to_collection(batch, dtype=dtype, function=batch_to) -def convert_to_tensors(data: Any, device: Union[str, torch.device] = None) -> Any: - if device is None: - raise MisconfigurationException("device (torch.device) should be provided.") - +def convert_to_tensors(data: Any, device: Optional[Union[str, torch.device]] = None) -> Any: for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 6714a36ca6122..214a41a16d407 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -13,7 +13,7 @@ # limitations under the License. import os from argparse import Namespace -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Type, Union import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback @@ -93,10 +93,10 @@ class LightningCLI: def __init__( self, - model_class: 'pl.LightningModule', - datamodule_class: Optional['pl.LightningDataModule'] = None, - save_config_callback: SaveConfigCallback = SaveConfigCallback, - trainer_class: 'pl.Trainer' = Trainer, + model_class: Type['pl.LightningModule'], + datamodule_class: Optional[Type['pl.LightningDataModule']] = None, + save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback, + trainer_class: Type['pl.Trainer'] = Trainer, trainer_defaults: Optional[Dict[str, Any]] = None, seed_everything_default: Optional[int] = None, description: str = 'pytorch-lightning trainer command line tool', diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index c233fbf35483f..d77c2f8745d31 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -18,7 +18,7 @@ import fsspec import torch -from fsspec.implementations.local import LocalFileSystem +from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem from packaging.version import Version @@ -36,7 +36,7 @@ def load( return torch.load(f, map_location=map_location) -def get_filesystem(path: Union[str, Path]) -> LocalFileSystem: +def get_filesystem(path: Union[str, Path]) -> AbstractFileSystem: path = str(path) if "://" in path: # use the fileystem from the protocol specified From 9cd7d19bf9473351ac798ef40c3341c58b1cda09 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Tue, 4 May 2021 11:28:11 +0200 Subject: [PATCH 20/51] Add some missing commas' --- pytorch_lightning/utilities/apply_func.py | 2 +- pytorch_lightning/utilities/cli.py | 4 ++-- pytorch_lightning/utilities/cloud_io.py | 2 +- pytorch_lightning/utilities/debugging.py | 4 ++-- pytorch_lightning/utilities/distributed.py | 4 ++-- pytorch_lightning/utilities/memory.py | 2 +- pytorch_lightning/utilities/parsing.py | 2 +- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index f501fe6284e49..e000df7831b9b 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -36,7 +36,7 @@ def to_dtype_tensor( value: Union[int, float, List[Union[int, float]]], dtype: Optional[torch.dtype] = None, - device: Optional[Union[str, torch.device]] = None + device: Optional[Union[str, torch.device]] = None, ) -> torch.Tensor: return torch.tensor(value, dtype=dtype, device=device) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 214a41a16d407..7778915fac9fc 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -53,7 +53,7 @@ def add_lightning_class_args( self, lightning_class: Union['pl.Trainer', 'pl.LightningModule', 'pl.LightningDataModule'], nested_key: str, - subclass_mode: bool = False + subclass_mode: bool = False, ) -> None: """ Adds arguments from a lightning class to a nested key of the parser @@ -104,7 +104,7 @@ def __init__( env_parse: bool = False, parser_kwargs: Optional[Dict[str, Any]] = None, subclass_mode_model: bool = False, - subclass_mode_data: bool = False + subclass_mode_data: bool = False, ) -> None: """ Receives as input pytorch-lightning classes, which are instantiated diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index d77c2f8745d31..93a328bba462b 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -24,7 +24,7 @@ def load( path_or_url: Union[str, IO, Path], - map_location: Optional[Union[str, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]] = None + map_location: Optional[Union[str, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]] = None, ) -> Any: if not isinstance(path_or_url, (str, Path)): # any sort of BytesIO or similiar diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index d441b025388fe..3eeb28e2b5c3c 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -64,7 +64,7 @@ def track_event( evt_value: Optional[Any] = None, global_rank: Optional[int] = None, local_rank: Optional[int] = None, - comment: str = '' + comment: str = '', ) -> None: self.events.append({ "timestamp": time.time(), @@ -134,7 +134,7 @@ def track_lr_schedulers_update( old_lr: float, new_lr: float, monitor_key: Optional[str] = None, - monitor_val: Optional[torch.Tensor] = None + monitor_val: Optional[torch.Tensor] = None, ) -> None: loss_dict = { 'batch_idx': batch_idx, diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 7cee9efa2c052..1e7a89dcc27c9 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -116,7 +116,7 @@ def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> Lis def sync_ddp_if_available( result: torch.Tensor, group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None + reduce_op: Optional[Union[ReduceOp, str]] = None, ) -> torch.Tensor: """ Function to reduce a tensor across worker processes during distributed training @@ -137,7 +137,7 @@ def sync_ddp_if_available( def sync_ddp( result: torch.Tensor, group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None + reduce_op: Optional[Union[ReduceOp, str]] = None, ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 5658a28bd0036..27e4ca33022dd 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -20,7 +20,7 @@ def recursive_detach( in_dict: Dict[Any, Union[torch.Tensor, Dict[Any, torch.Tensor]]], - to_cpu: bool = False + to_cpu: bool = False, ) -> Dict[str, torch.Tensor]: """Detach all tensors in `in_dict`. diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 4331a149eb2d3..f11141d801284 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -188,7 +188,7 @@ def save_hyperparameters( obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, - frame: Optional[types.FrameType] = None + frame: Optional[types.FrameType] = None, ) -> None: """See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`""" From 99811aaacfd7691c3a9ce5c3a3e6f9575113a6a9 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Thu, 6 May 2021 12:46:51 +0200 Subject: [PATCH 21/51] Add import Set to device_parser.py --- pytorch_lightning/utilities/device_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index 8e41900c9d196..64a781e7b9bac 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import operator -from typing import Any, List, MutableSequence, Optional, Tuple, Union +from typing import Any, List, MutableSequence, Optional, Set, Tuple, Union import torch From 7908c39ee12e21e4326214e6daaba4771ae42fcb Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 7 May 2021 21:41:05 +0200 Subject: [PATCH 22/51] Remove unused import to pass PEP8 --- pytorch_lightning/utilities/apply_func.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index e000df7831b9b..c7cc3c3c173d8 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -21,7 +21,6 @@ import numpy as np import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE if _TORCHTEXT_AVAILABLE: From 198aa413750e50568f27c922454d4e998d4ffe22 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 May 2021 19:42:00 +0000 Subject: [PATCH 23/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/governance.rst | 2 -- pytorch_lightning/utilities/argparse.py | 2 +- pytorch_lightning/utilities/debugging.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/source/governance.rst b/docs/source/governance.rst index fac8b68e1df53..5b1f9bd1916c1 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -38,5 +38,3 @@ Alumni - Jeff Ling (`jeffling `_) - Teddy Koker (`teddykoker `_) - Nate Raw (`nateraw `_) - - diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 75a8ff02ceb12..8f71084d77216 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Tuple, Union import pytorch_lightning as pl -from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_str, str_to_bool_or_int +from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs: Any) -> 'pl.Trainer': diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 3eeb28e2b5c3c..3ff5b09e0bc11 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -16,7 +16,7 @@ import time from collections import Counter from functools import wraps -from typing import Any, Dict, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch from torch.utils.data import DataLoader From 7fcbf96eff3c0ff7f3b8832c06f5b330013e657d Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 8 May 2021 10:45:52 +0200 Subject: [PATCH 24/51] Fix some issues after review * Add Type to apply_func.py to emhpasize a class is expected to be passed in * Replace torch.dtype with type in some functions in cli.py --- pytorch_lightning/utilities/apply_func.py | 2 +- pytorch_lightning/utilities/cli.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index c7cc3c3c173d8..8675ed0eb9d32 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -55,7 +55,7 @@ def from_numpy(value: np.ndarray, device: Optional[Union[str, torch.device]] = N def apply_to_collection( data: Any, - dtype: Union[torch.dtype, Tuple[torch.dtype]], + dtype: Union[type, Tuple[type]], function: Callable, *args: Any, wrong_dtype: Optional[Union[torch.dtype, Tuple[torch.dtype]]] = None, diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 7778915fac9fc..d7a916e1139ac 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -51,7 +51,7 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non def add_lightning_class_args( self, - lightning_class: Union['pl.Trainer', 'pl.LightningModule', 'pl.LightningDataModule'], + lightning_class: Union[Type('pl.Trainer'), Type('pl.LightningModule'), Type('pl.LightningDataModule')], nested_key: str, subclass_mode: bool = False, ) -> None: From fef29070ab65f082a82a7bad780cbffecffcc04b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 May 2021 08:47:52 +0000 Subject: [PATCH 25/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/cli.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index d7a916e1139ac..033259c785258 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -51,7 +51,9 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non def add_lightning_class_args( self, - lightning_class: Union[Type('pl.Trainer'), Type('pl.LightningModule'), Type('pl.LightningDataModule')], + lightning_class: Union[Type('pl.Trainer'), + Type('pl.LightningModule'), + Type('pl.LightningDataModule')], nested_key: str, subclass_mode: bool = False, ) -> None: From 60d615da7571b9b76382515501d868728570042c Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 8 May 2021 11:00:57 +0200 Subject: [PATCH 26/51] Fix typo a make consistent typing in cli.py --- pytorch_lightning/utilities/cli.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 033259c785258..6db1e94889628 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -51,9 +51,9 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non def add_lightning_class_args( self, - lightning_class: Union[Type('pl.Trainer'), - Type('pl.LightningModule'), - Type('pl.LightningDataModule')], + lightning_class: Union[Type['pl.Trainer'], + Type['pl.LightningModule'], + Type['pl.LightningDataModule']], nested_key: str, subclass_mode: bool = False, ) -> None: @@ -84,7 +84,7 @@ def __init__( self.config = config self.config_filename = config_filename - def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: log_dir = trainer.log_dir or trainer.default_root_dir config_path = os.path.join(log_dir, self.config_filename) self.parser.save(self.config, config_path, skip_none=False) From 0854986da248cd39b17e3a2c8273e7dc3601c803 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 May 2021 09:01:45 +0000 Subject: [PATCH 27/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/cli.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 6db1e94889628..c2c3a8f0bd7ae 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -51,9 +51,7 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non def add_lightning_class_args( self, - lightning_class: Union[Type['pl.Trainer'], - Type['pl.LightningModule'], - Type['pl.LightningDataModule']], + lightning_class: Union[Type['pl.Trainer'], Type['pl.LightningModule'], Type['pl.LightningDataModule']], nested_key: str, subclass_mode: bool = False, ) -> None: From 14b172d93702f9a75523f0aac2e1bad03e43e469 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 8 May 2021 11:10:26 +0200 Subject: [PATCH 28/51] * Reflect typing for recursive functions * Introduce new typing class: recursive_dict_with_tensors = Dict[ str, Union[torch.Tensor, 'recursive_dict_with_tensors'] ] --- pytorch_lightning/utilities/memory.py | 7 +++---- pytorch_lightning/utilities/metrics.py | 4 +++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 27e4ca33022dd..0273e86177ce3 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -17,11 +17,10 @@ import torch +recursive_dict_with_tensors = Dict[Any, Union[torch.Tensor, 'recursive_dict_with_tensors']] -def recursive_detach( - in_dict: Dict[Any, Union[torch.Tensor, Dict[Any, torch.Tensor]]], - to_cpu: bool = False, -) -> Dict[str, torch.Tensor]: + +def recursive_detach(in_dict: recursive_dict_with_tensors, to_cpu: bool = False) -> Dict[str, torch.Tensor]: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries diff --git a/pytorch_lightning/utilities/metrics.py b/pytorch_lightning/utilities/metrics.py index e547e69894a05..d1e7e6a23200e 100644 --- a/pytorch_lightning/utilities/metrics.py +++ b/pytorch_lightning/utilities/metrics.py @@ -19,8 +19,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException +recursive_dict_with_tensors = Dict[str, Union[torch.Tensor, 'recursive_dict_with_tensors']] -def metrics_to_scalars(metrics: Dict[str, Union[torch.Tensor, dict]]) -> Dict[str, float]: + +def metrics_to_scalars(metrics: recursive_dict_with_tensors) -> Dict[str, float]: """ Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """ # TODO: this is duplicated in MetricsHolder. should be unified From 6690bb9aecead774649c129188ffb88892f61331 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 8 May 2021 12:56:19 +0200 Subject: [PATCH 29/51] Capitalize the name of recursive-dict type --- pytorch_lightning/utilities/memory.py | 4 ++-- pytorch_lightning/utilities/metrics.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 0273e86177ce3..1a0dbcd04dff3 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -17,10 +17,10 @@ import torch -recursive_dict_with_tensors = Dict[Any, Union[torch.Tensor, 'recursive_dict_with_tensors']] +RECURSIVE_DICT_WITH_TENSORS = Dict[Any, Union[torch.Tensor, 'RECURSIVE_DICT_WITH_TENSORS']] -def recursive_detach(in_dict: recursive_dict_with_tensors, to_cpu: bool = False) -> Dict[str, torch.Tensor]: +def recursive_detach(in_dict: RECURSIVE_DICT_WITH_TENSORS, to_cpu: bool = False) -> Dict[str, torch.Tensor]: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries diff --git a/pytorch_lightning/utilities/metrics.py b/pytorch_lightning/utilities/metrics.py index d1e7e6a23200e..1fd97eea0fc85 100644 --- a/pytorch_lightning/utilities/metrics.py +++ b/pytorch_lightning/utilities/metrics.py @@ -19,10 +19,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException -recursive_dict_with_tensors = Dict[str, Union[torch.Tensor, 'recursive_dict_with_tensors']] +RECURSIVE_DICT_WITH_TENSORS = Dict[str, Union[torch.Tensor, 'RECURSIVE_DICT_WITH_TENSORS']] -def metrics_to_scalars(metrics: recursive_dict_with_tensors) -> Dict[str, float]: +def metrics_to_scalars(metrics: RECURSIVE_DICT_WITH_TENSORS) -> Dict[str, float]: """ Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """ # TODO: this is duplicated in MetricsHolder. should be unified From 644776a530e6d81ff43c01a455e66337644b48fe Mon Sep 17 00:00:00 2001 From: jirka Date: Tue, 11 May 2021 10:40:01 +0200 Subject: [PATCH 30/51] mypy --- setup.cfg | 4 ---- 1 file changed, 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index 3fa6e39076725..f3238d9b34cef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -170,10 +170,6 @@ ignore_errors = True [mypy-pytorch_lightning.tuner.*] ignore_errors = True -# todo: add proper typing to this module... -[mypy-pytorch_lightning.utilities.*] -ignore_errors = True - # todo: add proper typing to this module... [mypy-pl_examples.*] ignore_errors = True From e924379826df127a8a3e2162e7bd9f55a4d2712e Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 15 May 2021 11:50:56 +0200 Subject: [PATCH 31/51] Fix typing for enums.py and xla_device.py --- pytorch_lightning/utilities/enums.py | 2 +- pytorch_lightning/utilities/xla_device.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index e01f8862486d3..766863819a97e 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -27,7 +27,7 @@ def from_str(cls, value: str) -> Optional['LightningEnum']: return getattr(cls, st) return None - def __eq__(self, other: Union[str, Enum]) -> bool: + def __eq__(self, other: Union[object, Enum]) -> bool: other = other.value if isinstance(other, Enum) else str(other) return self.value.lower() == other.lower() diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 5f82cff215ab5..91c21681308f1 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -40,7 +40,7 @@ def pl_multi_process(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Union[Any, bool]: - queue = Queue() + queue: Queue = Queue() proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) proc.start() proc.join(TPU_CHECK_TIMEOUT) From 2659a92035fff40dbba3ea3d2a65c969bab52588 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 15 May 2021 12:09:45 +0200 Subject: [PATCH 32/51] Remove string types where not neede in 2 files Files: cli.py, model_helpers.py --- pytorch_lightning/utilities/cli.py | 10 +++++----- pytorch_lightning/utilities/model_helpers.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 2dcb23bd3a504..4a8f62795a977 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -51,7 +51,7 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non def add_lightning_class_args( self, - lightning_class: Union[Type['pl.Trainer'], Type['pl.LightningModule'], Type['pl.LightningDataModule']], + lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule]], nested_key: str, subclass_mode: bool = False, ) -> None: @@ -82,7 +82,7 @@ def __init__( self.config = config self.config_filename = config_filename - def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: log_dir = trainer.log_dir or trainer.default_root_dir config_path = os.path.join(log_dir, self.config_filename) self.parser.save(self.config, config_path, skip_none=False) @@ -93,10 +93,10 @@ class LightningCLI: def __init__( self, - model_class: Type['pl.LightningModule'], - datamodule_class: Optional[Type['pl.LightningDataModule']] = None, + model_class: Type[LightningModule], + datamodule_class: Optional[Type[LightningDataModule]] = None, save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback, - trainer_class: Type['pl.Trainer'] = Trainer, + trainer_class: Type[Trainer] = Trainer, trainer_defaults: Optional[Dict[str, Any]] = None, seed_everything_default: Optional[int] = None, description: str = 'pytorch-lightning trainer command line tool', diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 6aaeb96dec599..88f0cc647c3b4 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -19,7 +19,7 @@ from pytorch_lightning.core.lightning import LightningModule -def is_overridden(method_name: str, model: Union['pl.LightningModule', 'pl.LightningDataModule']) -> bool: +def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool: # if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super # TODO - refector this function to accept model_name, instance, parent so it makes more sense super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule From 0a16daeb651bd84ae6a0016722d74206206069ab Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 15 May 2021 12:53:51 +0200 Subject: [PATCH 33/51] Fix typing for utilities/argparse.py --- pytorch_lightning/utilities/argparse.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 8f71084d77216..dcd940ee346fb 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -15,7 +15,7 @@ import os from argparse import _ArgumentGroup, ArgumentParser, Namespace from contextlib import suppress -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union import pytorch_lightning as pl from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str @@ -150,7 +150,7 @@ def add_argparse_args( parent_parser: ArgumentParser, *, use_argument_group: bool = True, -) -> ArgumentParser: +) -> Union[_ArgumentGroup, ArgumentParser]: r"""Extends existing argparse by default attributes for ``cls``. Args: @@ -190,7 +190,7 @@ def add_argparse_args( raise RuntimeError("Please only pass an ArgumentParser instance.") if use_argument_group: group_name = _get_abbrev_qualified_cls_name(cls) - parser = parent_parser.add_argument_group(group_name) + parser: Union[_ArgumentGroup, ArgumentParser] = parent_parser.add_argument_group(group_name) else: parser = ArgumentParser( parents=[parent_parser], @@ -213,16 +213,16 @@ def add_argparse_args( args_help = _parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__ or "") for arg, arg_types, arg_default in args_and_types: - arg_types = [at for at in allowed_types if at in arg_types] + arg_types = tuple(at for at in allowed_types if at in arg_types) if not arg_types: # skip argument with not supported type continue - arg_kwargs = {} + arg_kwargs: Dict[str, Any] = {} if bool in arg_types: arg_kwargs.update(nargs="?", const=True) # if the only arg type is bool if len(arg_types) == 1: - use_type = str_to_bool + use_type: Callable[[str], Union[bool, int, float, str]] = str_to_bool elif int in arg_types: use_type = str_to_bool_or_int elif str in arg_types: @@ -261,7 +261,7 @@ def add_argparse_args( def _parse_args_from_docstring(docstring: str) -> Dict[str, str]: arg_block_indent = None - current_arg = None + current_arg = '' parsed = {} for line in docstring.split("\n"): stripped = line.lstrip() From db1a7d159132836ce0446a78f98a6f8bd980c155 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 15 May 2021 13:40:34 +0200 Subject: [PATCH 34/51] Add missing typing for utilities/debugging.py --- pytorch_lightning/utilities/debugging.py | 38 ++++++++++++------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 3ff5b09e0bc11..d390c899a3f67 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -16,7 +16,7 @@ import time from collections import Counter from functools import wraps -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.utils.data import DataLoader @@ -44,19 +44,19 @@ class InternalDebugger(object): def __init__(self, trainer: 'pl.Trainer') -> None: self.enabled = os.environ.get('PL_DEV_DEBUG', '0') == '1' self.trainer = trainer - self.logged_metrics = [] - self.pbar_added_metrics = [] - self.saved_train_losses = [] - self.saved_val_losses = [] - self.saved_test_losses = [] - self.early_stopping_history = [] - self.checkpoint_callback_history = [] - self.events = [] - self.saved_lr_scheduler_updates = [] - self.train_dataloader_calls = [] - self.val_dataloader_calls = [] - self.test_dataloader_calls = [] - self.dataloader_sequence_calls = [] + self.logged_metrics: List[Dict[str, Union[int, torch.Tensor]]] = [] + self.pbar_added_metrics: List[Dict[str, Union[int, torch.Tensor]]] = [] + self.saved_train_losses: List[Dict[str, Union[int, torch.Tensor, object]]] = [] + self.saved_val_losses: List[Dict[str, Union[int, torch.Tensor, object]]] = [] + self.saved_test_losses: List[Dict[str, Union[int, torch.Tensor, object]]] = [] + self.early_stopping_history: List[Dict[str, Union[int, torch.Tensor, object]]] = [] + self.checkpoint_callback_history: List[Dict[str, Union[int, torch.Tensor, object]]] = [] + self.events: List[Dict[str, Any]] = [] + self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = [] + self.train_dataloader_calls: List[Dict[str, Union[int, str, object]]] = [] + self.val_dataloader_calls: List[Dict[str, Union[int, str, object]]] = [] + self.test_dataloader_calls: List[Dict[str, Union[int, str, object]]] = [] + self.dataloader_sequence_calls: List[Dict[str, Union[int, str, object]]] = [] def track_event( self, @@ -116,7 +116,7 @@ def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]) - self.test_dataloader_calls.append(values) @enabled_only - def track_logged_metrics_history(self, scalar_metrics: Dict[str, torch.Tensor]) -> None: + def track_logged_metrics_history(self, scalar_metrics: Dict[str, Union[int, torch.Tensor]]) -> None: scalar_metrics['global_step'] = self.trainer.global_step self.logged_metrics.append(scalar_metrics) @@ -164,12 +164,12 @@ def track_eval_loss_history(self, batch_idx: int, dataloader_idx: int, output: t self.saved_val_losses.append(loss_dict) @enabled_only - def track_pbar_metrics_history(self, metrics: Dict[str, torch.Tensor]) -> None: + def track_pbar_metrics_history(self, metrics: Dict[str, Union[int, torch.Tensor]]) -> None: metrics['debug_epoch'] = self.trainer.current_epoch self.pbar_added_metrics.append(metrics) @enabled_only - def track_early_stopping_history(self, callback: 'pl.Callback', current: torch.Tensor) -> None: + def track_early_stopping_history(self, callback: 'pl.callbacks.early_stopping.EarlyStopping', current: torch.Tensor) -> None: debug_dict = { 'epoch': self.trainer.current_epoch, 'global_step': self.trainer.global_step, @@ -199,7 +199,7 @@ def num_seen_sanity_check_batches(self) -> int: @property def num_seen_val_check_batches(self) -> Counter: - counts = Counter() + counts: Counter = Counter() for x in self.saved_val_losses: if not x['sanity_check']: counts.update({x['dataloader_idx']: 1}) @@ -207,7 +207,7 @@ def num_seen_val_check_batches(self) -> Counter: @property def num_seen_test_check_batches(self) -> Counter: - counts = Counter() + counts: Counter = Counter() for x in self.saved_test_losses: if not x['sanity_check']: counts.update({x['dataloader_idx']: 1}) From 1993c79e6463095943357d3c177e03459e355b9b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 May 2021 11:41:36 +0000 Subject: [PATCH 35/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/debugging.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index d390c899a3f67..e6d8209b8a9d5 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -169,7 +169,9 @@ def track_pbar_metrics_history(self, metrics: Dict[str, Union[int, torch.Tensor] self.pbar_added_metrics.append(metrics) @enabled_only - def track_early_stopping_history(self, callback: 'pl.callbacks.early_stopping.EarlyStopping', current: torch.Tensor) -> None: + def track_early_stopping_history( + self, callback: 'pl.callbacks.early_stopping.EarlyStopping', current: torch.Tensor + ) -> None: debug_dict = { 'epoch': self.trainer.current_epoch, 'global_step': self.trainer.global_step, From c6f2f4e332ba5f692b9fa5a693ce2743e56b55b3 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 15 May 2021 14:20:39 +0200 Subject: [PATCH 36/51] Change type of NotImplemented to Any --- pytorch_lightning/utilities/apply_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 4db1e1546bfc1..dbdfa074627fc 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -122,7 +122,7 @@ class TransferableDataType(ABC): """ @classmethod - def __subclasshook__(cls, subclass: Any) -> Union[bool, Type[NotImplemented]]: + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: if cls is TransferableDataType: to = getattr(subclass, "to", None) return callable(to) From 9cb30cc3edaac23cd9cb4e1449101a10ce14bc39 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 15 May 2021 15:56:02 +0200 Subject: [PATCH 37/51] Fix mypy compatibility in a few files * device_dtype_mixin.py - 0 issue * device_parser.py - 0 issue * [WIP] distributed.py - still some issues mainly related to rank_zero_only function --- .../utilities/device_dtype_mixin.py | 21 +++++++----- pytorch_lightning/utilities/device_parser.py | 21 +++++++----- pytorch_lightning/utilities/distributed.py | 33 ++++++++++--------- 3 files changed, 44 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index c3a5d24461401..4f285d865c14e 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -114,7 +114,7 @@ def to(self, *args: Any, **kwargs: Any) -> Module: self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Module: + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> 'DeviceDtypeModuleMixin': """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will @@ -127,11 +127,16 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Module: Returns: Module: self """ - property_device = device if isinstance(device, torch.device) else torch.device('cuda', index=device) + if isinstance(device, torch.device): + property_device = device + elif isinstance(device, int): + property_device = torch.device('cuda', index=device) + else: + property_device = torch.device('cuda') # use current cuda device if device == None [mypy compatibility] self.__update_properties(device=property_device) return super().cuda(device=device) - def cpu(self) -> Module: + def cpu(self) -> 'DeviceDtypeModuleMixin': """Moves all model parameters and buffers to the CPU. Returns: @@ -140,7 +145,7 @@ def cpu(self) -> Module: self.__update_properties(device=torch.device('cpu')) return super().cpu() - def type(self, dst_type: Union[str, torch.dtype]) -> Module: + def type(self, dst_type: Union[str, torch.dtype]) -> 'DeviceDtypeModuleMixin': """Casts all parameters and buffers to :attr:`dst_type`. Arguments: @@ -152,7 +157,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Module: self.__update_properties(dtype=dst_type) return super().type(dst_type=dst_type) - def float(self) -> Module: + def float(self) -> 'DeviceDtypeModuleMixin': """Casts all floating point parameters and buffers to float datatype. Returns: @@ -161,7 +166,7 @@ def float(self) -> Module: self.__update_properties(dtype=torch.float) return super().float() - def double(self) -> Module: + def double(self) -> 'DeviceDtypeModuleMixin': """Casts all floating point parameters and buffers to ``double`` datatype. Returns: @@ -170,7 +175,7 @@ def double(self) -> Module: self.__update_properties(dtype=torch.double) return super().double() - def half(self) -> Module: + def half(self) -> 'DeviceDtypeModuleMixin': """Casts all floating point parameters and buffers to ``half`` datatype. Returns: @@ -179,7 +184,7 @@ def half(self) -> Module: self.__update_properties(dtype=torch.half) return super().half() - def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: + def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None) -> None: def apply_fn(module): if not isinstance(module, DeviceDtypeModuleMixin): diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index 64a781e7b9bac..db60216cd5928 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import operator -from typing import Any, List, MutableSequence, Optional, Set, Tuple, Union +from typing import Any, Iterable, List, MutableSequence, Optional, Set, Tuple, Union import torch @@ -187,15 +187,20 @@ def _check_data_type(device_ids: Any) -> None: raise MisconfigurationException("Device ID's (GPU/TPU) must be int, string or sequence of ints or None.") -def _tpu_cores_valid(tpu_cores: Optional[Union[int, List[int], Tuple[int], Set[int]]]) -> bool: +def _tpu_cores_valid(tpu_cores: Optional[Union[int, Iterable[int], List[int], Tuple[int], Set[int]]] = None) -> bool: # allow 1 or 8 cores if tpu_cores in (1, 8, None): return True + + # First condition is necessary for mypy compatibility; + # list_tpu_cores is required to declare for mypy compatiblity too + if (tpu_cores is not None) and (not isinstance(tpu_cores, int)): + list_tpu_cores: List[int] = list(tpu_cores) # allow picking 1 of 8 indexes - if isinstance(tpu_cores, (List, Tuple, Set)): - has_1_tpu_idx = len(tpu_cores) == 1 - is_valid_tpu_idx = tpu_cores[0] in range(1, 9) + if isinstance(list_tpu_cores, List): + has_1_tpu_idx = len(list_tpu_cores) == 1 + is_valid_tpu_idx = list_tpu_cores[0] in range(1, 9) is_valid_tpu_core_choice = has_1_tpu_idx and is_valid_tpu_idx return is_valid_tpu_core_choice @@ -205,7 +210,7 @@ def _tpu_cores_valid(tpu_cores: Optional[Union[int, List[int], Tuple[int], Set[i def _parse_tpu_cores_str(tpu_cores: str) -> Union[int, List[int]]: if tpu_cores in ('1', '8'): - tpu_cores = int(tpu_cores) + int_tpu_cores: Union[int, List[int]] = int(tpu_cores) else: - tpu_cores = [int(x.strip()) for x in tpu_cores.split(',') if len(x) > 0] - return tpu_cores + int_tpu_cores = [int(x.strip()) for x in tpu_cores.split(',') if len(x) > 0] + return int_tpu_cores diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 1e7a89dcc27c9..45c8e760d4d6e 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -16,7 +16,7 @@ import os import warnings from functools import partial, wraps -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel @@ -41,12 +41,13 @@ class group: log = logging.getLogger(__name__) -def rank_zero_only(fn: Callable) -> Optional[Callable]: +def rank_zero_only(fn: Callable) -> Callable: @wraps(fn) def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: if rank_zero_only.rank == 0: return fn(*args, **kwargs) + return None return wrapped_fn @@ -174,23 +175,23 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward(ctx: object, tensor: torch.Tensor, group=group.WORLD) -> torch.Tensor: + def forward(ctx: torch.autograd.Function, tensor: torch.Tensor, group=group.WORLD) -> torch.Tensor: ctx.group = group gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(gathered_tensor, tensor, group=group) - gathered_tensor = torch.stack(gathered_tensor, dim=0) + stacked_tensor = torch.stack(gathered_tensor, dim=0) - return gathered_tensor + return stacked_tensor @staticmethod - def backward(ctx: object, *grad_output: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, None]: - grad_output = torch.cat(grad_output) + def backward(ctx: torch.autograd.Function, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: + grad_output_cat = torch.cat(grad_output) - torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) + torch.distributed.all_reduce(grad_output_cat, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) - return grad_output[torch.distributed.get_rank()], None + return grad_output_cat[torch.distributed.get_rank()], None def all_gather_ddp_if_available( @@ -307,12 +308,14 @@ def register_ddp_comm_hook( f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." ) ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) - - rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") - model.register_comm_hook( - state=ddp_comm_state, - hook=ddp_comm_hook, - ) + + # If condition required for mypy compatibility + if ddp_comm_hook is not None: + rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") + model.register_comm_hook( + state=ddp_comm_state, + hook=ddp_comm_hook, + ) def tpu_distributed() -> bool: From 56301122bb9367481e73edad8fe08c4f8871a785 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 May 2021 13:59:38 +0000 Subject: [PATCH 38/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/device_dtype_mixin.py | 4 +++- pytorch_lightning/utilities/device_parser.py | 2 +- pytorch_lightning/utilities/distributed.py | 6 ++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index 4f285d865c14e..bfecd7905a00c 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -184,7 +184,9 @@ def half(self) -> 'DeviceDtypeModuleMixin': self.__update_properties(dtype=torch.half) return super().half() - def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None) -> None: + def __update_properties( + self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None + ) -> None: def apply_fn(module): if not isinstance(module, DeviceDtypeModuleMixin): diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index db60216cd5928..7111e1be7c288 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -191,7 +191,7 @@ def _tpu_cores_valid(tpu_cores: Optional[Union[int, Iterable[int], List[int], Tu # allow 1 or 8 cores if tpu_cores in (1, 8, None): return True - + # First condition is necessary for mypy compatibility; # list_tpu_cores is required to declare for mypy compatiblity too if (tpu_cores is not None) and (not isinstance(tpu_cores, int)): diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 45c8e760d4d6e..28723bbbfa291 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -189,7 +189,9 @@ def forward(ctx: torch.autograd.Function, tensor: torch.Tensor, group=group.WORL def backward(ctx: torch.autograd.Function, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: grad_output_cat = torch.cat(grad_output) - torch.distributed.all_reduce(grad_output_cat, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) + torch.distributed.all_reduce( + grad_output_cat, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group + ) return grad_output_cat[torch.distributed.get_rank()], None @@ -308,7 +310,7 @@ def register_ddp_comm_hook( f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." ) ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) - + # If condition required for mypy compatibility if ddp_comm_hook is not None: rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") From 00a145b0fa7b5d0defe89e337d838d639dfb4673 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 15 May 2021 16:52:36 +0200 Subject: [PATCH 39/51] [WIP] Fix mypy compatibility for parsing.py * TODO: Resolve one missing return and non-FrameType in save_hyperparameters function. --- pytorch_lightning/utilities/parsing.py | 51 +++++++++++++++----------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index f11141d801284..747c22bae7fc2 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +from enum import Enum import inspect import pickle +from pytorch_lightning.core.lightning import LightningModule import types from argparse import Namespace -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from types import FrameType +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_warn @@ -51,10 +54,10 @@ def str_to_bool(val: str) -> bool: >>> str_to_bool('FALSE') False """ - val = str_to_bool_or_str(val) - if isinstance(val, bool): - return val - raise ValueError(f'invalid truth value {val}') + val_converted = str_to_bool_or_str(val) + if isinstance(val_converted, bool): + return val_converted + raise ValueError(f'invalid truth value {val_converted}') def str_to_bool_or_int(val: str) -> Union[bool, int, str]: @@ -69,13 +72,13 @@ def str_to_bool_or_int(val: str) -> Union[bool, int, str]: >>> str_to_bool_or_int("abc") 'abc' """ - val = str_to_bool_or_str(val) - if isinstance(val, bool): - return val + val_converted = str_to_bool_or_str(val) + if isinstance(val_converted, bool): + return val_converted try: - return int(val) + return int(val_converted) except ValueError: - return val + return val_converted def is_picklable(obj: object) -> bool: @@ -118,7 +121,10 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]: # self is always first n_self = init_params[0].name - def _get_first_if_any(params: Sequence[Namespace], param_type: Type) -> str: + def _get_first_if_any( + params: List[inspect.Parameter], + param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD], + ) -> str: for p in params: if p.kind == param_type: return p.name @@ -129,7 +135,7 @@ def _get_first_if_any(params: Sequence[Namespace], param_type: Type) -> str: return n_self, n_args, n_kwargs -def get_init_args(frame: object) -> Dict[str, Any]: +def get_init_args(frame: FrameType) -> Dict[str, Any]: _, _, _, local_vars = inspect.getargvalues(frame) if '__class__' not in local_vars: return {} @@ -145,7 +151,7 @@ def get_init_args(frame: object) -> Dict[str, Any]: return local_args -def collect_init_args(frame: object, path_args: List[Dict[str, Any]], inside: bool = False) -> List[Dict[str, Any]]: +def collect_init_args(frame: FrameType, path_args: List[Dict[str, Any]], inside: bool = False) -> List[Dict[str, Any]]: """ Recursively collects the arguments passed to the child constructors in the inheritance tree. @@ -160,13 +166,16 @@ def collect_init_args(frame: object, path_args: List[Dict[str, Any]], inside: bo most specific class in the hierarchy. """ _, _, _, local_vars = inspect.getargvalues(frame) - if '__class__' in local_vars: - local_args = get_init_args(frame) - # recursive update - path_args.append(local_args) - return collect_init_args(frame.f_back, path_args, inside=True) - elif not inside: - return collect_init_args(frame.f_back, path_args, inside) + if isinstance(frame.f_back, FrameType): + if '__class__' in local_vars: + local_args = get_init_args(frame) + # recursive update + path_args.append(local_args) + return collect_init_args(frame.f_back, path_args, inside=True) + elif not inside: + return collect_init_args(frame.f_back, path_args, inside) + else: + return path_args else: return path_args @@ -272,7 +281,7 @@ def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) """ trainer = getattr(model, 'trainer', None) - holders = [] + holders: List[Union[Dict[Any, Any], Namespace, LightningModule]] = [] # Check if attribute in model if hasattr(model, attribute): From 87be8a08e9da2d1faf216f6993c312db9f211916 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 May 2021 14:55:27 +0000 Subject: [PATCH 40/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/parsing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 747c22bae7fc2..2a8153f42b5ec 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -from enum import Enum import inspect import pickle -from pytorch_lightning.core.lightning import LightningModule import types from argparse import Namespace +from enum import Enum from types import FrameType from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union import pytorch_lightning as pl +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn From 8eda497d121d64191883ae270d34a193e0777a83 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 15 May 2021 17:23:40 +0200 Subject: [PATCH 41/51] Import Literal from typing_extensions to support python version < 3.8 --- pytorch_lightning/utilities/parsing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 2a8153f42b5ec..733204b64dc72 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -18,7 +18,8 @@ from argparse import Namespace from enum import Enum from types import FrameType -from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing_extensions import Literal import pytorch_lightning as pl from pytorch_lightning.core.lightning import LightningModule From d198abf07b9c4c5ad9abca83bf6a9ae859d4c5a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 May 2021 15:24:37 +0000 Subject: [PATCH 42/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/parsing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 733204b64dc72..336f3b433b611 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -19,6 +19,7 @@ from enum import Enum from types import FrameType from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + from typing_extensions import Literal import pytorch_lightning as pl From f1e0e7edd9262a7fb0d283d72d9c6a919571dc59 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 15 May 2021 17:47:07 +0200 Subject: [PATCH 43/51] Remove unusued import and circular import --- pytorch_lightning/utilities/apply_func.py | 2 +- pytorch_lightning/utilities/cli.py | 1 - pytorch_lightning/utilities/model_helpers.py | 1 - pytorch_lightning/utilities/parsing.py | 16 +++++++--------- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index dbdfa074627fc..29d0bdf935466 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -16,7 +16,7 @@ from collections.abc import Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Callable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 4a8f62795a977..94b969d713f3a 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -15,7 +15,6 @@ from argparse import Namespace from typing import Any, Dict, Optional, Type, Union -import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 88f0cc647c3b4..87bd9e6c4545d 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -14,7 +14,6 @@ from typing import Union -import pytorch_lightning as pl from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 336f3b433b611..fc7b24a6b286b 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -16,14 +16,12 @@ import pickle import types from argparse import Namespace -from enum import Enum from types import FrameType from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing_extensions import Literal import pytorch_lightning as pl -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn @@ -201,7 +199,7 @@ def save_hyperparameters( ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None, ) -> None: - """See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`""" + """See :meth:`~pytorch_lightning.'pl.core.lightning.LightningModule'.save_hyperparameters`""" if len(args) == 1 and not isinstance(args, str) and not args[0]: # args[0] is an empty container @@ -276,14 +274,14 @@ def __repr__(self) -> str: return out -def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) -> Any: +def _lightning_get_all_attr_holders(model: 'pl.core.lightning.LightningModule', attribute: str) -> Any: """ Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ trainer = getattr(model, 'trainer', None) - holders: List[Union[Dict[Any, Any], Namespace, LightningModule]] = [] + holders: List[Union[Dict[Any, Any], Namespace, 'pl.core.lightning.LightningModule']] = [] # Check if attribute in model if hasattr(model, attribute): @@ -301,7 +299,7 @@ def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) return holders -def _lightning_get_first_attr_holder(model: 'pl.LightningModule', attribute: str) -> Optional[Any]: +def _lightning_get_first_attr_holder(model: 'pl.core.lightning.LightningModule', attribute: str) -> Optional[Any]: """ Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, @@ -314,7 +312,7 @@ def _lightning_get_first_attr_holder(model: 'pl.LightningModule', attribute: str return holders[-1] -def lightning_hasattr(model: 'pl.LightningModule', attribute: str) -> bool: +def lightning_hasattr(model: 'pl.core.lightning.LightningModule', attribute: str) -> bool: """ Special hasattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -322,7 +320,7 @@ def lightning_hasattr(model: 'pl.LightningModule', attribute: str) -> bool: return _lightning_get_first_attr_holder(model, attribute) is not None -def lightning_getattr(model: 'pl.LightningModule', attribute: str) -> Any: +def lightning_getattr(model: 'pl.core.lightning.LightningModule', attribute: str) -> Any: """ Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -344,7 +342,7 @@ def lightning_getattr(model: 'pl.LightningModule', attribute: str) -> Any: return getattr(holder, attribute) -def lightning_setattr(model: 'pl.LightningModule', attribute: str, value: Any) -> None: +def lightning_setattr(model: 'pl.core.lightning.LightningModule', attribute: str, value: Any) -> None: """ Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict. From 94f7bf74daf6094765e5170549de72ca91abced0 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 15 May 2021 20:36:38 +0200 Subject: [PATCH 44/51] Fix another bunch of mypy issues --- pytorch_lightning/utilities/argparse.py | 14 +++++++------- pytorch_lightning/utilities/device_dtype_mixin.py | 4 ++-- pytorch_lightning/utilities/distributed.py | 10 +++++++--- pytorch_lightning/utilities/memory.py | 6 ++++-- pytorch_lightning/utilities/metrics.py | 6 +++--- pytorch_lightning/utilities/parsing.py | 4 ++-- pytorch_lightning/utilities/warnings.py | 4 ++-- 7 files changed, 27 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index dcd940ee346fb..b4dd145b9da7b 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -15,13 +15,13 @@ import os from argparse import _ArgumentGroup, ArgumentParser, Namespace from contextlib import suppress -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Type, Union import pytorch_lightning as pl from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str -def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs: Any) -> 'pl.Trainer': +def from_argparse_args(cls: Type['pl.Trainer'], args: Union[Namespace, ArgumentParser], **kwargs: Any) -> 'pl.Trainer': """Create an instance from CLI arguments. Eventually use varibles from OS environement which are defined as "PL__" @@ -53,7 +53,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs: An return cls(**trainer_kwargs) -def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: +def parse_argparser(cls: Type['pl.Trainer'], arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: """Parse CLI arguments, required for custom bool types.""" args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser @@ -78,7 +78,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp return Namespace(**modified_args) -def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: +def parse_env_variables(cls: Type['pl.Trainer'], template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: """Parse environment arguments if they are defined. Example: @@ -107,7 +107,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") return Namespace(**env_args) -def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: +def get_init_arguments_and_types(cls: Type['pl.Trainer']) -> List[Tuple[str, Tuple, Any]]: r"""Scans the class signature and returns argument names, types and default values. Returns: @@ -135,7 +135,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: return name_type_default -def _get_abbrev_qualified_cls_name(cls) -> str: +def _get_abbrev_qualified_cls_name(cls: Type['pl.Trainer']) -> str: assert isinstance(cls, type), repr(cls) if cls.__module__.startswith("pytorch_lightning."): # Abbreviate. @@ -146,7 +146,7 @@ def _get_abbrev_qualified_cls_name(cls) -> str: def add_argparse_args( - cls, + cls: Type['pl.Trainer'], parent_parser: ArgumentParser, *, use_argument_group: bool = True, diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index bfecd7905a00c..6a357a27622dd 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -25,7 +25,7 @@ class DeviceDtypeModuleMixin(Module): def __init__(self) -> None: super().__init__() - self._dtype = torch.get_default_dtype() + self._dtype: Union[str, torch.dtype] = torch.get_default_dtype() self._device = torch.device('cpu') @property @@ -188,7 +188,7 @@ def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None ) -> None: - def apply_fn(module): + def apply_fn(module: Module) -> None: if not isinstance(module, DeviceDtypeModuleMixin): return if device is not None: diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 28723bbbfa291..cd69ab6732c5e 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -16,7 +16,7 @@ import os import warnings from functools import partial, wraps -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel @@ -175,7 +175,11 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward(ctx: torch.autograd.Function, tensor: torch.Tensor, group=group.WORLD) -> torch.Tensor: + def forward( + ctx: torch.autograd.Function, + tensor: torch.Tensor, + group: Optional['torch.distributed.ProcessGroup'] = group.WORLD, + ) -> torch.Tensor: ctx.group = group gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] @@ -197,7 +201,7 @@ def backward(ctx: torch.autograd.Function, *grad_output: torch.Tensor) -> Tuple[ def all_gather_ddp_if_available( - tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False + tensor: torch.Tensor, group: Optional[Type['torch.distributed.ProcessGroup']] = None, sync_grads: bool = False ) -> torch.Tensor: """ Function to gather a tensor from several distributed processes diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 1a0dbcd04dff3..7bbac2d443618 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -17,10 +17,12 @@ import torch -RECURSIVE_DICT_WITH_TENSORS = Dict[Any, Union[torch.Tensor, 'RECURSIVE_DICT_WITH_TENSORS']] +RECURSIVE_DICT_WITH_TENSORS = Union[Dict[str, torch.Tensor], Dict[Any, Any]] -def recursive_detach(in_dict: RECURSIVE_DICT_WITH_TENSORS, to_cpu: bool = False) -> Dict[str, torch.Tensor]: +def recursive_detach( + in_dict: RECURSIVE_DICT_WITH_TENSORS, to_cpu: bool = False +) -> Dict[str, Union[Any, Dict[str, torch.Tensor], torch.Tensor]]: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries diff --git a/pytorch_lightning/utilities/metrics.py b/pytorch_lightning/utilities/metrics.py index 1fd97eea0fc85..aa4a60e08bde4 100644 --- a/pytorch_lightning/utilities/metrics.py +++ b/pytorch_lightning/utilities/metrics.py @@ -13,16 +13,16 @@ # limitations under the License. """Helper functions to operate on metric values. """ -from typing import Dict, Union +from typing import Any, Dict, Union import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -RECURSIVE_DICT_WITH_TENSORS = Dict[str, Union[torch.Tensor, 'RECURSIVE_DICT_WITH_TENSORS']] +RECURSIVE_DICT_WITH_TENSORS = Union[Dict[str, torch.Tensor], Dict[Any, Any]] -def metrics_to_scalars(metrics: RECURSIVE_DICT_WITH_TENSORS) -> Dict[str, float]: +def metrics_to_scalars(metrics: RECURSIVE_DICT_WITH_TENSORS) -> Dict[str, Union[Any, Dict[str, float], float]]: """ Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """ # TODO: this is duplicated in MetricsHolder. should be unified diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index fc7b24a6b286b..38d082d378932 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -17,7 +17,7 @@ import types from argparse import Namespace from types import FrameType -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union from typing_extensions import Literal @@ -105,7 +105,7 @@ def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None: del hparams_dict[k] -def parse_class_init_keys(cls) -> Tuple[str, str, str]: +def parse_class_init_keys(cls: Type['pl.LightningModule']) -> Tuple[str, str, str]: """Parse key words for standard self, *args and **kwargs >>> class Model(): diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py index 6dfb67133cf19..08b4d9d962f43 100644 --- a/pytorch_lightning/utilities/warnings.py +++ b/pytorch_lightning/utilities/warnings.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Set from pytorch_lightning.utilities.distributed import rank_zero_warn @@ -19,7 +19,7 @@ class WarningCache: def __init__(self) -> None: - self.warnings = set() + self.warnings: Set[Any] = set() def warn(self, m: Any, *args: Any, **kwargs: Any) -> None: if m not in self.warnings: From 3117438449c792e90a5641e7aafe366bf197059d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 May 2021 18:38:11 +0000 Subject: [PATCH 45/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/distributed.py | 4 +++- pytorch_lightning/utilities/memory.py | 5 ++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index cd69ab6732c5e..8285759eaf473 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -201,7 +201,9 @@ def backward(ctx: torch.autograd.Function, *grad_output: torch.Tensor) -> Tuple[ def all_gather_ddp_if_available( - tensor: torch.Tensor, group: Optional[Type['torch.distributed.ProcessGroup']] = None, sync_grads: bool = False + tensor: torch.Tensor, + group: Optional[Type['torch.distributed.ProcessGroup']] = None, + sync_grads: bool = False ) -> torch.Tensor: """ Function to gather a tensor from several distributed processes diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 7bbac2d443618..bfed10fe871ee 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -20,9 +20,8 @@ RECURSIVE_DICT_WITH_TENSORS = Union[Dict[str, torch.Tensor], Dict[Any, Any]] -def recursive_detach( - in_dict: RECURSIVE_DICT_WITH_TENSORS, to_cpu: bool = False -) -> Dict[str, Union[Any, Dict[str, torch.Tensor], torch.Tensor]]: +def recursive_detach(in_dict: RECURSIVE_DICT_WITH_TENSORS, + to_cpu: bool = False) -> Dict[str, Union[Any, Dict[str, torch.Tensor], torch.Tensor]]: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries From 1888663f35a0d3a72b714b8500edcbbdc65184c5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 17 May 2021 09:10:17 +0100 Subject: [PATCH 46/51] typo --- docs/source/governance.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/source/governance.rst b/docs/source/governance.rst index 1f59491781468..60476822cd59c 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -37,7 +37,4 @@ Alumni - Jeff Ling (`jeffling `_) - Teddy Koker (`teddykoker `_) - Nate Raw (`nateraw `_) -<<<<<<< HEAD -======= - Peter Yu (`yukw777 `_) ->>>>>>> upstream/master From 4ae0c144700f429a0c8bf817c1c27e195311ff63 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Tue, 18 May 2021 17:18:19 +0200 Subject: [PATCH 47/51] Fix pl.LightningModule type in parsing.py --- pytorch_lightning/utilities/parsing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 38d082d378932..886a5078c60bf 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -199,7 +199,7 @@ def save_hyperparameters( ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None, ) -> None: - """See :meth:`~pytorch_lightning.'pl.core.lightning.LightningModule'.save_hyperparameters`""" + """See :meth:`~pytorch_lightning.core.lightning.LightningModule'.save_hyperparameters`""" if len(args) == 1 and not isinstance(args, str) and not args[0]: # args[0] is an empty container @@ -274,14 +274,14 @@ def __repr__(self) -> str: return out -def _lightning_get_all_attr_holders(model: 'pl.core.lightning.LightningModule', attribute: str) -> Any: +def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) -> Any: """ Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ trainer = getattr(model, 'trainer', None) - holders: List[Union[Dict[Any, Any], Namespace, 'pl.core.lightning.LightningModule']] = [] + holders: List[Union[Dict[Any, Any], Namespace, 'pl.LightningModule']] = [] # Check if attribute in model if hasattr(model, attribute): @@ -299,7 +299,7 @@ def _lightning_get_all_attr_holders(model: 'pl.core.lightning.LightningModule', return holders -def _lightning_get_first_attr_holder(model: 'pl.core.lightning.LightningModule', attribute: str) -> Optional[Any]: +def _lightning_get_first_attr_holder(model: 'pl.LightningModule', attribute: str) -> Optional[Any]: """ Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, @@ -312,7 +312,7 @@ def _lightning_get_first_attr_holder(model: 'pl.core.lightning.LightningModule', return holders[-1] -def lightning_hasattr(model: 'pl.core.lightning.LightningModule', attribute: str) -> bool: +def lightning_hasattr(model: 'pl.LightningModule', attribute: str) -> bool: """ Special hasattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -320,7 +320,7 @@ def lightning_hasattr(model: 'pl.core.lightning.LightningModule', attribute: str return _lightning_get_first_attr_holder(model, attribute) is not None -def lightning_getattr(model: 'pl.core.lightning.LightningModule', attribute: str) -> Any: +def lightning_getattr(model: 'pl.LightningModule', attribute: str) -> Any: """ Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -342,7 +342,7 @@ def lightning_getattr(model: 'pl.core.lightning.LightningModule', attribute: str return getattr(holder, attribute) -def lightning_setattr(model: 'pl.core.lightning.LightningModule', attribute: str, value: Any) -> None: +def lightning_setattr(model: 'pl.LightningModule', attribute: str, value: Any) -> None: """ Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict. From 9a71d85dcd6370d9253f197a2e3738b0085e28cf Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 21 May 2021 09:03:02 +0200 Subject: [PATCH 48/51] Add back deleted MisconfigurationException (device) --- pytorch_lightning/utilities/apply_func.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 29d0bdf935466..2d24776a41827 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -21,6 +21,7 @@ import numpy as np import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE if _TORCHTEXT_AVAILABLE: @@ -35,12 +36,16 @@ def to_dtype_tensor( value: Union[int, float, List[Union[int, float]]], dtype: Optional[torch.dtype] = None, - device: Optional[Union[str, torch.device]] = None, + device: Union[str, torch.device] = None, ) -> torch.Tensor: + if device is None: + raise MisconfigurationException("device (torch.device) should be provided.") return torch.tensor(value, dtype=dtype, device=device) -def from_numpy(value: np.ndarray, device: Optional[Union[str, torch.device]] = None) -> torch.Tensor: +def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor: + if device is None: + raise MisconfigurationException("device (torch.device) should be provided.") return torch.from_numpy(value).to(device) @@ -167,7 +172,9 @@ def batch_to(data: Any) -> Any: return apply_to_collection(batch, dtype=dtype, function=batch_to) -def convert_to_tensors(data: Any, device: Optional[Union[str, torch.device]] = None) -> Any: +def convert_to_tensors(data: Any, device: Union[str, torch.device] = None) -> Any: + if device is None: + raise MisconfigurationException("device (torch.device) should be provided.") for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) From b7f9ca703f35a00d92cca1b77e9f8890ae16b359 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Tue, 25 May 2021 15:13:59 +0200 Subject: [PATCH 49/51] [WIP] Tackle some other mypy issues --- pytorch_lightning/utilities/apply_func.py | 2 +- pytorch_lightning/utilities/device_dtype_mixin.py | 2 +- pytorch_lightning/utilities/distributed.py | 4 ++-- pytorch_lightning/utilities/seed.py | 14 +++++++------- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 2d24776a41827..2df8c7cad83e9 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -63,7 +63,7 @@ def apply_to_collection( dtype: Union[type, Tuple[type]], function: Callable, *args: Any, - wrong_dtype: Optional[Union[torch.dtype, Tuple[torch.dtype]]] = None, + wrong_dtype: Optional[Union[type, Tuple[type]]] = None, **kwargs: Any, ) -> Any: """ diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index 52c78d2728f84..06c18a7bfedc9 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -45,7 +45,7 @@ def device(self) -> Union[str, torch.device]: return device - def to(self, *args: Any, **kwargs: Any) -> Module: + def to(self, *args: Any, **kwargs: Any) -> 'DeviceDtypeModuleMixin': """Moves and/or casts the parameters and buffers. This can be called as diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 8285759eaf473..9997a68515a71 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -31,10 +31,10 @@ else: - class ReduceOp: + class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153) SUM = None - class group: + class group: # type: ignore WORLD = None diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 78abd06e3b289..05c307769bb7f 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -48,13 +48,13 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min - try: - if seed is None: - seed = os.environ.get("PL_GLOBAL_SEED") - seed = int(seed) - except (TypeError, ValueError): - seed = _select_seed_randomly(min_seed_value, max_seed_value) - rank_zero_warn(f"No correct seed found, seed set to {seed}") + if seed is None: + global_seed = os.environ.get("PL_GLOBAL_SEED") + if isinstance(global_seed, str): + seed = int(global_seed) + else: + seed = _select_seed_randomly(min_seed_value, max_seed_value) + rank_zero_warn(f"No correct seed found, seed set to {seed}") if not (min_seed_value <= seed <= max_seed_value): rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") From d2e06b3c5cefe8813ba0c55bd2c15c51f48023c3 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Thu, 27 May 2021 23:19:30 +0200 Subject: [PATCH 50/51] Fix some issues after a review --- pytorch_lightning/utilities/device_parser.py | 6 +++--- pytorch_lightning/utilities/memory.py | 10 +++++----- pytorch_lightning/utilities/parsing.py | 1 + 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index 7111e1be7c288..93d7d0731f6a9 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import operator -from typing import Any, Iterable, List, MutableSequence, Optional, Set, Tuple, Union +from typing import Any, Iterable, List, MutableSequence, Optional, Sequence, Set, Tuple, Union import torch @@ -194,11 +194,11 @@ def _tpu_cores_valid(tpu_cores: Optional[Union[int, Iterable[int], List[int], Tu # First condition is necessary for mypy compatibility; # list_tpu_cores is required to declare for mypy compatiblity too - if (tpu_cores is not None) and (not isinstance(tpu_cores, int)): + if tpu_cores is not None and not isinstance(tpu_cores, int): list_tpu_cores: List[int] = list(tpu_cores) # allow picking 1 of 8 indexes - if isinstance(list_tpu_cores, List): + if isinstance(list_tpu_cores, Sequence): has_1_tpu_idx = len(list_tpu_cores) == 1 is_valid_tpu_idx = list_tpu_cores[0] in range(1, 9) diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index bfed10fe871ee..9f0ed0f9b9c3c 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -21,7 +21,7 @@ def recursive_detach(in_dict: RECURSIVE_DICT_WITH_TENSORS, - to_cpu: bool = False) -> Dict[str, Union[Any, Dict[str, torch.Tensor], torch.Tensor]]: + to_cpu: bool = False,) -> Dict[str, Union[Any, Dict[str, torch.Tensor], torch.Tensor]]: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries @@ -47,14 +47,14 @@ def recursive_detach(in_dict: RECURSIVE_DICT_WITH_TENSORS, return out_dict -def is_oom_error(exception: Any) -> bool: +def is_oom_error(exception: Exception) -> bool: return is_cuda_out_of_memory(exception) \ or is_cudnn_snafu(exception) \ or is_out_of_cpu_memory(exception) # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py -def is_cuda_out_of_memory(exception: Any) -> bool: +def is_cuda_out_of_memory(exception: Exception) -> bool: return isinstance(exception, RuntimeError) \ and len(exception.args) == 1 \ and "CUDA" in exception.args[0] \ @@ -62,7 +62,7 @@ def is_cuda_out_of_memory(exception: Any) -> bool: # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py -def is_cudnn_snafu(exception: Any) -> bool: +def is_cudnn_snafu(exception: Exception) -> bool: # For/because of https://github.com/pytorch/pytorch/issues/4107 return isinstance(exception, RuntimeError) \ and len(exception.args) == 1 \ @@ -70,7 +70,7 @@ def is_cudnn_snafu(exception: Any) -> bool: # based on https://github.com/BlackHC/toma/blob/master/toma/cpu_memory.py -def is_out_of_cpu_memory(exception: Any) -> bool: +def is_out_of_cpu_memory(exception: Exception) -> bool: return isinstance(exception, RuntimeError) \ and len(exception.args) == 1 \ and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 886a5078c60bf..18df53bbb4531 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -166,6 +166,7 @@ def collect_init_args(frame: FrameType, path_args: List[Dict[str, Any]], inside: most specific class in the hierarchy. """ _, _, _, local_vars = inspect.getargvalues(frame) + # frame.f_back must be of a type FrameType for get_init_args/collect_init_args due to mypy if isinstance(frame.f_back, FrameType): if '__class__' in local_vars: local_args = get_init_args(frame) From eb902c7ce0da9e3a1c1f5fb841f04e581e81b3da Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 May 2021 21:20:31 +0000 Subject: [PATCH 51/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/memory.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 9f0ed0f9b9c3c..65858fe8a6d1d 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -20,8 +20,10 @@ RECURSIVE_DICT_WITH_TENSORS = Union[Dict[str, torch.Tensor], Dict[Any, Any]] -def recursive_detach(in_dict: RECURSIVE_DICT_WITH_TENSORS, - to_cpu: bool = False,) -> Dict[str, Union[Any, Dict[str, torch.Tensor], torch.Tensor]]: +def recursive_detach( + in_dict: RECURSIVE_DICT_WITH_TENSORS, + to_cpu: bool = False, +) -> Dict[str, Union[Any, Dict[str, torch.Tensor], torch.Tensor]]: """Detach all tensors in `in_dict`. May operate recursively if some of the values in `in_dict` are dictionaries