Skip to content

Commit 04ae832

Browse files
committed
Add support for early stopping during training epoch end
1 parent b85cfbe commit 04ae832

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
verbose: bool = False,
8181
mode: str = 'min',
8282
strict: bool = True,
83+
during_training: bool = False,
8384
):
8485
super().__init__()
8586
self.monitor = monitor
@@ -90,6 +91,7 @@ def __init__(
9091
self.wait_count = 0
9192
self.stopped_epoch = 0
9293
self.mode = mode
94+
self.during_training = during_training
9395

9496
if self.mode not in self.mode_dict:
9597
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
@@ -129,15 +131,24 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) ->
129131
'patience': self.patience
130132
}
131133

132-
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
134+
def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
133135
self.wait_count = callback_state['wait_count']
134136
self.stopped_epoch = callback_state['stopped_epoch']
135137
self.best_score = callback_state['best_score']
136138
self.patience = callback_state['patience']
137139

138-
def on_validation_end(self, trainer, pl_module):
140+
def _should_skip_check(self, trainer) -> bool:
139141
from pytorch_lightning.trainer.states import TrainerState
140-
if trainer.state != TrainerState.FITTING or trainer.sanity_checking:
142+
return trainer.state != TrainerState.FITTING or trainer.sanity_checking
143+
144+
def on_train_epoch_end(self, trainer, pl_module, outputs) -> None:
145+
if not self.during_training or self._should_skip_check(trainer):
146+
return
147+
self._run_early_stopping_check(trainer)
148+
149+
150+
def on_validation_end(self, trainer, pl_module):
151+
if self.during_training or self._should_skip_check(trainer):
141152
return
142153

143154
self._run_early_stopping_check(trainer)
@@ -153,7 +164,7 @@ def _run_early_stopping_check(self, trainer):
153164
trainer.fast_dev_run # disable early_stopping with fast_dev_run
154165
or not self._validate_condition_metric(logs) # short circuit if metric not present
155166
):
156-
return # short circuit if metric not present
167+
return
157168

158169
current = logs.get(self.monitor)
159170

tests/callbacks/test_early_stopping.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -333,46 +333,77 @@ def test_early_stopping_mode_options():
333333

334334
class EarlyStoppingModel(BoringModel):
335335

336-
def __init__(self, expected_end_epoch):
336+
def __init__(self, expected_end_epoch: int, during_training: bool):
337337
super().__init__()
338338
self.expected_end_epoch = expected_end_epoch
339+
self.during_training = during_training
340+
341+
def training_epoch_end(self, outputs):
342+
if not self.during_training:
343+
return
344+
losses = [8, 4, 2, 3, 4, 5, 8, 10]
345+
loss = losses[self.current_epoch]
346+
self.log('abc', torch.tensor(loss))
347+
self.log('cba', torch.tensor(0))
339348

340349
def validation_epoch_end(self, outputs):
350+
if self.during_training:
351+
return
341352
losses = [8, 4, 2, 3, 4, 5, 8, 10]
342-
val_loss = losses[self.current_epoch]
343-
self.log('abc', torch.tensor(val_loss))
353+
loss = losses[self.current_epoch]
354+
self.log('abc', torch.tensor(loss))
344355
self.log('cba', torch.tensor(0))
345356

346357
def on_train_end(self) -> None:
347358
assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed'
348359

349360

350361
@pytest.mark.parametrize(
351-
"callbacks, expected_stop_epoch, accelerator, num_processes",
362+
"callbacks, expected_stop_epoch, during_training, accelerator, num_processes",
352363
[
353-
([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1),
364+
([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, False, None, 1),
354365
([EarlyStopping(monitor='cba', patience=3),
355-
EarlyStopping(monitor='abc')], 3, None, 1),
366+
EarlyStopping(monitor='abc')], 3, False, None, 1),
356367
pytest.param([EarlyStopping(monitor='abc'),
357368
EarlyStopping(monitor='cba', patience=3)],
358369
3,
370+
False,
359371
'ddp_cpu',
360372
2,
361373
marks=RunIf(skip_windows=True)),
362374
pytest.param([EarlyStopping(monitor='cba', patience=3),
363375
EarlyStopping(monitor='abc')],
364376
3,
377+
False,
378+
'ddp_cpu',
379+
2,
380+
marks=RunIf(skip_windows=True)),
381+
([EarlyStopping(monitor='abc', during_training=True), EarlyStopping(monitor='cba', patience=3, during_training=True)], 3, True, None, 1),
382+
([EarlyStopping(monitor='cba', patience=3, during_training=True),
383+
EarlyStopping(monitor='abc', during_training=True)], 3, True, None, 1),
384+
pytest.param([EarlyStopping(monitor='abc', during_training=True),
385+
EarlyStopping(monitor='cba', patience=3, during_training=True)],
386+
3,
387+
True,
388+
'ddp_cpu',
389+
2,
390+
marks=RunIf(skip_windows=True)),
391+
pytest.param([EarlyStopping(monitor='cba', patience=3, during_training=True),
392+
EarlyStopping(monitor='abc', during_training=True)],
393+
3,
394+
True,
365395
'ddp_cpu',
366396
2,
367397
marks=RunIf(skip_windows=True)),
398+
368399
],
369400
)
370401
def test_multiple_early_stopping_callbacks(
371-
tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, accelerator: Optional[str], num_processes: int
402+
tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, during_training: bool, accelerator: Optional[str], num_processes: int
372403
):
373404
"""Ensure when using multiple early stopping callbacks we stop if any signals we should stop."""
374405

375-
model = EarlyStoppingModel(expected_stop_epoch)
406+
model = EarlyStoppingModel(expected_stop_epoch, during_training)
376407

377408
trainer = Trainer(
378409
default_root_dir=tmpdir,

0 commit comments

Comments
 (0)