1616import time
1717from collections import Counter
1818from functools import wraps
19- from typing import Any , Callable , Optional
19+ from typing import Any , Callable , Dict , List , Optional , Union
2020
21+ import torch
22+ from torch .utils .data import DataLoader
2123
22- def enabled_only (fn : Callable ):
24+ import pytorch_lightning as pl
25+
26+
27+ def enabled_only (fn : Callable ) -> Optional [Callable ]:
2328 """Decorate a logger method to run it only on the process with rank 0.
2429
2530 Args:
2631 fn: Function to decorate
2732 """
2833
2934 @wraps (fn )
30- def wrapped_fn (self , * args , ** kwargs ) :
35+ def wrapped_fn (self : Callable , * args : Any , ** kwargs : Any ) -> Optional [ Any ] :
3136 if self .enabled :
3237 fn (self , * args , ** kwargs )
38+ return None
3339
3440 return wrapped_fn
3541
3642
3743class InternalDebugger :
38- def __init__ (self , trainer ) :
44+ def __init__ (self , trainer : "pl.Trainer" ) -> None :
3945 self .enabled = os .environ .get ("PL_DEV_DEBUG" , "0" ) == "1"
4046 self .trainer = trainer
41- self .saved_train_losses = []
42- self .saved_val_losses = []
43- self .saved_test_losses = []
44- self .early_stopping_history = []
45- self .checkpoint_callback_history = []
46- self .events = []
47- self .saved_lr_scheduler_updates = []
48- self .train_dataloader_calls = []
49- self .val_dataloader_calls = []
50- self .test_dataloader_calls = []
51- self .dataloader_sequence_calls = []
47+ self .saved_train_losses : List [ Dict [ str , Any ]] = []
48+ self .saved_val_losses : List [ Dict [ str , Any ]] = []
49+ self .saved_test_losses : List [ Dict [ str , Any ]] = []
50+ self .early_stopping_history : List [ Dict [ str , Any ]] = []
51+ self .checkpoint_callback_history : List [ Dict [ str , Any ]] = []
52+ self .events : List [ Dict [ str , Any ]] = []
53+ self .saved_lr_scheduler_updates : List [ Dict [ str , Union [ int , float , str , torch . Tensor , None ]]] = []
54+ self .train_dataloader_calls : List [ Dict [ str , Any ]] = []
55+ self .val_dataloader_calls : List [ Dict [ str , Any ]] = []
56+ self .test_dataloader_calls : List [ Dict [ str , Any ]] = []
57+ self .dataloader_sequence_calls : List [ Dict [ str , Any ]] = []
5258
5359 @enabled_only
5460 def track_event (
@@ -71,7 +77,7 @@ def track_event(
7177 )
7278
7379 @enabled_only
74- def track_load_dataloader_call (self , name , dataloaders ) :
80+ def track_load_dataloader_call (self , name : str , dataloaders : List [ DataLoader ]) -> None :
7581 loader_counts = len (dataloaders )
7682
7783 lengths = []
@@ -102,14 +108,21 @@ def track_load_dataloader_call(self, name, dataloaders):
102108 self .test_dataloader_calls .append (values )
103109
104110 @enabled_only
105- def track_train_loss_history (self , batch_idx , loss ) :
111+ def track_train_loss_history (self , batch_idx : int , loss : torch . Tensor ) -> None :
106112 loss_dict = {"batch_idx" : batch_idx , "epoch" : self .trainer .current_epoch , "loss" : loss .detach ()}
107113 self .saved_train_losses .append (loss_dict )
108114
109115 @enabled_only
110116 def track_lr_schedulers_update (
111- self , batch_idx , interval , scheduler_idx , old_lr , new_lr , monitor_key = None , monitor_val = None
112- ):
117+ self ,
118+ batch_idx : int ,
119+ interval : int ,
120+ scheduler_idx : int ,
121+ old_lr : float ,
122+ new_lr : float ,
123+ monitor_key : Optional [str ] = None ,
124+ monitor_val : Optional [torch .Tensor ] = None ,
125+ ) -> None :
113126 loss_dict = {
114127 "batch_idx" : batch_idx ,
115128 "interval" : interval ,
@@ -123,7 +136,7 @@ def track_lr_schedulers_update(
123136 self .saved_lr_scheduler_updates .append (loss_dict )
124137
125138 @enabled_only
126- def track_eval_loss_history (self , batch_idx , dataloader_idx , output ) :
139+ def track_eval_loss_history (self , batch_idx : int , dataloader_idx : int , output : torch . Tensor ) -> None :
127140 loss_dict = {
128141 "sanity_check" : self .trainer .sanity_checking ,
129142 "dataloader_idx" : dataloader_idx ,
@@ -138,7 +151,9 @@ def track_eval_loss_history(self, batch_idx, dataloader_idx, output):
138151 self .saved_val_losses .append (loss_dict )
139152
140153 @enabled_only
141- def track_early_stopping_history (self , callback , current ):
154+ def track_early_stopping_history (
155+ self , callback : "pl.callbacks.early_stopping.EarlyStopping" , current : torch .Tensor
156+ ) -> None :
142157 debug_dict = {
143158 "epoch" : self .trainer .current_epoch ,
144159 "global_step" : self .trainer .global_step ,
@@ -150,33 +165,33 @@ def track_early_stopping_history(self, callback, current):
150165 self .early_stopping_history .append (debug_dict )
151166
152167 @enabled_only
153- def track_checkpointing_history (self , filepath ) :
168+ def track_checkpointing_history (self , filepath : str ) -> None :
154169 cb = self .trainer .checkpoint_callback
155170 debug_dict = {
156171 "epoch" : self .trainer .current_epoch ,
157172 "global_step" : self .trainer .global_step ,
158- "monitor" : cb .monitor ,
173+ "monitor" : cb .monitor if cb is not None else None ,
159174 "rank" : self .trainer .global_rank ,
160175 "filepath" : filepath ,
161176 }
162177 self .checkpoint_callback_history .append (debug_dict )
163178
164179 @property
165- def num_seen_sanity_check_batches (self ):
180+ def num_seen_sanity_check_batches (self ) -> int :
166181 count = sum (1 for x in self .saved_val_losses if x ["sanity_check" ])
167182 return count
168183
169184 @property
170- def num_seen_val_check_batches (self ):
171- counts = Counter ()
185+ def num_seen_val_check_batches (self ) -> Counter :
186+ counts : Counter = Counter ()
172187 for x in self .saved_val_losses :
173188 if not x ["sanity_check" ]:
174189 counts .update ({x ["dataloader_idx" ]: 1 })
175190 return counts
176191
177192 @property
178- def num_seen_test_check_batches (self ):
179- counts = Counter ()
193+ def num_seen_test_check_batches (self ) -> Counter :
194+ counts : Counter = Counter ()
180195 for x in self .saved_test_losses :
181196 if not x ["sanity_check" ]:
182197 counts .update ({x ["dataloader_idx" ]: 1 })
0 commit comments