1414
1515from collections import OrderedDict
1616from contextlib import contextmanager , suppress
17- from copy import copy , deepcopy
17+ from copy import copy
1818from functools import partial , update_wrapper
1919from typing import Any , Callable , Dict , List , Optional , Tuple , Union
2020
@@ -478,7 +478,6 @@ def run_training_epoch(self):
478478
479479 train_dataloader = self .trainer .data_connector .get_profiled_train_dataloader (train_dataloader )
480480 dataloader_idx = 0
481-
482481 batch_idx = None
483482 is_last_batch = None
484483
@@ -525,8 +524,7 @@ def run_training_epoch(self):
525524 self .save_loggers_on_train_batch_end ()
526525
527526 # update LR schedulers
528- monitor_metrics = deepcopy (self .trainer .logger_connector .callback_metrics )
529- self .update_train_loop_lr_schedulers (monitor_metrics = monitor_metrics )
527+ self .update_lr_schedulers ('step' )
530528 self .trainer .checkpoint_connector .has_trained = True
531529
532530 self .total_batch_idx += 1
@@ -567,7 +565,7 @@ def run_training_epoch(self):
567565
568566 # update epoch level lr_schedulers if no val loop outside train loop is triggered
569567 if not should_check_val or should_train_only :
570- self .trainer . optimizer_connector . update_learning_rates ( interval = 'epoch' )
568+ self .update_lr_schedulers ( 'epoch' )
571569
572570 if should_train_only :
573571 self .check_checkpoint_callback (True )
@@ -863,17 +861,16 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs):
863861 # track gradients
864862 result .grad_norm_dict = self .track_and_norm_grad (optimizer = optimizer )
865863
866- def update_train_loop_lr_schedulers (self , monitor_metrics = None ):
867- num_accumulated_batches_reached = self ._accumulated_batches_reached ()
868- num_training_batches_reached = self ._num_training_batches_reached ()
869-
870- if num_accumulated_batches_reached or num_training_batches_reached :
871- # update lr
872- self .trainer .optimizer_connector .update_learning_rates (
873- interval = "step" ,
874- monitor_metrics = monitor_metrics ,
875- opt_indices = [opt_idx for opt_idx , _ in self .get_active_optimizers ()],
876- )
864+ def update_lr_schedulers (self , interval : str ) -> None :
865+ if interval == "step" :
866+ finished_accumulation = self ._accumulated_batches_reached ()
867+ finished_epoch = self ._num_training_batches_reached ()
868+ if not finished_accumulation and not finished_epoch :
869+ return
870+ self .trainer .optimizer_connector .update_learning_rates (
871+ interval = interval ,
872+ opt_indices = [opt_idx for opt_idx , _ in self .get_active_optimizers ()],
873+ )
877874
878875 def increment_accumulated_grad_global_step (self ):
879876 num_accumulated_batches_reached = self ._accumulated_batches_reached ()
@@ -897,15 +894,21 @@ def should_accumulate(self):
897894
898895 def _should_check_val_fx (self , batch_idx : int , is_last_batch : bool , on_epoch : bool = False ) -> bool :
899896 """ Decide if we should run validation. """
900-
901897 if not self .trainer .enable_validation :
902898 return False
903899
904- # check if this epoch is eligible to run validation
905- if ( self . trainer . current_epoch + 1 ) % self . trainer . check_val_every_n_epoch != 0 :
900+ is_val_check_epoch = ( self . trainer . current_epoch + 1 ) % self . trainer . check_val_every_n_epoch == 0
901+ if not is_val_check_epoch :
906902 return False
907903
908904 # val_check_batch is inf for iterable datasets with no length defined
905+ is_infinite_dataset = self .trainer .val_check_batch == float ('inf' )
906+ if on_epoch and is_last_batch and is_infinite_dataset :
907+ return True
908+
909+ if on_epoch and self .trainer .should_stop :
910+ return True
911+
909912 # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
910913 is_val_check_batch = False
911914 if isinstance (self .trainer .limit_train_batches , int ) and self .trainer .val_check_batch == float ('inf' ):
@@ -915,12 +918,9 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo
915918
916919 # Note: num_training_batches is also inf for iterable datasets with no length defined
917920 epoch_end_val_check = (batch_idx + 1 ) % self .trainer .num_training_batches == 0
918- is_last_batch_for_infinite_dataset = is_last_batch and self .trainer .val_check_batch == float ("inf" )
919921
920922 if on_epoch :
921- return (
922- is_val_check_batch and epoch_end_val_check
923- ) or self .trainer .should_stop or is_last_batch_for_infinite_dataset
923+ return is_val_check_batch and epoch_end_val_check
924924 else :
925925 return is_val_check_batch and not epoch_end_val_check
926926
0 commit comments