@@ -626,19 +626,13 @@ def run_training_epoch(self):
626626 self .trainer .total_batch_idx += 1
627627
628628 # stop epoch if we limited the number of training batches
629- if self ._num_training_batches_reached ():
629+ if self ._num_training_batches_reached (is_last_batch ):
630630 break
631631
632632 # progress global step according to grads progress
633633 self .increment_accumulated_grad_global_step ()
634634
635635 # epoch end hook
636- should_check_val = self .should_check_val_fx (batch_idx , is_last_batch , on_epoch = True )
637- if should_check_val :
638- self .trainer .run_evaluation (test_mode = False )
639- # reset stage to train
640- self .trainer .logger_connector .set_stage ("train" )
641-
642636 self .run_on_epoch_end_hook (epoch_output )
643637
644638 # log epoch metrics
@@ -649,10 +643,19 @@ def run_training_epoch(self):
649643 self .num_optimizers
650644 )
651645
652- # update LR schedulers
653- self .trainer .optimizer_connector .update_learning_rates (interval = 'epoch' )
646+ should_check_val = self .should_check_val_fx (batch_idx , is_last_batch , on_epoch = True )
647+ if should_check_val :
648+ self .trainer .run_evaluation (test_mode = False , on_epoch = True )
649+ # reset stage to train
650+ self .trainer .logger_connector .set_stage ("train" )
651+
652+ should_skip_eval = sum (self .trainer .num_val_batches ) == 0
653+ should_train_only_check = not self .trainer .enable_validation and should_skip_eval
654+
655+ if should_skip_eval or should_train_only_check :
656+ # update epoch level lr_schedulers
657+ self .trainer .optimizer_connector .update_learning_rates (interval = 'epoch' )
654658
655- should_train_only_check = not self .trainer .enable_validation and (sum (self .trainer .num_val_batches ) == 0 )
656659 self .check_checkpoint_callback (should_train_only_check )
657660 self .check_early_stopping_callback (should_train_only_check )
658661
@@ -890,8 +893,8 @@ def increment_accumulated_grad_global_step(self):
890893 def _accumulated_batches_reached (self ):
891894 return (self .trainer .batch_idx + 1 ) % self .trainer .accumulate_grad_batches == 0
892895
893- def _num_training_batches_reached (self ):
894- return (self .trainer .batch_idx + 1 ) == self .trainer .num_training_batches
896+ def _num_training_batches_reached (self , is_last_batch = False ):
897+ return (self .trainer .batch_idx + 1 ) == self .trainer .num_training_batches or is_last_batch
895898
896899 def should_accumulate (self ):
897900 # checks if backward or backward + optimizer step (via closure)
0 commit comments