Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ module = [
"pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.cli",
"pytorch_lightning.utilities.cloud_io",
"pytorch_lightning.utilities.debugging",
"pytorch_lightning.utilities.device_dtype_mixin",
"pytorch_lightning.utilities.device_parser",
"pytorch_lightning.utilities.parsing",
Expand Down
71 changes: 43 additions & 28 deletions pytorch_lightning/utilities/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,45 @@
import time
from collections import Counter
from functools import wraps
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch.utils.data import DataLoader

def enabled_only(fn: Callable):
import pytorch_lightning as pl


def enabled_only(fn: Callable) -> Optional[Callable]:
"""Decorate a logger method to run it only on the process with rank 0.

Args:
fn: Function to decorate
"""

@wraps(fn)
def wrapped_fn(self, *args, **kwargs):
def wrapped_fn(self: Callable, *args: Any, **kwargs: Any) -> Optional[Any]:
if self.enabled:
fn(self, *args, **kwargs)
return None

return wrapped_fn


class InternalDebugger:
def __init__(self, trainer):
def __init__(self, trainer: "pl.Trainer") -> None:
self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1"
self.trainer = trainer
self.saved_train_losses = []
self.saved_val_losses = []
self.saved_test_losses = []
self.early_stopping_history = []
self.checkpoint_callback_history = []
self.events = []
self.saved_lr_scheduler_updates = []
self.train_dataloader_calls = []
self.val_dataloader_calls = []
self.test_dataloader_calls = []
self.dataloader_sequence_calls = []
self.saved_train_losses: List[Dict[str, Any]] = []
self.saved_val_losses: List[Dict[str, Any]] = []
self.saved_test_losses: List[Dict[str, Any]] = []
self.early_stopping_history: List[Dict[str, Any]] = []
self.checkpoint_callback_history: List[Dict[str, Any]] = []
self.events: List[Dict[str, Any]] = []
self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = []
self.train_dataloader_calls: List[Dict[str, Any]] = []
self.val_dataloader_calls: List[Dict[str, Any]] = []
self.test_dataloader_calls: List[Dict[str, Any]] = []
self.dataloader_sequence_calls: List[Dict[str, Any]] = []

@enabled_only
def track_event(
Expand All @@ -71,7 +77,7 @@ def track_event(
)

@enabled_only
def track_load_dataloader_call(self, name, dataloaders):
def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]) -> None:
loader_counts = len(dataloaders)

lengths = []
Expand Down Expand Up @@ -102,14 +108,21 @@ def track_load_dataloader_call(self, name, dataloaders):
self.test_dataloader_calls.append(values)

@enabled_only
def track_train_loss_history(self, batch_idx, loss):
def track_train_loss_history(self, batch_idx: int, loss: torch.Tensor) -> None:
loss_dict = {"batch_idx": batch_idx, "epoch": self.trainer.current_epoch, "loss": loss.detach()}
self.saved_train_losses.append(loss_dict)

@enabled_only
def track_lr_schedulers_update(
self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None
):
self,
batch_idx: int,
interval: int,
scheduler_idx: int,
old_lr: float,
new_lr: float,
monitor_key: Optional[str] = None,
monitor_val: Optional[torch.Tensor] = None,
) -> None:
loss_dict = {
"batch_idx": batch_idx,
"interval": interval,
Expand All @@ -123,7 +136,7 @@ def track_lr_schedulers_update(
self.saved_lr_scheduler_updates.append(loss_dict)

@enabled_only
def track_eval_loss_history(self, batch_idx, dataloader_idx, output):
def track_eval_loss_history(self, batch_idx: int, dataloader_idx: int, output: torch.Tensor) -> None:
loss_dict = {
"sanity_check": self.trainer.sanity_checking,
"dataloader_idx": dataloader_idx,
Expand All @@ -138,7 +151,9 @@ def track_eval_loss_history(self, batch_idx, dataloader_idx, output):
self.saved_val_losses.append(loss_dict)

@enabled_only
def track_early_stopping_history(self, callback, current):
def track_early_stopping_history(
self, callback: "pl.callbacks.early_stopping.EarlyStopping", current: torch.Tensor
) -> None:
debug_dict = {
"epoch": self.trainer.current_epoch,
"global_step": self.trainer.global_step,
Expand All @@ -150,33 +165,33 @@ def track_early_stopping_history(self, callback, current):
self.early_stopping_history.append(debug_dict)

@enabled_only
def track_checkpointing_history(self, filepath):
def track_checkpointing_history(self, filepath: str) -> None:
cb = self.trainer.checkpoint_callback
debug_dict = {
"epoch": self.trainer.current_epoch,
"global_step": self.trainer.global_step,
"monitor": cb.monitor,
"monitor": cb.monitor if cb is not None else None,
"rank": self.trainer.global_rank,
"filepath": filepath,
}
self.checkpoint_callback_history.append(debug_dict)

@property
def num_seen_sanity_check_batches(self):
def num_seen_sanity_check_batches(self) -> int:
count = sum(1 for x in self.saved_val_losses if x["sanity_check"])
return count

@property
def num_seen_val_check_batches(self):
counts = Counter()
def num_seen_val_check_batches(self) -> Counter:
counts: Counter = Counter()
for x in self.saved_val_losses:
if not x["sanity_check"]:
counts.update({x["dataloader_idx"]: 1})
return counts

@property
def num_seen_test_check_batches(self):
counts = Counter()
def num_seen_test_check_batches(self) -> Counter:
counts: Counter = Counter()
for x in self.saved_test_losses:
if not x["sanity_check"]:
counts.update({x["dataloader_idx"]: 1})
Expand Down