1818import numpy as np
1919import torch
2020
21- from pytorch_lightning .callbacks import ModelCheckpoint
21+ from pytorch_lightning .callbacks import EarlyStopping
2222from pytorch_lightning .core .memory import ModelSummary
2323from pytorch_lightning .core .optimizer import LightningOptimizer
2424from pytorch_lightning .core .step_result import Result
@@ -161,7 +161,7 @@ def on_train_end(self):
161161 # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
162162 # when a checkpoint was saved at the last step
163163 self .trainer .global_step -= 1
164- self .check_checkpoint_callback (should_save = True , is_last = True )
164+ self .check_checkpoint_callback (should_update = True , is_last = True )
165165 self .trainer .global_step += 1
166166
167167 # hook
@@ -184,18 +184,27 @@ def on_train_end(self):
184184 model .cpu ()
185185 torch .cuda .empty_cache ()
186186
187- def check_checkpoint_callback (self , should_save , is_last = False ):
188- # TODO bake this logic into the checkpoint callback
189- if should_save and self .trainer .checkpoint_connector .has_trained :
190- checkpoint_callbacks = [ c for c in self .trainer .callbacks if isinstance ( c , ModelCheckpoint )]
187+ def check_checkpoint_callback (self , should_update , is_last = False ):
188+ # TODO bake this logic into the ModelCheckpoint callback
189+ if should_update and self .trainer .checkpoint_connector .has_trained :
190+ callbacks = self .trainer .checkpoint_callbacks
191191
192- if is_last and any (c .save_last for c in checkpoint_callbacks ):
192+ if is_last and any (cb .save_last for cb in callbacks ):
193193 rank_zero_info ("Saving latest checkpoint..." )
194194
195195 model = self .trainer .get_model ()
196196
197- for callback in checkpoint_callbacks :
198- callback .on_validation_end (self .trainer , model )
197+ for cb in callbacks :
198+ cb .on_validation_end (self .trainer , model )
199+
200+ def check_early_stopping_callback (self , should_update ):
201+ # TODO bake this logic into the EarlyStopping callback
202+ if should_update and self .trainer .checkpoint_connector .has_trained :
203+ callbacks = [c for c in self .trainer .callbacks if isinstance (c , EarlyStopping )]
204+ model = self .trainer .get_model ()
205+
206+ for cb in callbacks :
207+ cb .on_validation_end (self .trainer , model )
199208
200209 def on_train_epoch_start (self , epoch ):
201210
@@ -521,7 +530,6 @@ def tbptt_split_batch(self, batch):
521530 return splits
522531
523532 def run_training_epoch (self ):
524-
525533 # get model
526534 model = self .trainer .get_model ()
527535
@@ -584,11 +592,12 @@ def run_training_epoch(self):
584592 self .trainer .checkpoint_connector .has_trained = True
585593
586594 # max steps reached, end training
587- if self .trainer .max_steps is not None and self .trainer .max_steps == self .trainer .global_step + 1 :
588- accumulation_done = self ._accumulated_batches_reached ()
589- # Ensure accumulation across batches has completed before breaking loop
590- if accumulation_done :
591- break
595+ if (
596+ self .trainer .max_steps is not None
597+ and self .trainer .max_steps == self .trainer .global_step + 1
598+ and self ._accumulated_batches_reached ()
599+ ):
600+ break
592601
593602 # end epoch early
594603 # stop when the flag is changed or we've gone past the amount
@@ -599,7 +608,7 @@ def run_training_epoch(self):
599608 self .trainer .total_batch_idx += 1
600609
601610 # stop epoch if we limited the number of training batches
602- if ( batch_idx + 1 ) >= self .trainer . num_training_batches :
611+ if self ._num_training_batches_reached ( is_last_batch ) :
603612 break
604613
605614 # progress global step according to grads progress
@@ -613,8 +622,20 @@ def run_training_epoch(self):
613622 epoch_output , self .checkpoint_accumulator , self .early_stopping_accumulator , self .num_optimizers
614623 )
615624
616- # when no val loop is present or fast-dev-run still need to call checkpoints
617- self .check_checkpoint_callback (not (should_check_val or is_overridden ('validation_step' , model )))
625+ should_check_val = self .should_check_val_fx (batch_idx , is_last_batch , on_epoch = True )
626+ if should_check_val :
627+ self .trainer .run_evaluation (on_epoch = True )
628+ # reset stage to train
629+ self .trainer .logger_connector .set_stage ("train" )
630+
631+ should_skip_eval = self .trainer .evaluation_loop .should_skip_evaluation (self .trainer .num_val_batches )
632+ should_train_only = self .trainer .disable_validation or should_skip_eval
633+
634+ if should_train_only :
635+ # update epoch level lr_schedulers
636+ self .trainer .optimizer_connector .update_learning_rates (interval = 'epoch' )
637+ self .check_checkpoint_callback (True )
638+ self .check_early_stopping_callback (True )
618639
619640 # increment the global step once
620641 # progress global step according to grads progress
@@ -840,25 +861,33 @@ def increment_accumulated_grad_global_step(self):
840861 def _accumulated_batches_reached (self ):
841862 return (self .trainer .batch_idx + 1 ) % self .trainer .accumulate_grad_batches == 0
842863
843- def _num_training_batches_reached (self ):
844- return (self .trainer .batch_idx + 1 ) == self .trainer .num_training_batches
864+ def _num_training_batches_reached (self , is_last_batch = False ):
865+ return (self .trainer .batch_idx + 1 ) == self .trainer .num_training_batches or is_last_batch
845866
846867 def should_accumulate (self ):
847868 # checks if backward or backward + optimizer step (via closure)
848869 accumulation_done = self ._accumulated_batches_reached ()
849870 is_final_batch = self ._num_training_batches_reached ()
850871 return not (accumulation_done or is_final_batch )
851872
852- def should_check_val_fx (self , batch_idx , is_last_batch ):
873+ def should_check_val_fx (self , batch_idx , is_last_batch , on_epoch = False ):
853874 # decide if we should run validation
854875 is_val_check_batch = (batch_idx + 1 ) % self .trainer .val_check_batch == 0
855876 is_val_check_epoch = (self .trainer .current_epoch + 1 ) % self .trainer .check_val_every_n_epoch == 0
856877 can_check_val = self .trainer .enable_validation and is_val_check_epoch
857- should_check_val = is_val_check_batch or self .trainer .should_stop
858878 is_last_batch_for_infinite_dataset = is_last_batch and self .trainer .val_check_batch == float ("inf" )
859- should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset )
879+ epoch_end_val_check = self .trainer .val_check_batch == self .trainer .num_training_batches
880+
881+ should_check_val = (
882+ (is_val_check_batch and epoch_end_val_check )
883+ or self .trainer .should_stop
884+ or is_last_batch_for_infinite_dataset
885+ ) if on_epoch else (
886+ is_val_check_batch
887+ and not epoch_end_val_check
888+ )
860889
861- return should_check_val
890+ return should_check_val and can_check_val
862891
863892 def build_train_args (self , batch , batch_idx , opt_idx , hiddens ):
864893 # enable not needing to add opt_idx to training_step
0 commit comments