Skip to content

Commit 579ddce

Browse files
committed
Disable {save,check}_on_train_epoch_end with check_val_every_n_epoch>1 (#9156)
1 parent 68dcd06 commit 579ddce

File tree

5 files changed

+70
-3
lines changed

5 files changed

+70
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))
1616

1717

18+
- Fixed `EarlyStopping` running on train epoch end when `check_val_every_n_epoch>1` is set ([#9156](https://github.com/PyTorchLightning/pytorch-lightning/pull/9156))
19+
20+
1821
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))
1922

2023

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ def __init__(
120120
)
121121
self.monitor = monitor or "early_stop_on"
122122

123+
@property
124+
def state_key(self) -> str:
125+
return self._generate_state_key(monitor=self.monitor, mode=self.mode)
126+
127+
def on_init_end(self, trainer: "pl.Trainer") -> None:
128+
if self._check_on_train_epoch_end is None:
129+
# if the user runs validation multiple times per training epoch or multiple training epochs without
130+
# validation, then we run after validation instead of on train epoch end
131+
self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
132+
123133
def _validate_condition_metric(self, logs):
124134
monitor_val = logs.get(self.monitor)
125135

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,9 @@ def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn
256256
self.__resolve_ckpt_dir(trainer)
257257
self._save_function = trainer.save_checkpoint
258258
if self._save_on_train_epoch_end is None:
259-
# if the user runs validation multiple times per training epoch, we try to save checkpoint after
260-
# validation instead of on train epoch end
261-
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0
259+
# if the user runs validation multiple times per training epoch or multiple training epochs without
260+
# validation, then we run after validation instead of on train epoch end
261+
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
262262

263263
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
264264
self._last_time_checked = time.monotonic()

tests/callbacks/test_early_stopping.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,39 @@ def test_multiple_early_stopping_callbacks(
416416
num_processes=num_processes,
417417
)
418418
trainer.fit(model)
419+
420+
421+
@pytest.mark.parametrize(
422+
"case",
423+
{
424+
"val_check_interval": {"val_check_interval": 0.3, "limit_train_batches": 10, "max_epochs": 10},
425+
"check_val_every_n_epoch": {"check_val_every_n_epoch": 2, "max_epochs": 5},
426+
}.items(),
427+
)
428+
def test_check_on_train_epoch_end_smart_handling(tmpdir, case):
429+
class TestModel(BoringModel):
430+
def validation_step(self, batch, batch_idx):
431+
self.log("foo", 1)
432+
return super().validation_step(batch, batch_idx)
433+
434+
case, kwargs = case
435+
model = TestModel()
436+
trainer = Trainer(
437+
default_root_dir=tmpdir,
438+
limit_val_batches=1,
439+
callbacks=EarlyStopping(monitor="foo"),
440+
progress_bar_refresh_rate=0,
441+
**kwargs,
442+
)
443+
444+
side_effect = [(False, "A"), (True, "B")]
445+
with mock.patch(
446+
"pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria", side_effect=side_effect
447+
) as es_mock:
448+
trainer.fit(model)
449+
450+
assert es_mock.call_count == len(side_effect)
451+
if case == "val_check_interval":
452+
assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval)
453+
else:
454+
assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch - 1

tests/checkpointing/test_model_checkpoint.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,3 +1222,21 @@ def test_trainer_checkpoint_callback_bool(tmpdir):
12221222
mc = ModelCheckpoint(dirpath=tmpdir)
12231223
with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"):
12241224
Trainer(checkpoint_callback=mc)
1225+
1226+
1227+
def test_check_val_every_n_epochs_top_k_integration(tmpdir):
1228+
model = BoringModel()
1229+
mc = ModelCheckpoint(dirpath=tmpdir, monitor="epoch", save_top_k=-1, filename="{epoch}")
1230+
trainer = Trainer(
1231+
default_root_dir=tmpdir,
1232+
limit_train_batches=1,
1233+
limit_val_batches=1,
1234+
num_sanity_val_steps=0,
1235+
max_epochs=5,
1236+
check_val_every_n_epoch=2,
1237+
callbacks=mc,
1238+
weights_summary=None,
1239+
logger=False,
1240+
)
1241+
trainer.fit(model)
1242+
assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"}

0 commit comments

Comments
 (0)