@@ -512,28 +512,17 @@ def run_training_epoch(self):
512512 self .update_train_loop_lr_schedulers (monitor_metrics = monitor_metrics )
513513 self .trainer .checkpoint_connector .has_trained = True
514514
515- # max steps reached, end training
516- if (
517- self .trainer .max_steps is not None and self .trainer .max_steps <= self .trainer .global_step + 1
518- and self ._accumulated_batches_reached ()
519- ):
520- break
521-
522- # end epoch early
523- # stop when the flag is changed or we've gone past the amount
524- # requested in the batches
525- if self .trainer .should_stop :
526- break
527-
528515 self .trainer .total_batch_idx += 1
529516
530- # stop epoch if we limited the number of training batches
531- if self ._num_training_batches_reached (is_last_batch ):
532- break
533-
534517 # progress global step according to grads progress
535518 self .increment_accumulated_grad_global_step ()
536519
520+ max_steps_reached = (
521+ self .trainer .max_steps is not None and self .trainer .max_steps <= self .trainer .global_step
522+ )
523+ if max_steps_reached or self .trainer .should_stop or self ._num_training_batches_reached (is_last_batch ):
524+ break
525+
537526 if batch_idx is None :
538527 # dataloader/iterator did not produce a batch
539528 return
@@ -552,18 +541,6 @@ def run_training_epoch(self):
552541 if (val_loop_called and not should_check_val ) or should_train_only :
553542 self .trainer .optimizer_connector .update_learning_rates (interval = 'epoch' )
554543
555- if should_train_only :
556- self .check_checkpoint_callback (True )
557-
558- if should_check_val :
559- self .trainer .validating = True
560- self .trainer .run_evaluation (on_epoch = True )
561- self .trainer .training = True
562-
563- # increment the global step once
564- # progress global step according to grads progress
565- self .increment_accumulated_grad_global_step ()
566-
567544 def on_train_epoch_end (self , epoch_output : List [List [List [Result ]]]) -> None :
568545 # inform logger the batch loop has finished
569546 self .trainer .logger_connector .on_train_epoch_end ()
@@ -863,16 +840,12 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo
863840 elif self .trainer .val_check_batch != float ('inf' ):
864841 is_val_check_batch = (batch_idx + 1 ) % self .trainer .val_check_batch == 0
865842
866- # Note: num_training_batches is also inf for iterable datasets with no length defined
867- epoch_end_val_check = (batch_idx + 1 ) % self .trainer .num_training_batches == 0
868843 is_last_batch_for_infinite_dataset = is_last_batch and self .trainer .val_check_batch == float ("inf" )
869844
870845 if on_epoch :
871- return (
872- is_val_check_batch and epoch_end_val_check
873- ) or self .trainer .should_stop or is_last_batch_for_infinite_dataset
846+ return is_val_check_batch or self .trainer .should_stop or is_last_batch_for_infinite_dataset
874847 else :
875- return is_val_check_batch and not epoch_end_val_check
848+ return is_val_check_batch
876849
877850 def build_train_args (self , batch , batch_idx , opt_idx , hiddens ):
878851 # enable not needing to add opt_idx to training_step
0 commit comments