@@ -303,12 +303,6 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
303303 # when in dev debugging track the losses
304304 self .trainer .dev_debugger .track_train_loss_history (batch_idx , untouched_loss .detach ())
305305
306- def _check_training_step_output (self , training_step_output ):
307- if isinstance (training_step_output , torch .Tensor ) and not self .automatic_optimization :
308- if training_step_output .grad_fn is None :
309- # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
310- raise MisconfigurationException ("In manual optimization, `training_step` should not return a Tensor" )
311-
312306 def training_step (self , split_batch , batch_idx , opt_idx , hiddens ):
313307 # give the PL module a result for logging
314308 model = self .trainer .get_model ()
@@ -318,8 +312,6 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
318312 with self .trainer .profiler .profile ("model_forward" ):
319313 args = self .build_train_args (split_batch , batch_idx , opt_idx , hiddens )
320314 training_step_output = self .trainer .accelerator_backend .training_step (args )
321- self ._check_training_step_output (training_step_output )
322-
323315 training_step_output = self .trainer .call_hook ("training_step_end" , training_step_output )
324316
325317 training_step_output_for_epoch_end , training_step_output = self ._process_training_step_output (
@@ -620,9 +612,6 @@ def run_training_epoch(self):
620612 # progress global step according to grads progress
621613 self .increment_accumulated_grad_global_step ()
622614
623- # epoch end hook
624- self .run_on_epoch_end_hook (epoch_output )
625-
626615 # log epoch metrics
627616 self .trainer .logger_connector .log_train_epoch_end_metrics (
628617 epoch_output , self .checkpoint_accumulator , self .early_stopping_accumulator , self .num_optimizers
@@ -734,8 +723,6 @@ def train_step_and_backward_closure():
734723
735724 if self ._curr_step_result is None :
736725 # user decided to skip optimization
737- # make sure to zero grad.
738- self .zero_grad_handler (batch_idx , optimizer , opt_idx )
739726 continue
740727
741728 batch_outputs = self ._process_closure_result (
@@ -748,11 +735,20 @@ def train_step_and_backward_closure():
748735 grad_norm_dic = self ._cur_grad_norm_dict
749736 self ._cur_grad_norm_dict = None
750737
751- # hook + clear gradients
752- self .zero_grad_handler (batch_idx , optimizer , opt_idx )
738+ # hook
739+ self .on_before_zero_grad (optimizer )
740+
741+ # clear gradients
742+ self .optimizer_zero_grad (batch_idx , optimizer , opt_idx )
753743
754- # update running loss + reset accumulated loss
755- self .update_running_loss ()
744+ accumulated_loss = self .accumulated_loss .mean ()
745+
746+ if accumulated_loss is not None :
747+ # calculate running loss for display
748+ self .running_loss .append (self .accumulated_loss .mean () * self .trainer .accumulate_grad_batches )
749+
750+ # reset for next set of accumulated grads
751+ self .accumulated_loss .reset ()
756752
757753 # collapse all metrics into one dict
758754 batch_log_metrics = {k : v for d in batch_log_metrics for k , v in d .items ()}
@@ -953,44 +949,3 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
953949 epoch_end_outputs .append (optimizer_idx_outputs )
954950
955951 return epoch_end_outputs
956-
957- def prepare_optimizers (self ):
958- # in manual optimization we loop over all optimizers at once
959- optimizers = self .get_optimizers_iterable ()
960- if not self .automatic_optimization :
961- optimizers = [optimizers [0 ]]
962- return optimizers
963-
964- def run_train_split_start (self , split_idx , split_batch , opt_idx , optimizer ):
965- # set split_idx to trainer for tracking
966- self .trainer .split_idx = split_idx
967-
968- # make sure only the gradients of the current optimizer's parameters are calculated
969- # in the training step to prevent dangling gradients in multiple-optimizer setup.
970- if self .automatic_optimization and len (self .trainer .optimizers ) > 1 :
971- model = self .trainer .get_model ()
972- model .toggle_optimizer (optimizer , opt_idx )
973-
974- # use to track metrics internally
975- self .trainer .logger_connector .on_train_split_start (split_idx , opt_idx , split_batch )
976-
977- def update_running_loss (self ):
978- accumulated_loss = self .accumulated_loss .mean ()
979-
980- if accumulated_loss is not None :
981- # calculate running loss for display
982- self .running_loss .append (self .accumulated_loss .mean () * self .trainer .accumulate_grad_batches )
983-
984- # reset for next set of accumulated grads
985- self .accumulated_loss .reset ()
986-
987- def zero_grad_handler (self , batch_idx , optimizer , opt_idx ):
988- if self .automatic_optimization :
989- # hook
990- self .on_before_zero_grad (optimizer )
991- optimizers = enumerate ([optimizer ])
992- else :
993- optimizers = self .get_optimizers_iterable ()
994-
995- for idx , optimizer in optimizers :
996- self .optimizer_zero_grad (batch_idx , optimizer , opt_idx )
0 commit comments