@@ -93,8 +93,25 @@ class ModelCheckpoint(Callback):
9393 save_weights_only: if ``True``, then only the model's weights will be
9494 saved (``model.save_weights(filepath)``), else the full model
9595 is saved (``model.save(filepath)``).
96+ every_n_train_steps: Number of training steps between checkpoints.
97+ If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training
98+ To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative.
99+ This must be mutually exclusive with ``every_n_val_epochs``.
100+ every_n_val_epochs: Number of validation epochs between checkpoints.
101+ If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end
102+ To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative.
103+ This must be mutually exclusive with ``every_n_train_steps``.
104+ Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and
105+ ``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
106+ will only save checkpoints at epochs 0 < E <= N
107+ where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
96108 period: Interval (number of epochs) between checkpoints.
97109
110+ .. warning::
111+ This argument has been deprecated in v1.3 and will be removed in v1.5.
112+
113+ Use ``every_n_val_epochs`` instead.
114+
98115 Note:
99116 For extra customization, ModelCheckpoint includes the following attributes:
100117
@@ -165,16 +182,17 @@ def __init__(
165182 save_top_k : Optional [int ] = None ,
166183 save_weights_only : bool = False ,
167184 mode : str = "min" ,
168- period : int = 1 ,
169- auto_insert_metric_name : bool = True
185+ auto_insert_metric_name : bool = True ,
186+ every_n_train_steps : Optional [int ] = None ,
187+ every_n_val_epochs : Optional [int ] = None ,
188+ period : Optional [int ] = None ,
170189 ):
171190 super ().__init__ ()
172191 self .monitor = monitor
173192 self .verbose = verbose
174193 self .save_last = save_last
175194 self .save_top_k = save_top_k
176195 self .save_weights_only = save_weights_only
177- self .period = period
178196 self .auto_insert_metric_name = auto_insert_metric_name
179197 self ._last_global_step_saved = - 1
180198 self .current_score = None
@@ -188,6 +206,7 @@ def __init__(
188206
189207 self .__init_monitor_mode (monitor , mode )
190208 self .__init_ckpt_dir (dirpath , filename , save_top_k )
209+ self .__init_triggers (every_n_train_steps , every_n_val_epochs , period )
191210 self .__validate_init_configuration ()
192211
193212 def on_pretrain_routine_start (self , trainer , pl_module ):
@@ -197,10 +216,26 @@ def on_pretrain_routine_start(self, trainer, pl_module):
197216 self .__resolve_ckpt_dir (trainer )
198217 self .save_function = trainer .save_checkpoint
199218
200- def on_validation_end (self , trainer , pl_module ):
219+ def on_train_batch_end (self , trainer , * args , ** kwargs ) -> None :
220+ """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """
221+ if self ._should_skip_saving_checkpoint (trainer ):
222+ return
223+ step = trainer .global_step
224+ skip_batch = self ._every_n_train_steps < 1 or ((step + 1 ) % self ._every_n_train_steps != 0 )
225+ if skip_batch :
226+ return
227+ self .save_checkpoint (trainer )
228+
229+ def on_validation_end (self , trainer , * args , ** kwargs ) -> None :
201230 """
202231 checkpoints can be saved at the end of the val loop
203232 """
233+ skip = (
234+ self ._should_skip_saving_checkpoint (trainer ) or self ._every_n_val_epochs < 1
235+ or (trainer .current_epoch + 1 ) % self ._every_n_val_epochs != 0
236+ )
237+ if skip :
238+ return
204239 self .save_checkpoint (trainer )
205240
206241 def on_save_checkpoint (self , trainer , pl_module , checkpoint : Dict [str , Any ]) -> Dict [str , Any ]:
@@ -228,20 +263,8 @@ def save_checkpoint(self, trainer, unused: Optional = None):
228263 " has been removed. Support for the old signature will be removed in v1.5" , DeprecationWarning
229264 )
230265
231- epoch = trainer .current_epoch
232266 global_step = trainer .global_step
233267
234- from pytorch_lightning .trainer .states import TrainerState
235- if (
236- trainer .fast_dev_run # disable checkpointing with fast_dev_run
237- or trainer .state != TrainerState .FITTING # don't save anything during non-fit
238- or trainer .sanity_checking # don't save anything during sanity check
239- or self .period < 1 # no models are saved
240- or (epoch + 1 ) % self .period # skip epoch
241- or self ._last_global_step_saved == global_step # already saved at the last step
242- ):
243- return
244-
245268 self ._add_backward_monitor_support (trainer )
246269 self ._validate_monitor_key (trainer )
247270
@@ -260,9 +283,32 @@ def save_checkpoint(self, trainer, unused: Optional = None):
260283 # Mode 3: save last checkpoints
261284 self ._save_last_checkpoint (trainer , monitor_candidates )
262285
286+ def _should_skip_saving_checkpoint (self , trainer ) -> bool :
287+ from pytorch_lightning .trainer .states import TrainerState
288+ return (
289+ trainer .fast_dev_run # disable checkpointing with fast_dev_run
290+ or trainer .state != TrainerState .FITTING # don't save anything during non-fit
291+ or trainer .sanity_checking # don't save anything during sanity check
292+ or self ._last_global_step_saved == trainer .global_step # already saved at the last step
293+ )
294+
263295 def __validate_init_configuration (self ):
264296 if self .save_top_k is not None and self .save_top_k < - 1 :
265297 raise MisconfigurationException (f'Invalid value for save_top_k={ self .save_top_k } . Must be None or >= -1' )
298+ if self ._every_n_train_steps < 0 :
299+ raise MisconfigurationException (
300+ f'Invalid value for every_n_train_steps={ self ._every_n_train_steps } . Must be >= 0'
301+ )
302+ if self ._every_n_val_epochs < 0 :
303+ raise MisconfigurationException (
304+ f'Invalid value for every_n_val_epochs={ self ._every_n_val_epochs } . Must be >= 0'
305+ )
306+ if self ._every_n_train_steps > 0 and self ._every_n_val_epochs > 0 :
307+ raise MisconfigurationException (
308+ f'Invalid values for every_n_train_steps={ self ._every_n_train_steps } '
309+ ' and every_n_val_epochs={self._every_n_val_epochs}.'
310+ ' Both cannot be enabled at the same time.'
311+ )
266312 if self .monitor is None :
267313 # None: save last epoch, -1: save all epochs, 0: nothing is saved
268314 if self .save_top_k not in (None , - 1 , 0 ):
@@ -309,6 +355,46 @@ def __init_monitor_mode(self, monitor, mode):
309355
310356 self .kth_value , self .mode = mode_dict [mode ]
311357
358+ def __init_triggers (
359+ self , every_n_train_steps : Optional [int ], every_n_val_epochs : Optional [int ], period : Optional [int ]
360+ ) -> None :
361+
362+ # Default to running once after each validation epoch if neither
363+ # every_n_train_steps nor every_n_val_epochs is set
364+ if every_n_train_steps is None and every_n_val_epochs is None :
365+ self ._every_n_val_epochs = 1
366+ self ._every_n_train_steps = 0
367+ log .debug ("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1" )
368+ else :
369+ self ._every_n_val_epochs = every_n_val_epochs or 0
370+ self ._every_n_train_steps = every_n_train_steps or 0
371+
372+ # period takes precedence over every_n_val_epochs for backwards compatibility
373+ if period is not None :
374+ rank_zero_warn (
375+ 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
376+ ' Please use `every_n_val_epochs` instead.' , DeprecationWarning
377+ )
378+ self ._every_n_val_epochs = period
379+
380+ self ._period = self ._every_n_val_epochs
381+
382+ @property
383+ def period (self ) -> Optional [int ]:
384+ rank_zero_warn (
385+ 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
386+ ' Please use `every_n_val_epochs` instead.' , DeprecationWarning
387+ )
388+ return self ._period
389+
390+ @period .setter
391+ def period (self , value : Optional [int ]) -> None :
392+ rank_zero_warn (
393+ 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
394+ ' Please use `every_n_val_epochs` instead.' , DeprecationWarning
395+ )
396+ self ._period = value
397+
312398 @rank_zero_only
313399 def _del_model (self , filepath : str ):
314400 if self ._fs .exists (filepath ):
@@ -422,11 +508,8 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
422508
423509 """
424510 filename = self ._format_checkpoint_name (
425- self .filename ,
426- epoch ,
427- step ,
428- metrics ,
429- auto_insert_metric_name = self .auto_insert_metric_name )
511+ self .filename , epoch , step , metrics , auto_insert_metric_name = self .auto_insert_metric_name
512+ )
430513
431514 if ver is not None :
432515 filename = self .CHECKPOINT_JOIN_CHAR .join ((filename , f"v{ ver } " ))
@@ -581,9 +664,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A
581664 self ._save_model (trainer , filepath )
582665
583666 if (
584- self .save_top_k is None
585- and self .best_model_path
586- and self .best_model_path != filepath
667+ self .save_top_k is None and self .best_model_path and self .best_model_path != filepath
587668 and trainer .is_global_zero
588669 ):
589670 self ._del_model (self .best_model_path )
0 commit comments