Skip to content

Commit 14b8dd4

Browse files
authored
[2/2] Remove training loop force calling early stopping callback (#7069)
* rebase * doc * Update training_loop.py * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md
1 parent a5ac3f8 commit 14b8dd4

File tree

4 files changed

+12
-14
lines changed

4 files changed

+12
-14
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
141141

142142
### Changed
143143

144+
145+
- Changed `EarlyStopping` callback from by default running `EarlyStopping.on_validation_end` if only training is run. Set `check_on_train_epoch_end` to run the callback at the end of the train epoch instead of at the end of the validation epoch ([#7069](https://github.com/PyTorchLightning/pytorch-lightning/pull/7069))
146+
147+
144148
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
145149

146150

@@ -224,6 +228,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
224228
### Removed
225229

226230

231+
- Removed training loop explicitly calling `EarlyStopping.on_validation_end` if no validation is run ([#7069](https://github.com/PyTorchLightning/pytorch-lightning/pull/7069))
232+
233+
227234
- Removed `automatic_optimization` as a property from the training loop in favor of `LightningModule.automatic_optimization` ([#7130](https://github.com/PyTorchLightning/pytorch-lightning/pull/7130))
228235

229236

pytorch_lightning/trainer/training_loop.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import numpy as np
2020
import torch
2121

22-
from pytorch_lightning.callbacks import EarlyStopping
2322
from pytorch_lightning.core.optimizer import LightningOptimizer
2423
from pytorch_lightning.core.step_result import Result
2524
from pytorch_lightning.plugins import ParallelPlugin
@@ -148,15 +147,6 @@ def check_checkpoint_callback(self, should_update, is_last=False):
148147
for cb in callbacks:
149148
cb.on_validation_end(self.trainer, model)
150149

151-
def check_early_stopping_callback(self, should_update):
152-
# TODO bake this logic into the EarlyStopping callback
153-
if should_update and self.trainer.checkpoint_connector.has_trained:
154-
callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)]
155-
model = self.trainer.lightning_module
156-
157-
for cb in callbacks:
158-
cb.on_validation_end(self.trainer, model)
159-
160150
def on_train_epoch_start(self, epoch):
161151

162152
# update training progress in trainer
@@ -556,7 +546,6 @@ def run_training_epoch(self):
556546

557547
if should_train_only:
558548
self.check_checkpoint_callback(True)
559-
self.check_early_stopping_callback(True)
560549

561550
if should_check_val:
562551
self.trainer.validating = True

tests/callbacks/test_early_stopping.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def training_epoch_end(self, outputs):
169169
if validation_step_none:
170170
model.validation_step = None
171171

172-
early_stop_callback = EarlyStopping(monitor="train_loss", patience=patience, verbose=True)
172+
early_stop_callback = EarlyStopping(
173+
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=validation_step_none
174+
)
173175
trainer = Trainer(
174176
default_root_dir=tmpdir,
175177
callbacks=[early_stop_callback],
@@ -200,7 +202,7 @@ def test_early_stopping_no_val_step(tmpdir):
200202
model.validation_step = None
201203
model.val_dataloader = None
202204

203-
stopping = EarlyStopping(monitor='train_loss', min_delta=0.1, patience=0)
205+
stopping = EarlyStopping(monitor='train_loss', min_delta=0.1, patience=0, check_on_train_epoch_end=True)
204206
trainer = Trainer(
205207
default_root_dir=tmpdir,
206208
callbacks=[stopping],

tests/trainer/test_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ def training_step(self, batch, batch_idx):
548548
return output
549549

550550
model = TestModel()
551-
early_stop = EarlyStopping(monitor="loss", patience=0)
551+
early_stop = EarlyStopping(monitor="loss", patience=0, check_on_train_epoch_end=True)
552552
min_epochs = 5
553553
trainer = Trainer(
554554
default_root_dir=tmpdir,

0 commit comments

Comments
 (0)