Skip to content

Commit 28aab26

Browse files
jxtngxawaelchlicarmocca
authored andcommitted
Fix mypy errors attributed to pytorch_lightning.core.module.py (#13603)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 583429a commit 28aab26

File tree

5 files changed

+60
-53
lines changed

5 files changed

+60
-53
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ module = [
5858
"pytorch_lightning.callbacks.progress.rich_progress",
5959
"pytorch_lightning.callbacks.quantization",
6060
"pytorch_lightning.core.datamodule",
61-
"pytorch_lightning.core.module",
6261
"pytorch_lightning.demos.boring_classes",
6362
"pytorch_lightning.demos.mnist_datamodule",
6463
"pytorch_lightning.profilers.base",

src/pytorch_lightning/core/mixins/device_dtype_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def dtype(self, new_dtype: Union[str, torch.dtype]) -> None:
3737
raise RuntimeError("Cannot set the dtype explicitly. Please use module.to(new_dtype).")
3838

3939
@property
40-
def device(self) -> Union[str, torch.device]:
40+
def device(self) -> torch.device:
4141
device = self._device
4242

4343
# make this more explicit to always include the index

src/pytorch_lightning/core/module.py

Lines changed: 56 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import weakref
2323
from contextlib import contextmanager
2424
from pathlib import Path
25-
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union
25+
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, overload, Sequence, Tuple, Union
2626

2727
import torch
2828
from torch import ScriptModule, Tensor
@@ -47,12 +47,20 @@
4747
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13
4848
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn
4949
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
50-
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT
50+
from pytorch_lightning.utilities.types import (
51+
_METRIC_COLLECTION,
52+
EPOCH_OUTPUT,
53+
LRSchedulerPLType,
54+
LRSchedulerTypeUnion,
55+
STEP_OUTPUT,
56+
)
5157
from pytorch_lightning.utilities.warnings import WarningCache
5258

5359
warning_cache = WarningCache()
5460
log = logging.getLogger(__name__)
5561

62+
MODULE_OPTIMIZERS = Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]
63+
5664

5765
class LightningModule(
5866
DeviceDtypeModuleMixin,
@@ -104,7 +112,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
104112
self._current_fx_name: Optional[str] = None
105113
self._automatic_optimization: bool = True
106114
self._truncated_bptt_steps: int = 0
107-
self._param_requires_grad_state = {}
115+
self._param_requires_grad_state: Dict[str, bool] = {}
108116
self._metric_attributes: Optional[Dict[int, str]] = None
109117
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
110118
# TODO: remove in 1.8
@@ -121,14 +129,10 @@ def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[
121129
...
122130

123131
@overload
124-
def optimizers(
125-
self, use_pl_optimizer: bool
126-
) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]:
132+
def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS:
127133
...
128134

129-
def optimizers(
130-
self, use_pl_optimizer: bool = True
131-
) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]:
135+
def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
132136
"""Returns the optimizer(s) that are being used during training. Useful for manual optimization.
133137
134138
Args:
@@ -140,7 +144,7 @@ def optimizers(
140144
A single optimizer, or a list of optimizers in case multiple ones are present.
141145
"""
142146
if use_pl_optimizer:
143-
opts = list(self.trainer.strategy._lightning_optimizers.values())
147+
opts: MODULE_OPTIMIZERS = list(self.trainer.strategy._lightning_optimizers.values())
144148
else:
145149
opts = self.trainer.optimizers
146150

@@ -150,7 +154,7 @@ def optimizers(
150154
# multiple opts
151155
return opts
152156

153-
def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRSchedulerTypeUnion]]]:
157+
def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]:
154158
"""Returns the learning rate scheduler(s) that are being used during training. Useful for manual
155159
optimization.
156160
@@ -162,7 +166,7 @@ def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRScheduler
162166
return None
163167

164168
# ignore other keys "interval", "frequency", etc.
165-
lr_schedulers = [config.scheduler for config in self.trainer.lr_scheduler_configs]
169+
lr_schedulers: List[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs]
166170

167171
# single scheduler
168172
if len(lr_schedulers) == 1:
@@ -175,13 +179,13 @@ def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRScheduler
175179
def trainer(self) -> "pl.Trainer":
176180
if not self._running_torchscript and self._trainer is None:
177181
raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.")
178-
return self._trainer
182+
return self._trainer # type: ignore[return-value]
179183

180184
@trainer.setter
181185
def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
182186
for v in self.children():
183187
if isinstance(v, LightningModule):
184-
v.trainer = trainer
188+
v.trainer = trainer # type: ignore[assignment]
185189
if trainer is not None and not isinstance(trainer, weakref.ProxyTypes):
186190
trainer = weakref.proxy(trainer)
187191
self._trainer = trainer
@@ -228,7 +232,7 @@ def local_rank(self) -> int:
228232
return self.trainer.local_rank if self._trainer else 0
229233

230234
@property
231-
def on_gpu(self):
235+
def on_gpu(self) -> bool:
232236
"""Returns ``True`` if this model is currently located on a GPU.
233237
234238
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
@@ -264,7 +268,7 @@ def logger(self) -> Optional[Logger]:
264268
# this should match the implementation of `trainer.logger`
265269
# we don't reuse it so we can properly set the deprecation stacklevel
266270
if self._trainer is None:
267-
return
271+
return None
268272
loggers = self.trainer.loggers
269273
if len(loggers) == 0:
270274
return None
@@ -287,15 +291,15 @@ def loggers(self) -> List[Logger]:
287291
"""Reference to the list of loggers in the Trainer."""
288292
return self.trainer.loggers if self._trainer else []
289293

