@@ -303,6 +303,12 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
303303 # when in dev debugging track the losses
304304 self .trainer .dev_debugger .track_train_loss_history (batch_idx , untouched_loss .detach ())
305305
306+ def _check_training_step_output (self , training_step_output ):
307+ if isinstance (training_step_output , torch .Tensor ) and not self .automatic_optimization :
308+ if training_step_output .grad_fn is None :
309+ # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
310+ raise MisconfigurationException ("In manual optimization, `training_step` should not return a Tensor" )
311+
306312 def training_step (self , split_batch , batch_idx , opt_idx , hiddens ):
307313 # give the PL module a result for logging
308314 model = self .trainer .get_model ()
@@ -312,6 +318,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
312318 with self .trainer .profiler .profile ("model_forward" ):
313319 args = self .build_train_args (split_batch , batch_idx , opt_idx , hiddens )
314320 training_step_output = self .trainer .accelerator_backend .training_step (args )
321+ self ._check_training_step_output (training_step_output )
322+
315323 training_step_output = self .trainer .call_hook ("training_step_end" , training_step_output )
316324
317325 training_step_output_for_epoch_end , training_step_output = self ._process_training_step_output (
@@ -612,6 +620,9 @@ def run_training_epoch(self):
612620 # progress global step according to grads progress
613621 self .increment_accumulated_grad_global_step ()
614622
623+ # epoch end hook
624+ self .run_on_epoch_end_hook (epoch_output )
625+
615626 # log epoch metrics
616627 self .trainer .logger_connector .log_train_epoch_end_metrics (
617628 epoch_output , self .checkpoint_accumulator , self .early_stopping_accumulator , self .num_optimizers
@@ -723,6 +734,8 @@ def train_step_and_backward_closure():
723734
724735 if self ._curr_step_result is None :
725736 # user decided to skip optimization
737+ # make sure to zero grad.
738+ self .zero_grad_handler (batch_idx , optimizer , opt_idx )
726739 continue
727740
728741 batch_outputs = self ._process_closure_result (
@@ -735,20 +748,11 @@ def train_step_and_backward_closure():
735748 grad_norm_dic = self ._cur_grad_norm_dict
736749 self ._cur_grad_norm_dict = None
737750
738- # hook
739- self .on_before_zero_grad (optimizer )
740-
741- # clear gradients
742- self .optimizer_zero_grad (batch_idx , optimizer , opt_idx )
751+ # hook + clear gradients
752+ self .zero_grad_handler (batch_idx , optimizer , opt_idx )
743753
744- accumulated_loss = self .accumulated_loss .mean ()
745-
746- if accumulated_loss is not None :
747- # calculate running loss for display
748- self .running_loss .append (self .accumulated_loss .mean () * self .trainer .accumulate_grad_batches )
749-
750- # reset for next set of accumulated grads
751- self .accumulated_loss .reset ()
754+ # update running loss + reset accumulated loss
755+ self .update_running_loss ()
752756
753757 # collapse all metrics into one dict
754758 batch_log_metrics = {k : v for d in batch_log_metrics for k , v in d .items ()}
@@ -949,3 +953,44 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
949953 epoch_end_outputs .append (optimizer_idx_outputs )
950954
951955 return epoch_end_outputs
956+
957+ def prepare_optimizers (self ):
958+ # in manual optimization we loop over all optimizers at once
959+ optimizers = self .get_optimizers_iterable ()
960+ if not self .automatic_optimization :
961+ optimizers = [optimizers [0 ]]
962+ return optimizers
963+
964+ def run_train_split_start (self , split_idx , split_batch , opt_idx , optimizer ):
965+ # set split_idx to trainer for tracking
966+ self .trainer .split_idx = split_idx
967+
968+ # make sure only the gradients of the current optimizer's parameters are calculated
969+ # in the training step to prevent dangling gradients in multiple-optimizer setup.
970+ if self .automatic_optimization and len (self .trainer .optimizers ) > 1 :
971+ model = self .trainer .get_model ()
972+ model .toggle_optimizer (optimizer , opt_idx )
973+
974+ # use to track metrics internally
975+ self .trainer .logger_connector .on_train_split_start (split_idx , opt_idx , split_batch )
976+
977+ def update_running_loss (self ):
978+ accumulated_loss = self .accumulated_loss .mean ()
979+
980+ if accumulated_loss is not None :
981+ # calculate running loss for display
982+ self .running_loss .append (self .accumulated_loss .mean () * self .trainer .accumulate_grad_batches )
983+
984+ # reset for next set of accumulated grads
985+ self .accumulated_loss .reset ()
986+
987+ def zero_grad_handler (self , batch_idx , optimizer , opt_idx ):
988+ if self .automatic_optimization :
989+ # hook
990+ self .on_before_zero_grad (optimizer )
991+ optimizers = enumerate ([optimizer ])
992+ else :
993+ optimizers = self .get_optimizers_iterable ()
994+
995+ for idx , optimizer in optimizers :
996+ self .optimizer_zero_grad (batch_idx , optimizer , opt_idx )
0 commit comments