|
25 | 25 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
26 | 26 |
|
27 | 27 | if _TORCH_GREATER_EQUAL_1_10: |
28 | | - from torch import autocast |
| 28 | + from torch import autocast as new_autocast |
29 | 29 | else: |
30 | | - from torch.cuda.amp import autocast |
| 30 | + from torch.cuda.amp import autocast as old_autocast |
31 | 31 |
|
32 | 32 |
|
33 | 33 | class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): |
@@ -62,7 +62,7 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) |
62 | 62 | closure_loss = self.scaler.scale(closure_loss) |
63 | 63 | return super().pre_backward(model, closure_loss) |
64 | 64 |
|
65 | | - def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: |
| 65 | + def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: |
66 | 66 | if self.scaler is not None: |
67 | 67 | tensor = self.scaler.scale(tensor) |
68 | 68 | super()._run_backward(tensor, model, *args, **kwargs) |
@@ -93,12 +93,12 @@ def optimizer_step( |
93 | 93 | self.scaler.step(optimizer, **kwargs) |
94 | 94 | self.scaler.update() |
95 | 95 |
|
96 | | - def autocast_context_manager(self) -> autocast: |
| 96 | + def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]: |
97 | 97 | if _TORCH_GREATER_EQUAL_1_10: |
98 | 98 | # the dtype could be automatically inferred but we need to manually set it due to a bug upstream |
99 | 99 | # https://github.com/pytorch/pytorch/issues/67233 |
100 | | - return autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half) |
101 | | - return autocast() |
| 100 | + return new_autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half) |
| 101 | + return old_autocast() |
102 | 102 |
|
103 | 103 | @contextmanager |
104 | 104 | def forward_context(self) -> Generator[None, None, None]: |
|
0 commit comments