290-
def _call_batch_hook(self, hook_name, *args) -> Any:
294+
def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
291295
if self._trainer:
292296
datahook_selector = self._trainer._data_connector._datahook_selector
293297
obj = datahook_selector.get_instance(hook_name)
294-
trainer_method = (
295-
self._trainer._call_lightning_module_hook
296-
if isinstance(obj, self.__class__)
297-
else self._trainer._call_lightning_datamodule_hook
298-
)
298+
if isinstance(obj, self.__class__):
299+
trainer_method = self._trainer._call_lightning_module_hook
300+
else:
301+
trainer_method = self._trainer._call_lightning_datamodule_hook
302+
299303
return trainer_method(hook_name, *args)
300304
else:
301305
hook = getattr(self, hook_name)
@@ -312,7 +316,7 @@ def _apply_batch_transfer_handler(
312316
batch = self._call_batch_hook("on_after_batch_transfer", batch, dataloader_idx)
313317
return batch
314318

315-
def print(self, *args, **kwargs) -> None:
319+
def print(self, *args: Any, **kwargs: Any) -> None:
316320
r"""
317321
Prints only from process 0. Use this in any distributed mode to log only once.
318322
@@ -463,7 +467,7 @@ def log(
463467
logger=logger,
464468
on_step=on_step,
465469
on_epoch=on_epoch,
466-
reduce_fx=reduce_fx,
470+
reduce_fx=reduce_fx, # type: ignore[arg-type]
467471
enable_graph=enable_graph,
468472
add_dataloader_idx=add_dataloader_idx,
469473
batch_size=batch_size,
@@ -578,7 +582,9 @@ def log_grad_norm(self, grad_norm_dict):
578582
"""
579583
self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=False, logger=True)
580584

