@@ -251,13 +251,13 @@ def on_train_epoch_start(self, epoch):
251251 self .trainer .call_hook ("on_train_epoch_start" )
252252
253253 def on_train_batch_end (self , epoch_output , epoch_end_outputs , batch , batch_idx , dataloader_idx ):
254+ # hook
255+ self .trainer .call_hook ('on_batch_end' )
256+ self .trainer .call_hook ('on_train_batch_end' , epoch_end_outputs , batch , batch_idx , dataloader_idx )
257+
254258 # figure out what to track for epoch end
255259 self .track_epoch_end_reduce_metrics (epoch_output , epoch_end_outputs )
256260
257- # hook
258- self .trainer .call_hook ("on_batch_end" )
259- self .trainer .call_hook ("on_train_batch_end" , epoch_end_outputs , batch , batch_idx , dataloader_idx )
260-
261261 def reset_train_val_dataloaders (self , model ):
262262 if not self .trainer .reload_dataloaders_every_epoch :
263263 self .trainer .reset_train_dataloader (model )
@@ -303,6 +303,12 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
303303 # when in dev debugging track the losses
304304 self .trainer .dev_debugger .track_train_loss_history (batch_idx , untouched_loss .detach ())
305305
306+ def _check_training_step_output (self , training_step_output ):
307+ if isinstance (training_step_output , torch .Tensor ) and not self .automatic_optimization :
308+ if training_step_output .grad_fn is None :
309+ # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
310+ raise MisconfigurationException ("In manual optimization, `training_step` should not return a Tensor" )
311+
306312 def training_step (self , split_batch , batch_idx , opt_idx , hiddens ):
307313 # give the PL module a result for logging
308314 model = self .trainer .get_model ()
@@ -312,6 +318,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
312318 with self .trainer .profiler .profile ("model_forward" ):
313319 args = self .build_train_args (split_batch , batch_idx , opt_idx , hiddens )
314320 training_step_output = self .trainer .accelerator_backend .training_step (args )
321+ self ._check_training_step_output (training_step_output )
322+
315323 training_step_output = self .trainer .call_hook ("training_step_end" , training_step_output )
316324
317325 training_step_output_for_epoch_end , training_step_output = self ._process_training_step_output (
@@ -724,6 +732,8 @@ def train_step_and_backward_closure():
724732
725733 if self ._curr_step_result is None :
726734 # user decided to skip optimization
735+ # make sure to zero grad.
736+ self .zero_grad_handler (batch_idx , optimizer , opt_idx )
727737 continue
728738
729739 batch_outputs = self ._process_closure_result (
@@ -736,20 +746,11 @@ def train_step_and_backward_closure():
736746 grad_norm_dic = self ._cur_grad_norm_dict
737747 self ._cur_grad_norm_dict = None
738748
739- # hook
740- self .on_before_zero_grad ( optimizer )
749+ # hook + clear gradients
750+ self .zero_grad_handler ( batch_idx , optimizer , opt_idx )
741751
742- # clear gradients
743- self .optimizer_zero_grad (batch_idx , optimizer , opt_idx )
744-
745- accumulated_loss = self .accumulated_loss .mean ()
746-
747- if accumulated_loss is not None :
748- # calculate running loss for display
749- self .running_loss .append (self .accumulated_loss .mean () * self .trainer .accumulate_grad_batches )
750-
751- # reset for next set of accumulated grads
752- self .accumulated_loss .reset ()
752+ # update running loss + reset accumulated loss
753+ self .update_running_loss ()
753754
754755 # collapse all metrics into one dict
755756 batch_log_metrics = {k : v for d in batch_log_metrics for k , v in d .items ()}
@@ -950,3 +951,44 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
950951 epoch_end_outputs .append (optimizer_idx_outputs )
951952
952953 return epoch_end_outputs
954+
955+ def prepare_optimizers (self ):
956+ # in manual optimization we loop over all optimizers at once
957+ optimizers = self .get_optimizers_iterable ()
958+ if not self .automatic_optimization :
959+ optimizers = [optimizers [0 ]]
960+ return optimizers
961+
962+ def run_train_split_start (self , split_idx , split_batch , opt_idx , optimizer ):
963+ # set split_idx to trainer for tracking
964+ self .trainer .split_idx = split_idx
965+
966+ # make sure only the gradients of the current optimizer's parameters are calculated
967+ # in the training step to prevent dangling gradients in multiple-optimizer setup.
968+ if self .automatic_optimization and len (self .trainer .optimizers ) > 1 :
969+ model = self .trainer .get_model ()
970+ model .toggle_optimizer (optimizer , opt_idx )
971+
972+ # use to track metrics internally
973+ self .trainer .logger_connector .on_train_split_start (split_idx , opt_idx , split_batch )
974+
975+ def update_running_loss (self ):
976+ accumulated_loss = self .accumulated_loss .mean ()
977+
978+ if accumulated_loss is not None :
979+ # calculate running loss for display
980+ self .running_loss .append (self .accumulated_loss .mean () * self .trainer .accumulate_grad_batches )
981+
982+ # reset for next set of accumulated grads
983+ self .accumulated_loss .reset ()
984+
985+ def zero_grad_handler (self , batch_idx , optimizer , opt_idx ):
986+ if self .automatic_optimization :
987+ # hook
988+ self .on_before_zero_grad (optimizer )
989+ optimizers = enumerate ([optimizer ])
990+ else :
991+ optimizers = self .get_optimizers_iterable ()
992+
993+ for idx , optimizer in optimizers :
994+ self .optimizer_zero_grad (batch_idx , optimizer , opt_idx )
0 commit comments