Skip to content

Commit 78b2562

Browse files
committed
pre-commit
1 parent 02ec2d2 commit 78b2562

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,9 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]):
238238

239239
def _should_skip_saving_checkpoint(self, trainer) -> bool:
240240
return (
241-
trainer.fast_dev_run # disable checkpointing with fast_dev_run
241+
trainer.fast_dev_run # disable checkpointing with fast_dev_run
242242
or trainer.running_sanity_check # don't save anything during sanity check
243-
or self.save_top_k == 0 # no models are saved
243+
or self.save_top_k == 0 # no models are saved
244244
or self._last_global_step_saved == global_step # already saved at the last step
245245
)
246246

@@ -282,9 +282,13 @@ def __validate_init_configuration(self):
282282
if self.save_top_k is not None and self.save_top_k < -1:
283283
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
284284
if self.every_n_epochs == 0 or self.every_n_epochs < -1:
285-
raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1')
285+
raise MisconfigurationException(
286+
f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1'
287+
)
286288
if self.every_n_batches == 0 or self.every_n_batches < -1:
287-
raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1')
289+
raise MisconfigurationException(
290+
f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1'
291+
)
288292
if self.monitor is None:
289293
# None: save last epoch, -1: save all epochs, 0: nothing is saved
290294
if self.save_top_k not in [None, -1, 0]:

tests/checkpointing/test_model_checkpoint.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -499,21 +499,19 @@ def test_none_monitor_top_k(tmpdir):
499499
ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)
500500
ModelCheckpoint(dirpath=tmpdir, save_top_k=0)
501501

502+
502503
def test_invalid_every_n_epoch(tmpdir):
503504
""" Test that an exception is raised for every_n_epochs = 0 or < -1. """
504-
with pytest.raises(
505-
MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'
506-
):
505+
with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'):
507506
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0)
508-
with pytest.raises(
509-
MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'
510-
):
507+
with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'):
511508
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2)
512509

513510
# These should not fail
514511
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1)
515512
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3)
516513

514+
517515
def test_invalid_every_n_batches(tmpdir):
518516
""" Test that an exception is raised for every_n_batches = 0 or < -1. """
519517
with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'):

0 commit comments

Comments
 (0)