581-
def all_gather(self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False):
585+
def all_gather(
586+
self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
587+
) -> Union[Tensor, Dict, List, Tuple]:
582588
r"""
583589
Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ``all_gather`` operation
584590
accelerator agnostic. ``all_gather`` is a function provided by accelerators to gather a tensor from several
@@ -598,7 +604,7 @@ def all_gather(self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any
598604
data = convert_to_tensors(data, device=self.device)
599605
return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads)
600606

601-
def forward(self, *args, **kwargs) -> Any:
607+
def forward(self, *args: Any, **kwargs: Any) -> Any:
602608
r"""
603609
Same as :meth:`torch.nn.Module.forward()`.
604610
@@ -611,7 +617,7 @@ def forward(self, *args, **kwargs) -> Any:
611617
"""
612618
return super().forward(*args, **kwargs)
613619

614-
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
620+
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
615621
r"""
616622
Here you compute and return the training loss and some additional metrics for e.g.
617623
the progress bar or logger.
@@ -769,7 +775,7 @@ def training_epoch_end(self, training_step_outputs):
769775
...
770776
"""
771777

772-
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
778+
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
773779
r"""
774780
Operates on a single batch of data from the validation set.
775781
In this step you'd might generate examples or calculate anything of interest like accuracy.
@@ -858,7 +864,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
858864
the model goes back to training mode and gradients are enabled.
859865
"""
860866

861-
def validation_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
867+
def validation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
862868
"""Use this when validating with dp because :meth:`validation_step` will operate on only part of the batch.
863869
However, this is still optional and only needed for things like softmax or NCE loss.
864870
@@ -955,7 +961,7 @@ def validation_epoch_end(self, outputs):
955961
self.log("final_metric", final_value)
956962
"""
957963

958-
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
964+
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
959965
r"""
960966
Operates on a single batch of data from the test set.
961967
In this step you'd normally generate examples or calculate anything of interest
@@ -1035,7 +1041,7 @@ def test_step(self, batch, batch_idx, dataloader_idx=0):
10351041
to training mode and gradients are enabled.
10361042
"""
10371043

1038-
def test_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
1044+
def test_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
10391045
"""Use this when testing with DP because :meth:`test_step` will operate on only part of the batch. However,
10401046
this is still optional and only needed for things like softmax or NCE loss.
10411047
@@ -1200,7 +1206,7 @@ def configure_callbacks(self):
12001206
"""
12011207
return []
12021208

1203-
def configure_optimizers(self):
1209+
def configure_optimizers(self) -> Any:
12041210
r"""
12051211
Choose what optimizers and learning-rate schedulers to use in your optimization.
12061212
Normally you'd need one. But in the case of GANs or similar you might have multiple.
@@ -1374,7 +1380,7 @@ def configure_optimizers(self):
13741380
"""
13751381
rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer")
13761382

1377-
def manual_backward(self, loss: Tensor, *args, **kwargs) -> None:
1383+
def manual_backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None:
13781384
"""Call this directly from your :meth:`training_step` when doing optimizations manually. By using this,
13791385
Lightning can ensure that all the proper scaling gets applied when using mixed precision.
13801386
@@ -1399,7 +1405,7 @@ def training_step(...):
13991405
self.trainer.strategy.backward(loss, None, None, *args, **kwargs)
14001406

14011407
def backward(
1402-
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs
1408+
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args: Any, **kwargs: Any
14031409
) -> None:
14041410
"""Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your
14051411
own implementation if you need to.
@@ -1442,7 +1448,7 @@ def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer], opti
14421448

14431449
# Then iterate over the current optimizer's parameters and set its `requires_grad`
14441450
# properties accordingly
1445-
for group in optimizer.param_groups:
1451+
for group in optimizer.param_groups: # type: ignore[union-attr]
14461452
for param in group["params"]:
14471453
param.requires_grad = param_requires_grad_state[param]
14481454
self._param_requires_grad_state = param_requires_grad_state
@@ -1469,7 +1475,7 @@ def clip_gradients(
14691475
optimizer: Optimizer,
14701476
gradient_clip_val: Optional[Union[int, float]] = None,
14711477
gradient_clip_algorithm: Optional[str] = None,
1472-
):
1478+
) -> None:
14731479
"""Handles gradient clipping internally.
14741480
14751481
Note:
@@ -1523,7 +1529,7 @@ def configure_gradient_clipping(
15231529
optimizer_idx: int,
15241530
gradient_clip_val: Optional[Union[int, float]] = None,
15251531
gradient_clip_algorithm: Optional[str] = None,
1526-
):
1532+
) -> None:
15271533
"""Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.
15281534
15291535
Args:
@@ -1584,7 +1590,7 @@ def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
15841590
15851591
"""
15861592
if metric is None:
1587-
scheduler.step()
1593+
scheduler.step() # type: ignore[call-arg]
15881594
else:
15891595
scheduler.step(metric)
15901596

