From 95a1155b667b716038f0389c44301e9a330ebc0d Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 2 Aug 2021 08:04:19 +0200 Subject: [PATCH 1/4] Fix a majority of typing for utilities.debugging --- pyproject.toml | 1 + pytorch_lightning/utilities/debugging.py | 71 +++++++++++++++--------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d0d9aaf383ebe..8ccc2157099e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ module = [ "pytorch_lightning.trainer.connectors.logger_connector", "pytorch_lightning.utilities.argparse", "pytorch_lightning.utilities.cli", + "pytorch_lightning.utilities.debugging", "pytorch_lightning.utilities.device_dtype_mixin", "pytorch_lightning.utilities.device_parser", "pytorch_lightning.utilities.parsing", diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 3585480d028a9..ae603d5a80a0b 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -16,10 +16,15 @@ import time from collections import Counter from functools import wraps -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, List, Optional, Union +import torch +from torch.utils.data import DataLoader -def enabled_only(fn: Callable): +import pytorch_lightning as pl + + +def enabled_only(fn: Callable) -> Optional[Callable]: """Decorate a logger method to run it only on the process with rank 0. Args: @@ -27,28 +32,29 @@ def enabled_only(fn: Callable): """ @wraps(fn) - def wrapped_fn(self, *args, **kwargs): + def wrapped_fn(self: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: if self.enabled: fn(self, *args, **kwargs) + return None return wrapped_fn class InternalDebugger: - def __init__(self, trainer): + def __init__(self, trainer: 'pl.Trainer') -> None: self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1" self.trainer = trainer - self.saved_train_losses = [] - self.saved_val_losses = [] - self.saved_test_losses = [] - self.early_stopping_history = [] - self.checkpoint_callback_history = [] - self.events = [] - self.saved_lr_scheduler_updates = [] - self.train_dataloader_calls = [] - self.val_dataloader_calls = [] - self.test_dataloader_calls = [] - self.dataloader_sequence_calls = [] + self.saved_train_losses: List[Dict[str, Union[float, torch.Tensor, object]]] = [] + self.saved_val_losses: List[Dict[str, Union[float, torch.Tensor, object]]] = [] + self.saved_test_losses: List[Dict[str, Union[float, torch.Tensor, object]]] = [] + self.early_stopping_history: List[Dict[str, Union[float, torch.Tensor, object]]] = [] + self.checkpoint_callback_history: List[Dict[str, Union[float, torch.Tensor, object]]] = [] + self.events: List[Dict[str, Any]] = [] + self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = [] + self.train_dataloader_calls: List[Dict[str, Union[int, str, object]]] = [] + self.val_dataloader_calls: List[Dict[str, Union[int, str, object]]] = [] + self.test_dataloader_calls: List[Dict[str, Union[int, str, object]]] = [] + self.dataloader_sequence_calls: List[Dict[str, Union[int, str, object]]] = [] @enabled_only def track_event( @@ -71,7 +77,7 @@ def track_event( ) @enabled_only - def track_load_dataloader_call(self, name, dataloaders): + def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]) -> None: loader_counts = len(dataloaders) lengths = [] @@ -102,14 +108,21 @@ def track_load_dataloader_call(self, name, dataloaders): self.test_dataloader_calls.append(values) @enabled_only - def track_train_loss_history(self, batch_idx, loss): + def track_train_loss_history(self, batch_idx: int, loss: torch.Tensor) -> None: loss_dict = {"batch_idx": batch_idx, "epoch": self.trainer.current_epoch, "loss": loss.detach()} self.saved_train_losses.append(loss_dict) @enabled_only def track_lr_schedulers_update( - self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None - ): + self, + batch_idx: int, + interval: int, + scheduler_idx: int, + old_lr: float, + new_lr: float, + monitor_key: Optional[str] = None, + monitor_val: Optional[torch.Tensor] = None, + ) -> None: loss_dict = { "batch_idx": batch_idx, "interval": interval, @@ -123,7 +136,7 @@ def track_lr_schedulers_update( self.saved_lr_scheduler_updates.append(loss_dict) @enabled_only - def track_eval_loss_history(self, batch_idx, dataloader_idx, output): + def track_eval_loss_history(self, batch_idx: int, dataloader_idx: int, output: torch.Tensor) -> None: loss_dict = { "sanity_check": self.trainer.sanity_checking, "dataloader_idx": dataloader_idx, @@ -138,7 +151,11 @@ def track_eval_loss_history(self, batch_idx, dataloader_idx, output): self.saved_val_losses.append(loss_dict) @enabled_only - def track_early_stopping_history(self, callback, current): + def track_early_stopping_history( + self, + callback: 'pl.callbacks.early_stopping.EarlyStopping', + current: torch.Tensor, + ) -> None: debug_dict = { "epoch": self.trainer.current_epoch, "global_step": self.trainer.global_step, @@ -150,7 +167,7 @@ def track_early_stopping_history(self, callback, current): self.early_stopping_history.append(debug_dict) @enabled_only - def track_checkpointing_history(self, filepath): + def track_checkpointing_history(self, filepath: str) -> None: cb = self.trainer.checkpoint_callback debug_dict = { "epoch": self.trainer.current_epoch, @@ -162,21 +179,21 @@ def track_checkpointing_history(self, filepath): self.checkpoint_callback_history.append(debug_dict) @property - def num_seen_sanity_check_batches(self): + def num_seen_sanity_check_batches(self) -> int: count = sum(1 for x in self.saved_val_losses if x["sanity_check"]) return count @property - def num_seen_val_check_batches(self): - counts = Counter() + def num_seen_val_check_batches(self) -> Counter: + counts: Counter = Counter() for x in self.saved_val_losses: if not x["sanity_check"]: counts.update({x["dataloader_idx"]: 1}) return counts @property - def num_seen_test_check_batches(self): - counts = Counter() + def num_seen_test_check_batches(self) -> Counter: + counts: Counter = Counter() for x in self.saved_test_losses: if not x["sanity_check"]: counts.update({x["dataloader_idx"]: 1}) From a1cd666c081f9cac27e2d008baee6c7ba60ada2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Aug 2021 06:37:33 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/debugging.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index ae603d5a80a0b..7b316e8b8b4c4 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -41,7 +41,7 @@ def wrapped_fn(self: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: class InternalDebugger: - def __init__(self, trainer: 'pl.Trainer') -> None: + def __init__(self, trainer: "pl.Trainer") -> None: self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1" self.trainer = trainer self.saved_train_losses: List[Dict[str, Union[float, torch.Tensor, object]]] = [] @@ -152,9 +152,7 @@ def track_eval_loss_history(self, batch_idx: int, dataloader_idx: int, output: t @enabled_only def track_early_stopping_history( - self, - callback: 'pl.callbacks.early_stopping.EarlyStopping', - current: torch.Tensor, + self, callback: "pl.callbacks.early_stopping.EarlyStopping", current: torch.Tensor ) -> None: debug_dict = { "epoch": self.trainer.current_epoch, From 59403a0e3956da603190eab71a0c939daf52609a Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 2 Aug 2021 10:22:54 +0200 Subject: [PATCH 3/4] Replace object annotations with Any --- pytorch_lightning/utilities/debugging.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 7b316e8b8b4c4..2307f255a3bc7 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -44,17 +44,17 @@ class InternalDebugger: def __init__(self, trainer: "pl.Trainer") -> None: self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1" self.trainer = trainer - self.saved_train_losses: List[Dict[str, Union[float, torch.Tensor, object]]] = [] - self.saved_val_losses: List[Dict[str, Union[float, torch.Tensor, object]]] = [] - self.saved_test_losses: List[Dict[str, Union[float, torch.Tensor, object]]] = [] - self.early_stopping_history: List[Dict[str, Union[float, torch.Tensor, object]]] = [] - self.checkpoint_callback_history: List[Dict[str, Union[float, torch.Tensor, object]]] = [] + self.saved_train_losses: List[Dict[str, Any]] = [] + self.saved_val_losses: List[Dict[str, Any]] = [] + self.saved_test_losses: List[Dict[str, Any]] = [] + self.early_stopping_history: List[Dict[str, Any]] = [] + self.checkpoint_callback_history: List[Dict[str, Any]] = [] self.events: List[Dict[str, Any]] = [] self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = [] - self.train_dataloader_calls: List[Dict[str, Union[int, str, object]]] = [] - self.val_dataloader_calls: List[Dict[str, Union[int, str, object]]] = [] - self.test_dataloader_calls: List[Dict[str, Union[int, str, object]]] = [] - self.dataloader_sequence_calls: List[Dict[str, Union[int, str, object]]] = [] + self.train_dataloader_calls: List[Dict[str, Any]] = [] + self.val_dataloader_calls: List[Dict[str, Any]] = [] + self.test_dataloader_calls: List[Dict[str, Any]] = [] + self.dataloader_sequence_calls: List[Dict[str, Any]] = [] @enabled_only def track_event( From 9bf97dc46caa02ad5f1fffe76594997d9a168174 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Wed, 4 Aug 2021 08:53:20 +0200 Subject: [PATCH 4/4] Fix the last mypy issue --- pytorch_lightning/utilities/debugging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 2307f255a3bc7..ee2b58be106b5 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -170,7 +170,7 @@ def track_checkpointing_history(self, filepath: str) -> None: debug_dict = { "epoch": self.trainer.current_epoch, "global_step": self.trainer.global_step, - "monitor": cb.monitor, + "monitor": cb.monitor if cb is not None else None, "rank": self.trainer.global_rank, "filepath": filepath, }