diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4c839f3a6c906..044dd95f3b8c6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -105,6 +105,7 @@ def __init__(self, *args, **kwargs): self._current_hook_fx_name = None self._current_dataloader_idx = None self._automatic_optimization: bool = True + self._param_requires_grad_state = dict() def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -1311,7 +1312,7 @@ def untoggle_optimizer(self, optimizer_idx: int): if param in self._param_requires_grad_state: param.requires_grad = self._param_requires_grad_state[param] # save memory - del self._param_requires_grad_state + self._param_requires_grad_state = dict() def optimizer_step( self,