@@ -222,7 +222,7 @@ def __init__(
222222 self .save_weights_only = save_weights_only
223223 self .auto_insert_metric_name = auto_insert_metric_name
224224 self ._save_on_train_epoch_end = save_on_train_epoch_end
225- self ._last_global_step_saved = - 1
225+ self ._last_global_step_saved = 0 # no need to save when no steps were taken
226226 self ._last_time_checked : Optional [float ] = None
227227 self .current_score = None
228228 self .best_k_models = {}
@@ -275,8 +275,7 @@ def on_train_batch_end(
275275 """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
276276 if self ._should_skip_saving_checkpoint (trainer ):
277277 return
278- step = trainer .global_step
279- skip_batch = self ._every_n_train_steps < 1 or ((step + 1 ) % self ._every_n_train_steps != 0 )
278+ skip_batch = self ._every_n_train_steps < 1 or (trainer .global_step % self ._every_n_train_steps != 0 )
280279
281280 train_time_interval = self ._train_time_interval
282281 skip_time = True
@@ -297,16 +296,13 @@ def on_train_batch_end(
297296
298297 def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
299298 """Save a checkpoint at the end of the training epoch."""
300- # as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
301- trainer .fit_loop .global_step -= 1
302299 if (
303300 not self ._should_skip_saving_checkpoint (trainer )
304301 and self ._save_on_train_epoch_end
305302 and self ._every_n_epochs > 0
306303 and (trainer .current_epoch + 1 ) % self ._every_n_epochs == 0
307304 ):
308305 self .save_checkpoint (trainer )
309- trainer .fit_loop .global_step += 1
310306
311307 def on_validation_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
312308 """Save a checkpoint at the end of the validation stage."""
@@ -329,11 +325,8 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
329325 return
330326 if self .verbose :
331327 rank_zero_info ("Saving latest checkpoint..." )
332- # as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
333- monitor_candidates = self ._monitor_candidates (trainer , trainer .current_epoch , trainer .global_step - 1 )
334- trainer .fit_loop .global_step -= 1
328+ monitor_candidates = self ._monitor_candidates (trainer , trainer .current_epoch , trainer .global_step )
335329 self ._save_last_checkpoint (trainer , monitor_candidates )
336- trainer .fit_loop .global_step += 1
337330
338331 def on_save_checkpoint (
339332 self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , checkpoint : Dict [str , Any ]
@@ -368,12 +361,8 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None:
368361 """
369362 self ._validate_monitor_key (trainer )
370363
371- # track epoch when ckpt was last checked
372- global_step = trainer .global_step
373- self ._last_global_step_saved = global_step
374-
375364 # what can be monitored
376- monitor_candidates = self ._monitor_candidates (trainer , epoch = trainer .current_epoch , step = global_step )
365+ monitor_candidates = self ._monitor_candidates (trainer , epoch = trainer .current_epoch , step = trainer . global_step )
377366
378367 # callback supports multiple simultaneous modes
379368 # here we call each mode sequentially
@@ -638,6 +627,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer", epoch: int, step: int) -> D
638627 def _save_last_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
639628 if not self .save_last :
640629 return
630+ self ._last_global_step_saved = monitor_candidates .get ("step" , trainer .global_step )
641631
642632 filepath = self .format_checkpoint_name (monitor_candidates , self .CHECKPOINT_NAME_LAST )
643633 # set the last model path before saving because it will be part of the state.
@@ -649,9 +639,9 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
649639 def _save_top_k_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
650640 if self .monitor is None or self .save_top_k == 0 :
651641 return
642+ self ._last_global_step_saved = monitor_candidates .get ("step" , trainer .global_step )
652643
653644 current = monitor_candidates .get (self .monitor )
654-
655645 if self .check_monitor_top_k (trainer , current ):
656646 self ._update_best_and_save (current , trainer , monitor_candidates )
657647 elif self .verbose :
@@ -662,6 +652,7 @@ def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict
662652 def _save_none_monitor_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
663653 if self .monitor is not None or self .save_top_k == 0 :
664654 return
655+ self ._last_global_step_saved = monitor_candidates .get ("step" , trainer .global_step )
665656
666657 filepath = self ._get_metric_interpolated_filepath_name (monitor_candidates , trainer )
667658 # set the best model path before saving because it will be part of the state.
0 commit comments