@@ -1672,7 +1678,7 @@ def optimizer_step(
16721678
"""
16731679
optimizer.step(closure=optimizer_closure)
16741680

1675-
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
1681+
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int) -> None:
16761682
"""Override this method to change the default behaviour of ``optimizer.zero_grad()``.
16771683
16781684
Args:
@@ -1741,12 +1747,11 @@ def tbptt_split_batch(self, batch, split_size):
17411747
for t in range(0, time_dims[0], split_size):
17421748
batch_split = []
17431749
for i, x in enumerate(batch):
1750+
split_x: Union[Tensor, List[Tensor]]
17441751
if isinstance(x, Tensor):
17451752
split_x = x[:, t : t + split_size]
1746-
elif isinstance(x, collections.abc.Sequence):
1747-
split_x = [None] * len(x)
1748-
for batch_idx in range(len(x)):
1749-
split_x[batch_idx] = x[batch_idx][t : t + split_size]
1753+
elif isinstance(x, collections.Sequence):
1754+
split_x = [x[batch_idx][t : t + split_size] for batch_idx in range(len(x))]
17501755

17511756
batch_split.append(split_x)
17521757

@@ -1782,15 +1787,15 @@ def unfreeze(self) -> None:
17821787

17831788
self.train()
17841789

1785-
def _verify_is_manual_optimization(self, fn_name):
1790+
def _verify_is_manual_optimization(self, fn_name: str) -> None:
17861791
if self.automatic_optimization:
17871792
raise MisconfigurationException(
17881793
f"to use {fn_name}, please disable automatic optimization:"
17891794
" set model property `automatic_optimization` as False"
17901795
)
17911796

17921797
@torch.no_grad()
1793-
def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs):
1798+
def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
17941799
"""Saves the model in ONNX format.
17951800
17961801
Args:
@@ -1829,7 +1834,7 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non
18291834

18301835
if not _TORCH_GREATER_EQUAL_1_10 and "example_outputs" not in kwargs:
18311836
self.eval()
1832-
if isinstance(input_sample, Tuple):
1837+
if isinstance(input_sample, tuple):
18331838
kwargs["example_outputs"] = self(*input_sample)
18341839
else:
18351840
kwargs["example_outputs"] = self(input_sample)
@@ -1843,7 +1848,7 @@ def to_torchscript(
18431848
file_path: Optional[Union[str, Path]] = None,
18441849
method: Optional[str] = "script",
18451850
example_inputs: Optional[Any] = None,
1846-
**kwargs,
1851+
**kwargs: Any,
18471852
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
18481853
"""By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing,
18491854
please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is
@@ -1953,7 +1958,7 @@ def use_amp(self, use_amp: bool) -> None:
19531958
self._use_amp = use_amp
19541959

19551960
@contextmanager
1956-
def _prevent_trainer_and_dataloaders_deepcopy(self) -> None:
1961+
def _prevent_trainer_and_dataloaders_deepcopy(self) -> Generator[None, None, None]:
19571962
self._should_prevent_trainer_and_dataloaders_deepcopy = True
19581963
yield
19591964
self._should_prevent_trainer_and_dataloaders_deepcopy = False
@@ -1988,4 +1993,6 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
19881993
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
19891994
else:
19901995
# We need to make sure the self inside the method is a weakref proxy
1991-
self.__class__._register_load_state_dict_pre_hook(weakref.proxy(self), pre_load_state_dict_hook, True)
1996+
self.__class__._register_load_state_dict_pre_hook(
1997+
weakref.proxy(self), pre_load_state_dict_hook, True # type: ignore[arg-type]
1998+
)

src/pytorch_lightning/overrides/data_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import numbers
1515
import warnings
16-
from typing import Any, cast, Optional, Union
16+
from typing import Any, Optional, Union
1717

1818
import torch
1919
from torch import Tensor
@@ -77,7 +77,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
7777
output = super().forward(*inputs, **kwargs)
7878

7979
def output_transform(data: Any) -> Any:
80-
device = cast(torch.device, self.lightning_module.device)
80+
device = self.lightning_module.device
8181
data = python_scalar_to_tensor(data, device)
8282
data = unsqueeze_scalar_tensor(data)
8383
return data

0 commit comments

Comments
 (0)