@@ -303,6 +303,12 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
303303 # when in dev debugging track the losses
304304 self .trainer .dev_debugger .track_train_loss_history (batch_idx , untouched_loss .detach ())
305305
306+ def _check_training_step_output (self , training_step_output ):
307+ if isinstance (training_step_output , torch .Tensor ) and not self .automatic_optimization :
308+ if training_step_output .grad_fn is None :
309+ # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
310+ raise MisconfigurationException ("In manual optimization, `training_step` should not return a Tensor" )
311+
306312 def training_step (self , split_batch , batch_idx , opt_idx , hiddens ):
307313 # give the PL module a result for logging
308314 model = self .trainer .get_model ()
@@ -312,6 +318,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
312318 with self .trainer .profiler .profile ("model_forward" ):
313319 args = self .build_train_args (split_batch , batch_idx , opt_idx , hiddens )
314320 training_step_output = self .trainer .accelerator_backend .training_step (args )
321+ self ._check_training_step_output (training_step_output )
322+
315323 training_step_output = self .trainer .call_hook ("training_step_end" , training_step_output )
316324
317325 training_step_output_for_epoch_end , training_step_output = self ._process_training_step_output (
@@ -723,6 +731,8 @@ def train_step_and_backward_closure():
723731
724732 if self ._curr_step_result is None :
725733 # user decided to skip optimization
734+ # make sure to zero grad.
735+ self .zero_grad_handler (batch_idx , optimizer , opt_idx )
726736 continue
727737
728738 batch_outputs = self ._process_closure_result (
@@ -735,11 +745,8 @@ def train_step_and_backward_closure():
735745 grad_norm_dic = self ._cur_grad_norm_dict
736746 self ._cur_grad_norm_dict = None
737747
738- # hook
739- self .on_before_zero_grad (optimizer )
740-
741- # clear gradients
742- self .optimizer_zero_grad (batch_idx , optimizer , opt_idx )
748+ # hook + clear gradients
749+ self .zero_grad_handler (batch_idx , optimizer , opt_idx )
743750
744751 accumulated_loss = self .accumulated_loss .mean ()
745752
@@ -949,3 +956,44 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
949956 epoch_end_outputs .append (optimizer_idx_outputs )
950957
951958 return epoch_end_outputs
959+
960+ def prepare_optimizers (self ):
961+ # in manual optimization we loop over all optimizers at once
962+ optimizers = self .get_optimizers_iterable ()
963+ if not self .automatic_optimization :
964+ optimizers = [optimizers [0 ]]
965+ return optimizers
966+
967+ def run_train_split_start (self , split_idx , split_batch , opt_idx , optimizer ):
968+ # set split_idx to trainer for tracking
969+ self .trainer .split_idx = split_idx
970+
971+ # make sure only the gradients of the current optimizer's parameters are calculated
972+ # in the training step to prevent dangling gradients in multiple-optimizer setup.
973+ if self .automatic_optimization and len (self .trainer .optimizers ) > 1 :
974+ model = self .trainer .get_model ()
975+ model .toggle_optimizer (optimizer , opt_idx )
976+
977+ # use to track metrics internally
978+ self .trainer .logger_connector .on_train_split_start (split_idx , opt_idx , split_batch )
979+
980+ def update_running_loss (self ):
981+ accumulated_loss = self .accumulated_loss .mean ()
982+
983+ if accumulated_loss is not None :
984+ # calculate running loss for display
985+ self .running_loss .append (self .accumulated_loss .mean () * self .trainer .accumulate_grad_batches )
986+
987+ # reset for next set of accumulated grads
988+ self .accumulated_loss .reset ()
989+
990+ def zero_grad_handler (self , batch_idx , optimizer , opt_idx ):
991+ if self .automatic_optimization :
992+ # hook
993+ self .on_before_zero_grad (optimizer )
994+ optimizers = enumerate ([optimizer ])
995+ else :
996+ optimizers = self .get_optimizers_iterable ()
997+
998+ for idx , optimizer in optimizers :
999+ self .optimizer_zero_grad (batch_idx , optimizer , opt_idx )
0 commit comments