@@ -121,12 +121,6 @@ def on_train_end(self):
121121 return
122122 self ._teardown_already_run = True
123123
124- # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
125- # when a checkpoint was saved at the last step
126- self .trainer .global_step -= 1
127- self .check_checkpoint_callback (should_update = True , is_last = True )
128- self .trainer .global_step += 1
129-
130124 # hook
131125 self .trainer .call_hook ("on_train_end" )
132126
@@ -145,28 +139,6 @@ def on_train_end(self):
145139 # reset bookkeeping
146140 self .trainer ._running_stage = None
147141
148- def check_checkpoint_callback (self , should_update , is_last = False ):
149- # TODO bake this logic into the ModelCheckpoint callback
150- if should_update and self .trainer .checkpoint_connector .has_trained :
151- callbacks = self .trainer .checkpoint_callbacks
152-
153- if is_last and any (cb .save_last and cb .verbose for cb in callbacks ):
154- rank_zero_info ("Saving latest checkpoint..." )
155-
156- model = self .trainer .lightning_module
157-
158- for cb in callbacks :
159- cb .on_validation_end (self .trainer , model )
160-
161- def check_early_stopping_callback (self , should_update ):
162- # TODO bake this logic into the EarlyStopping callback
163- if should_update and self .trainer .checkpoint_connector .has_trained :
164- callbacks = [c for c in self .trainer .callbacks if isinstance (c , EarlyStopping )]
165- model = self .trainer .lightning_module
166-
167- for cb in callbacks :
168- cb .on_validation_end (self .trainer , model )
169-
170142 def on_train_epoch_start (self , epoch ):
171143
172144 # update training progress in trainer
@@ -562,15 +534,14 @@ def run_training_epoch(self):
562534 if (val_loop_called and not should_check_val ) or should_train_only :
563535 self .trainer .optimizer_connector .update_learning_rates (interval = 'epoch' )
564536
565- if should_train_only :
566- self .check_checkpoint_callback (True )
567- self .check_early_stopping_callback (True )
568-
569537 if should_check_val :
570538 self .trainer .validating = True
571539 self .trainer .run_evaluation (on_epoch = True )
572540 self .trainer .training = True
573541
542+ if should_train_only :
543+ self .trainer .call_hook ('on_train_epoch_final_end' )
544+
574545 # increment the global step once
575546 # progress global step according to grads progress
576547 self .increment_accumulated_grad_global_step ()
0 commit comments