Skip to content

Commit 69cd927

Browse files
authored
Fix mypy typing for utilities.debugging (#8672)
1 parent aacd131 commit 69cd927

File tree

2 files changed

+44
-28
lines changed

2 files changed

+44
-28
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ module = [
6565
"pytorch_lightning.utilities.argparse",
6666
"pytorch_lightning.utilities.cli",
6767
"pytorch_lightning.utilities.cloud_io",
68+
"pytorch_lightning.utilities.debugging",
6869
"pytorch_lightning.utilities.device_dtype_mixin",
6970
"pytorch_lightning.utilities.device_parser",
7071
"pytorch_lightning.utilities.distributed",

pytorch_lightning/utilities/debugging.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,45 @@
1616
import time
1717
from collections import Counter
1818
from functools import wraps
19-
from typing import Any, Callable, Optional
19+
from typing import Any, Callable, Dict, List, Optional, Union
2020

21+
import torch
22+
from torch.utils.data import DataLoader
2123

22-
def enabled_only(fn: Callable):
24+
import pytorch_lightning as pl
25+
26+
27+
def enabled_only(fn: Callable) -> Optional[Callable]:
2328
"""Decorate a logger method to run it only on the process with rank 0.
2429
2530
Args:
2631
fn: Function to decorate
2732
"""
2833

2934
@wraps(fn)
30-
def wrapped_fn(self, *args, **kwargs):
35+
def wrapped_fn(self: Callable, *args: Any, **kwargs: Any) -> Optional[Any]:
3136
if self.enabled:
3237
fn(self, *args, **kwargs)
38+
return None
3339

3440
return wrapped_fn
3541

3642

3743
class InternalDebugger:
38-
def __init__(self, trainer):
44+
def __init__(self, trainer: "pl.Trainer") -> None:
3945
self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1"
4046
self.trainer = trainer
41-
self.saved_train_losses = []
42-
self.saved_val_losses = []
43-
self.saved_test_losses = []
44-
self.early_stopping_history = []
45-
self.checkpoint_callback_history = []
46-
self.events = []
47-
self.saved_lr_scheduler_updates = []
48-
self.train_dataloader_calls = []
49-
self.val_dataloader_calls = []
50-
self.test_dataloader_calls = []
51-
self.dataloader_sequence_calls = []
47+
self.saved_train_losses: List[Dict[str, Any]] = []
48+
self.saved_val_losses: List[Dict[str, Any]] = []
49+
self.saved_test_losses: List[Dict[str, Any]] = []
50+
self.early_stopping_history: List[Dict[str, Any]] = []
51+
self.checkpoint_callback_history: List[Dict[str, Any]] = []
52+
self.events: List[Dict[str, Any]] = []
53+
self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = []
54+
self.train_dataloader_calls: List[Dict[str, Any]] = []
55+
self.val_dataloader_calls: List[Dict[str, Any]] = []
56+
self.test_dataloader_calls: List[Dict[str, Any]] = []
57+
self.dataloader_sequence_calls: List[Dict[str, Any]] = []
5258

5359
@enabled_only
5460
def track_event(
@@ -71,7 +77,7 @@ def track_event(
7177
)
7278

7379
@enabled_only
74-
def track_load_dataloader_call(self, name, dataloaders):
80+
def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]) -> None:
7581
loader_counts = len(dataloaders)
7682

7783
lengths = []
@@ -102,14 +108,21 @@ def track_load_dataloader_call(self, name, dataloaders):
102108
self.test_dataloader_calls.append(values)
103109

104110
@enabled_only
105-
def track_train_loss_history(self, batch_idx, loss):
111+
def track_train_loss_history(self, batch_idx: int, loss: torch.Tensor) -> None:
106112
loss_dict = {"batch_idx": batch_idx, "epoch": self.trainer.current_epoch, "loss": loss.detach()}
107113
self.saved_train_losses.append(loss_dict)
108114

109115
@enabled_only
110116
def track_lr_schedulers_update(
111-
self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None
112-
):
117+
self,
118+
batch_idx: int,
119+
interval: int,
120+
scheduler_idx: int,
121+
old_lr: float,
122+
new_lr: float,
123+
monitor_key: Optional[str] = None,
124+
monitor_val: Optional[torch.Tensor] = None,
125+
) -> None:
113126
loss_dict = {
114127
"batch_idx": batch_idx,
115128
"interval": interval,
@@ -123,7 +136,7 @@ def track_lr_schedulers_update(
123136
self.saved_lr_scheduler_updates.append(loss_dict)
124137

125138
@enabled_only
126-
def track_eval_loss_history(self, batch_idx, dataloader_idx, output):
139+
def track_eval_loss_history(self, batch_idx: int, dataloader_idx: int, output: torch.Tensor) -> None:
127140
loss_dict = {
128141
"sanity_check": self.trainer.sanity_checking,
129142
"dataloader_idx": dataloader_idx,
@@ -138,7 +151,9 @@ def track_eval_loss_history(self, batch_idx, dataloader_idx, output):
138151
self.saved_val_losses.append(loss_dict)
139152

140153
@enabled_only
141-
def track_early_stopping_history(self, callback, current):
154+
def track_early_stopping_history(
155+
self, callback: "pl.callbacks.early_stopping.EarlyStopping", current: torch.Tensor
156+
) -> None:
142157
debug_dict = {
143158
"epoch": self.trainer.current_epoch,
144159
"global_step": self.trainer.global_step,
@@ -150,33 +165,33 @@ def track_early_stopping_history(self, callback, current):
150165
self.early_stopping_history.append(debug_dict)
151166

152167
@enabled_only
153-
def track_checkpointing_history(self, filepath):
168+
def track_checkpointing_history(self, filepath: str) -> None:
154169
cb = self.trainer.checkpoint_callback
155170
debug_dict = {
156171
"epoch": self.trainer.current_epoch,
157172
"global_step": self.trainer.global_step,
158-
"monitor": cb.monitor,
173+
"monitor": cb.monitor if cb is not None else None,
159174
"rank": self.trainer.global_rank,
160175
"filepath": filepath,
161176
}
162177
self.checkpoint_callback_history.append(debug_dict)
163178

164179
@property
165-
def num_seen_sanity_check_batches(self):
180+
def num_seen_sanity_check_batches(self) -> int:
166181
count = sum(1 for x in self.saved_val_losses if x["sanity_check"])
167182
return count
168183

169184
@property
170-
def num_seen_val_check_batches(self):
171-
counts = Counter()
185+
def num_seen_val_check_batches(self) -> Counter:
186+
counts: Counter = Counter()
172187
for x in self.saved_val_losses:
173188
if not x["sanity_check"]:
174189
counts.update({x["dataloader_idx"]: 1})
175190
return counts
176191

177192
@property
178-
def num_seen_test_check_batches(self):
179-
counts = Counter()
193+
def num_seen_test_check_batches(self) -> Counter:
194+
counts: Counter = Counter()
180195
for x in self.saved_test_losses:
181196
if not x["sanity_check"]:
182197
counts.update({x["dataloader_idx"]: 1})

0 commit comments

Comments
 (0)