2222import weakref
2323from contextlib import contextmanager
2424from pathlib import Path
25- from typing import Any , Callable , Dict , List , Mapping , Optional , overload , Sequence , Tuple , Union
25+ from typing import Any , Callable , Dict , Generator , List , Mapping , Optional , overload , Sequence , Tuple , Union
2626
2727import torch
2828from torch import ScriptModule , Tensor
4747from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_11 , _TORCH_GREATER_EQUAL_1_13
4848from pytorch_lightning .utilities .rank_zero import rank_zero_debug , rank_zero_deprecation , rank_zero_warn
4949from pytorch_lightning .utilities .signature_utils import is_param_in_hook_signature
50- from pytorch_lightning .utilities .types import _METRIC_COLLECTION , EPOCH_OUTPUT , LRSchedulerTypeUnion , STEP_OUTPUT
50+ from pytorch_lightning .utilities .types import (
51+ _METRIC_COLLECTION ,
52+ EPOCH_OUTPUT ,
53+ LRSchedulerPLType ,
54+ LRSchedulerTypeUnion ,
55+ STEP_OUTPUT ,
56+ )
5157from pytorch_lightning .utilities .warnings import WarningCache
5258
5359warning_cache = WarningCache ()
5460log = logging .getLogger (__name__ )
5561
62+ MODULE_OPTIMIZERS = Union [Optimizer , LightningOptimizer , List [Optimizer ], List [LightningOptimizer ]]
63+
5664
5765class LightningModule (
5866 DeviceDtypeModuleMixin ,
@@ -104,7 +112,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
104112 self ._current_fx_name : Optional [str ] = None
105113 self ._automatic_optimization : bool = True
106114 self ._truncated_bptt_steps : int = 0
107- self ._param_requires_grad_state = {}
115+ self ._param_requires_grad_state : Dict [ str , bool ] = {}
108116 self ._metric_attributes : Optional [Dict [int , str ]] = None
109117 self ._should_prevent_trainer_and_dataloaders_deepcopy : bool = False
110118 # TODO: remove in 1.8
@@ -121,14 +129,10 @@ def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[
121129 ...
122130
123131 @overload
124- def optimizers (
125- self , use_pl_optimizer : bool
126- ) -> Union [Optimizer , LightningOptimizer , List [Optimizer ], List [LightningOptimizer ]]:
132+ def optimizers (self , use_pl_optimizer : bool ) -> MODULE_OPTIMIZERS :
127133 ...
128134
129- def optimizers (
130- self , use_pl_optimizer : bool = True
131- ) -> Union [Optimizer , LightningOptimizer , List [Optimizer ], List [LightningOptimizer ]]:
135+ def optimizers (self , use_pl_optimizer : bool = True ) -> MODULE_OPTIMIZERS :
132136 """Returns the optimizer(s) that are being used during training. Useful for manual optimization.
133137
134138 Args:
@@ -140,7 +144,7 @@ def optimizers(
140144 A single optimizer, or a list of optimizers in case multiple ones are present.
141145 """
142146 if use_pl_optimizer :
143- opts = list (self .trainer .strategy ._lightning_optimizers .values ())
147+ opts : MODULE_OPTIMIZERS = list (self .trainer .strategy ._lightning_optimizers .values ())
144148 else :
145149 opts = self .trainer .optimizers
146150
@@ -150,7 +154,7 @@ def optimizers(
150154 # multiple opts
151155 return opts
152156
153- def lr_schedulers (self ) -> Optional [ Union [LRSchedulerTypeUnion , List [LRSchedulerTypeUnion ]] ]:
157+ def lr_schedulers (self ) -> Union [None , List [LRSchedulerPLType ], LRSchedulerPLType ]:
154158 """Returns the learning rate scheduler(s) that are being used during training. Useful for manual
155159 optimization.
156160
@@ -162,7 +166,7 @@ def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRScheduler
162166 return None
163167
164168 # ignore other keys "interval", "frequency", etc.
165- lr_schedulers = [config .scheduler for config in self .trainer .lr_scheduler_configs ]
169+ lr_schedulers : List [ LRSchedulerPLType ] = [config .scheduler for config in self .trainer .lr_scheduler_configs ]
166170
167171 # single scheduler
168172 if len (lr_schedulers ) == 1 :
@@ -175,13 +179,13 @@ def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRScheduler
175179 def trainer (self ) -> "pl.Trainer" :
176180 if not self ._running_torchscript and self ._trainer is None :
177181 raise RuntimeError (f"{ self .__class__ .__qualname__ } is not attached to a `Trainer`." )
178- return self ._trainer
182+ return self ._trainer # type: ignore[return-value]
179183
180184 @trainer .setter
181185 def trainer (self , trainer : Optional ["pl.Trainer" ]) -> None :
182186 for v in self .children ():
183187 if isinstance (v , LightningModule ):
184- v .trainer = trainer
188+ v .trainer = trainer # type: ignore[assignment]
185189 if trainer is not None and not isinstance (trainer , weakref .ProxyTypes ):
186190 trainer = weakref .proxy (trainer )
187191 self ._trainer = trainer
@@ -228,7 +232,7 @@ def local_rank(self) -> int:
228232 return self .trainer .local_rank if self ._trainer else 0
229233
230234 @property
231- def on_gpu (self ):
235+ def on_gpu (self ) -> bool :
232236 """Returns ``True`` if this model is currently located on a GPU.
233237
234238 Useful to set flags around the LightningModule for different CPU vs GPU behavior.
@@ -264,7 +268,7 @@ def logger(self) -> Optional[Logger]:
264268 # this should match the implementation of `trainer.logger`
265269 # we don't reuse it so we can properly set the deprecation stacklevel
266270 if self ._trainer is None :
267- return
271+ return None
268272 loggers = self .trainer .loggers
269273 if len (loggers ) == 0 :
270274 return None
@@ -287,15 +291,15 @@ def loggers(self) -> List[Logger]:
287291 """Reference to the list of loggers in the Trainer."""
288292 return self .trainer .loggers if self ._trainer else []
289293
290- def _call_batch_hook (self , hook_name , * args ) -> Any :
294+ def _call_batch_hook (self , hook_name : str , * args : Any ) -> Any :
291295 if self ._trainer :
292296 datahook_selector = self ._trainer ._data_connector ._datahook_selector
293297 obj = datahook_selector .get_instance (hook_name )
294- trainer_method = (
295- self ._trainer ._call_lightning_module_hook
296- if isinstance ( obj , self . __class__ )
297- else self ._trainer ._call_lightning_datamodule_hook
298- )
298+ if isinstance ( obj , self . __class__ ):
299+ trainer_method = self ._trainer ._call_lightning_module_hook
300+ else :
301+ trainer_method = self ._trainer ._call_lightning_datamodule_hook
302+
299303 return trainer_method (hook_name , * args )
300304 else :
301305 hook = getattr (self , hook_name )
@@ -312,7 +316,7 @@ def _apply_batch_transfer_handler(
312316 batch = self ._call_batch_hook ("on_after_batch_transfer" , batch , dataloader_idx )
313317 return batch
314318
315- def print (self , * args , ** kwargs ) -> None :
319+ def print (self , * args : Any , ** kwargs : Any ) -> None :
316320 r"""
317321 Prints only from process 0. Use this in any distributed mode to log only once.
318322
@@ -463,7 +467,7 @@ def log(
463467 logger = logger ,
464468 on_step = on_step ,
465469 on_epoch = on_epoch ,
466- reduce_fx = reduce_fx ,
470+ reduce_fx = reduce_fx , # type: ignore[arg-type]
467471 enable_graph = enable_graph ,
468472 add_dataloader_idx = add_dataloader_idx ,
469473 batch_size = batch_size ,
@@ -578,7 +582,9 @@ def log_grad_norm(self, grad_norm_dict):
578582 """
579583 self .log_dict (grad_norm_dict , on_step = True , on_epoch = True , prog_bar = False , logger = True )
580584
581- def all_gather (self , data : Union [Tensor , Dict , List , Tuple ], group : Optional [Any ] = None , sync_grads : bool = False ):
585+ def all_gather (
586+ self , data : Union [Tensor , Dict , List , Tuple ], group : Optional [Any ] = None , sync_grads : bool = False
587+ ) -> Union [Tensor , Dict , List , Tuple ]:
582588 r"""
583589 Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ``all_gather`` operation
584590 accelerator agnostic. ``all_gather`` is a function provided by accelerators to gather a tensor from several
@@ -598,7 +604,7 @@ def all_gather(self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any
598604 data = convert_to_tensors (data , device = self .device )
599605 return apply_to_collection (data , Tensor , all_gather , group = group , sync_grads = sync_grads )
600606
601- def forward (self , * args , ** kwargs ) -> Any :
607+ def forward (self , * args : Any , ** kwargs : Any ) -> Any :
602608 r"""
603609 Same as :meth:`torch.nn.Module.forward()`.
604610
@@ -611,7 +617,7 @@ def forward(self, *args, **kwargs) -> Any:
611617 """
612618 return super ().forward (* args , ** kwargs )
613619
614- def training_step (self , * args , ** kwargs ) -> STEP_OUTPUT :
620+ def training_step (self , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
615621 r"""
616622 Here you compute and return the training loss and some additional metrics for e.g.
617623 the progress bar or logger.
@@ -769,7 +775,7 @@ def training_epoch_end(self, training_step_outputs):
769775 ...
770776 """
771777
772- def validation_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
778+ def validation_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
773779 r"""
774780 Operates on a single batch of data from the validation set.
775781 In this step you'd might generate examples or calculate anything of interest like accuracy.
@@ -858,7 +864,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
858864 the model goes back to training mode and gradients are enabled.
859865 """
860866
861- def validation_step_end (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
867+ def validation_step_end (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
862868 """Use this when validating with dp because :meth:`validation_step` will operate on only part of the batch.
863869 However, this is still optional and only needed for things like softmax or NCE loss.
864870
@@ -955,7 +961,7 @@ def validation_epoch_end(self, outputs):
955961 self.log("final_metric", final_value)
956962 """
957963
958- def test_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
964+ def test_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
959965 r"""
960966 Operates on a single batch of data from the test set.
961967 In this step you'd normally generate examples or calculate anything of interest
@@ -1035,7 +1041,7 @@ def test_step(self, batch, batch_idx, dataloader_idx=0):
10351041 to training mode and gradients are enabled.
10361042 """
10371043
1038- def test_step_end (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
1044+ def test_step_end (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
10391045 """Use this when testing with DP because :meth:`test_step` will operate on only part of the batch. However,
10401046 this is still optional and only needed for things like softmax or NCE loss.
10411047
@@ -1200,7 +1206,7 @@ def configure_callbacks(self):
12001206 """
12011207 return []
12021208
1203- def configure_optimizers (self ):
1209+ def configure_optimizers (self ) -> Any :
12041210 r"""
12051211 Choose what optimizers and learning-rate schedulers to use in your optimization.
12061212 Normally you'd need one. But in the case of GANs or similar you might have multiple.
@@ -1374,7 +1380,7 @@ def configure_optimizers(self):
13741380 """
13751381 rank_zero_warn ("`configure_optimizers` must be implemented to be used with the Lightning Trainer" )
13761382
1377- def manual_backward (self , loss : Tensor , * args , ** kwargs ) -> None :
1383+ def manual_backward (self , loss : Tensor , * args : Any , ** kwargs : Any ) -> None :
13781384 """Call this directly from your :meth:`training_step` when doing optimizations manually. By using this,
13791385 Lightning can ensure that all the proper scaling gets applied when using mixed precision.
13801386
@@ -1399,7 +1405,7 @@ def training_step(...):
13991405 self .trainer .strategy .backward (loss , None , None , * args , ** kwargs )
14001406
14011407 def backward (
1402- self , loss : Tensor , optimizer : Optional [Optimizer ], optimizer_idx : Optional [int ], * args , ** kwargs
1408+ self , loss : Tensor , optimizer : Optional [Optimizer ], optimizer_idx : Optional [int ], * args : Any , ** kwargs : Any
14031409 ) -> None :
14041410 """Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your
14051411 own implementation if you need to.
@@ -1442,7 +1448,7 @@ def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer], opti
14421448
14431449 # Then iterate over the current optimizer's parameters and set its `requires_grad`
14441450 # properties accordingly
1445- for group in optimizer .param_groups :
1451+ for group in optimizer .param_groups : # type: ignore[union-attr]
14461452 for param in group ["params" ]:
14471453 param .requires_grad = param_requires_grad_state [param ]
14481454 self ._param_requires_grad_state = param_requires_grad_state
@@ -1469,7 +1475,7 @@ def clip_gradients(
14691475 optimizer : Optimizer ,
14701476 gradient_clip_val : Optional [Union [int , float ]] = None ,
14711477 gradient_clip_algorithm : Optional [str ] = None ,
1472- ):
1478+ ) -> None :
14731479 """Handles gradient clipping internally.
14741480
14751481 Note:
@@ -1523,7 +1529,7 @@ def configure_gradient_clipping(
15231529 optimizer_idx : int ,
15241530 gradient_clip_val : Optional [Union [int , float ]] = None ,
15251531 gradient_clip_algorithm : Optional [str ] = None ,
1526- ):
1532+ ) -> None :
15271533 """Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.
15281534
15291535 Args:
@@ -1584,7 +1590,7 @@ def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
15841590
15851591 """
15861592 if metric is None :
1587- scheduler .step ()
1593+ scheduler .step () # type: ignore[call-arg]
15881594 else :
15891595 scheduler .step (metric )
15901596
@@ -1672,7 +1678,7 @@ def optimizer_step(
16721678 """
16731679 optimizer .step (closure = optimizer_closure )
16741680
1675- def optimizer_zero_grad (self , epoch : int , batch_idx : int , optimizer : Optimizer , optimizer_idx : int ):
1681+ def optimizer_zero_grad (self , epoch : int , batch_idx : int , optimizer : Optimizer , optimizer_idx : int ) -> None :
16761682 """Override this method to change the default behaviour of ``optimizer.zero_grad()``.
16771683
16781684 Args:
@@ -1741,12 +1747,11 @@ def tbptt_split_batch(self, batch, split_size):
17411747 for t in range (0 , time_dims [0 ], split_size ):
17421748 batch_split = []
17431749 for i , x in enumerate (batch ):
1750+ split_x : Union [Tensor , List [Tensor ]]
17441751 if isinstance (x , Tensor ):
17451752 split_x = x [:, t : t + split_size ]
1746- elif isinstance (x , collections .abc .Sequence ):
1747- split_x = [None ] * len (x )
1748- for batch_idx in range (len (x )):
1749- split_x [batch_idx ] = x [batch_idx ][t : t + split_size ]
1753+ elif isinstance (x , collections .Sequence ):
1754+ split_x = [x [batch_idx ][t : t + split_size ] for batch_idx in range (len (x ))]
17501755
17511756 batch_split .append (split_x )
17521757
@@ -1782,15 +1787,15 @@ def unfreeze(self) -> None:
17821787
17831788 self .train ()
17841789
1785- def _verify_is_manual_optimization (self , fn_name ) :
1790+ def _verify_is_manual_optimization (self , fn_name : str ) -> None :
17861791 if self .automatic_optimization :
17871792 raise MisconfigurationException (
17881793 f"to use { fn_name } , please disable automatic optimization:"
17891794 " set model property `automatic_optimization` as False"
17901795 )
17911796
17921797 @torch .no_grad ()
1793- def to_onnx (self , file_path : Union [str , Path ], input_sample : Optional [Any ] = None , ** kwargs ) :
1798+ def to_onnx (self , file_path : Union [str , Path ], input_sample : Optional [Any ] = None , ** kwargs : Any ) -> None :
17941799 """Saves the model in ONNX format.
17951800
17961801 Args:
@@ -1829,7 +1834,7 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non
18291834
18301835 if not _TORCH_GREATER_EQUAL_1_10 and "example_outputs" not in kwargs :
18311836 self .eval ()
1832- if isinstance (input_sample , Tuple ):
1837+ if isinstance (input_sample , tuple ):
18331838 kwargs ["example_outputs" ] = self (* input_sample )
18341839 else :
18351840 kwargs ["example_outputs" ] = self (input_sample )
@@ -1843,7 +1848,7 @@ def to_torchscript(
18431848 file_path : Optional [Union [str , Path ]] = None ,
18441849 method : Optional [str ] = "script" ,
18451850 example_inputs : Optional [Any ] = None ,
1846- ** kwargs ,
1851+ ** kwargs : Any ,
18471852 ) -> Union [ScriptModule , Dict [str , ScriptModule ]]:
18481853 """By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing,
18491854 please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is
@@ -1953,7 +1958,7 @@ def use_amp(self, use_amp: bool) -> None:
19531958 self ._use_amp = use_amp
19541959
19551960 @contextmanager
1956- def _prevent_trainer_and_dataloaders_deepcopy (self ) -> None :
1961+ def _prevent_trainer_and_dataloaders_deepcopy (self ) -> Generator [ None , None , None ] :
19571962 self ._should_prevent_trainer_and_dataloaders_deepcopy = True
19581963 yield
19591964 self ._should_prevent_trainer_and_dataloaders_deepcopy = False
@@ -1988,4 +1993,6 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
19881993 self ._register_load_state_dict_pre_hook (pre_load_state_dict_hook , True )
19891994 else :
19901995 # We need to make sure the self inside the method is a weakref proxy
1991- self .__class__ ._register_load_state_dict_pre_hook (weakref .proxy (self ), pre_load_state_dict_hook , True )
1996+ self .__class__ ._register_load_state_dict_pre_hook (
1997+ weakref .proxy (self ), pre_load_state_dict_hook , True # type: ignore[arg-type]
1998+ )
0 commit comments