Skip to content

Commit 038c151

Browse files
awaelchlicarmocca
andauthored
Improve typing for plugins (#10742)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 81a0a44 commit 038c151

File tree

4 files changed

+10
-11
lines changed

4 files changed

+10
-11
lines changed

pyproject.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,6 @@ module = [
8686
"pytorch_lightning.plugins.environments.lsf_environment",
8787
"pytorch_lightning.plugins.environments.slurm_environment",
8888
"pytorch_lightning.plugins.environments.torchelastic_environment",
89-
"pytorch_lightning.plugins.precision.deepspeed",
90-
"pytorch_lightning.plugins.precision.native_amp",
91-
"pytorch_lightning.plugins.precision.precision_plugin",
9289
"pytorch_lightning.plugins.training_type.ddp",
9390
"pytorch_lightning.plugins.training_type.ddp2",
9491
"pytorch_lightning.plugins.training_type.ddp_spawn",

pytorch_lightning/plugins/precision/deepspeed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any
4949
deepspeed_engine: DeepSpeedEngine = model.trainer.model
5050
deepspeed_engine.backward(closure_loss, *args, **kwargs)
5151

52-
def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None:
52+
def _run_backward(self, tensor: Tensor, model: Optional["DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None:
53+
if model is None:
54+
raise ValueError("Please provide the model as input to `backward`.")
5355
model.backward(tensor, *args, **kwargs)
5456

5557
def optimizer_step(

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626

2727
if _TORCH_GREATER_EQUAL_1_10:
28-
from torch import autocast
28+
from torch import autocast as new_autocast
2929
else:
30-
from torch.cuda.amp import autocast
30+
from torch.cuda.amp import autocast as old_autocast
3131

3232

3333
class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
@@ -62,7 +62,7 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor)
6262
closure_loss = self.scaler.scale(closure_loss)
6363
return super().pre_backward(model, closure_loss)
6464

65-
def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None:
65+
def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
6666
if self.scaler is not None:
6767
tensor = self.scaler.scale(tensor)
6868
super()._run_backward(tensor, model, *args, **kwargs)
@@ -93,12 +93,12 @@ def optimizer_step(
9393
self.scaler.step(optimizer, **kwargs)
9494
self.scaler.update()
9595

96-
def autocast_context_manager(self) -> autocast:
96+
def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]:
9797
if _TORCH_GREATER_EQUAL_1_10:
9898
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
9999
# https://github.com/pytorch/pytorch/issues/67233
100-
return autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
101-
return autocast()
100+
return new_autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
101+
return old_autocast()
102102

103103
@contextmanager
104104
def forward_context(self) -> Generator[None, None, None]:

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def optimizer_step(
147147
"""Hook to run the optimizer step."""
148148
if isinstance(model, pl.LightningModule):
149149
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
150-
optimizer.step(closure=closure, **kwargs)
150+
optimizer.step(closure=closure, **kwargs) # type: ignore[call-arg]
151151

152152
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
153153
if trainer.track_grad_norm == -1:

0 commit comments

Comments
 (0)