@@ -124,7 +124,6 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi
124124 # -------------------
125125 # calculate loss (train step + train step end)
126126 # -------------------
127-
128127 # automatic_optimization=True: perform ddp sync only when performing optimizer_step
129128 # automatic_optimization=False: don't block synchronization here
130129 with self .block_ddp_sync_behaviour ():
@@ -137,6 +136,9 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi
137136 else :
138137 if self .trainer .lightning_module .automatic_optimization :
139138 self .optimizer_step (optimizer , opt_idx , batch_idx , closure )
139+ if len (self .trainer .optimizers ) > 1 :
140+ # revert back to previous state
141+ self .trainer .lightning_module .untoggle_optimizer (opt_idx )
140142 else :
141143 result = self .training_step (split_batch , batch_idx , opt_idx , self ._hiddens )
142144
@@ -448,10 +450,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
448450 "training_step returned None. If this was on purpose, ignore this warning..."
449451 )
450452
451- if len (self .trainer .optimizers ) > 1 :
452- # revert back to previous state
453- self .trainer .lightning_module .untoggle_optimizer (opt_idx )
454-
455453 return result
456454
457455 def _check_finite (self , loss : torch .Tensor ) -> None :
0 commit comments