From ab4efc7e429a8826d7853cc79202b8e558020088 Mon Sep 17 00:00:00 2001 From: Anshul Date: Tue, 5 Jan 2021 07:46:05 +0000 Subject: [PATCH] :bug: Fix grad getting enabled for parameters on which explicitly disabled in case of multiple optimizers --- pytorch_lightning/core/lightning.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a4330b401936d..ddfe41489291b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -112,6 +112,8 @@ def __init__(self, *args, **kwargs): self._current_hook_fx_name = None self._current_dataloader_idx = None + self.param_grad_dict = {} + def optimizers(self): opts = self.trainer.optimizers @@ -1163,11 +1165,13 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): optimizer_idx: """ for param in self.parameters(): + if param not in self.param_grad_dict: + self.param_grad_dict[param] = param.requires_grad param.requires_grad = False for group in optimizer.param_groups: for param in group['params']: - param.requires_grad = True + param.requires_grad = self.param_grad_dict[param] def optimizer_step( self,