Skip to content

Commit cb67e1d

Browse files
rohitgr7Borda
authored andcommitted
Separate epoch validation from step validation (#5208)
* Seperate epoch validaton from step validation * update system * test * baked logic in callbacks * unbake logic in callbacks * fix the call for scheduler * use property * pep * correct rebase * gitignore * ref * add tests * fix * add early stopping test * trigger * chlog * rev * 1.3 * log * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> * Update pytorch_lightning/trainer/training_loop.py * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: chaton <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit e429f97)
1 parent e7c6e9d commit cb67e1d

File tree

11 files changed

+189
-80
lines changed

11 files changed

+189
-80
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,6 @@ wandb
151151

152152
# dataset generated from bolts in examples.
153153
cifar-10-batches-py
154+
155+
# ctags
156+
tags

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
175175
- Fixed support custom DataLoader with DDP if they can be re-instantiated ([#5745](https://github.com/PyTorchLightning/pytorch-lightning/pull/5745))
176176

177177

178-
## [1.1.8] - 2021-02-06
178+
## [1.1.8] - 2021-02-08
179179

180180
### Fixed
181181

182+
- Separate epoch validation from step validation ([#5208](https://github.com/PyTorchLightning/pytorch-lightning/pull/5208))
182183
- Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775))
183184

184185

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,6 @@ def __init__(
8686
self.stopped_epoch = 0
8787
self.mode = mode
8888
self.warned_result_obj = False
89-
# Indicates, if eval results are used as basis for early stopping
90-
# It is set to False initially and overwritten, if eval results have been validated
91-
self.based_on_eval_results = False
9289

9390
self.__init_monitor_mode()
9491

@@ -159,21 +156,6 @@ def on_validation_end(self, trainer, pl_module):
159156

160157
self._run_early_stopping_check(trainer, pl_module)
161158

162-
def on_validation_epoch_end(self, trainer, pl_module):
163-
if trainer.fast_dev_run or trainer.running_sanity_check:
164-
return
165-
166-
if self._validate_condition_metric(trainer.callback_metrics):
167-
# turn off early stopping in on_train_epoch_end
168-
self.based_on_eval_results = True
169-
170-
def on_train_epoch_end(self, trainer, pl_module, outputs):
171-
# disable early stopping in train loop when there's a val loop
172-
if self.based_on_eval_results:
173-
return
174-
175-
self._run_early_stopping_check(trainer, pl_module)
176-
177159
def _run_early_stopping_check(self, trainer, pl_module):
178160
"""
179161
Checks whether the early stopping condition is met

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __init__(
167167
self.save_top_k = save_top_k
168168
self.save_weights_only = save_weights_only
169169
self.period = period
170-
self.last_global_step_saved = -1
170+
self._last_global_step_saved = -1
171171
self.prefix = prefix
172172
self.current_score = None
173173
self.best_k_models = {}
@@ -232,15 +232,15 @@ def save_checkpoint(self, trainer, pl_module):
232232
or self.period < 1 # no models are saved
233233
or (epoch + 1) % self.period # skip epoch
234234
or trainer.running_sanity_check # don't save anything during sanity check
235-
or self.last_global_step_saved == global_step # already saved at the last step
235+
or self._last_global_step_saved == global_step # already saved at the last step
236236
):
237237
return
238238

239239
self._add_backward_monitor_support(trainer)
240240
self._validate_monitor_key(trainer)
241241

242242
# track epoch when ckpt was last checked
243-
self.last_global_step_saved = global_step
243+
self._last_global_step_saved = global_step
244244

245245
# what can be monitored
246246
monitor_candidates = self._monitor_candidates(trainer)

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,8 @@ def get_evaluation_dataloaders(self, max_batches):
7070

7171
return dataloaders, max_batches
7272

73-
def should_skip_evaluation(self, dataloaders, max_batches):
74-
# skip when dataloaders aren't defined
75-
if dataloaders is None:
76-
return True
77-
78-
# enable disabling validation step with limit_val_batches = 0
79-
should_skip = sum(max_batches) == 0
80-
if should_skip:
81-
return True
82-
83-
return False
73+
def should_skip_evaluation(self, max_batches):
74+
return sum(max_batches) == 0
8475

8576
def on_evaluation_start(self, *args, **kwargs):
8677
if self.trainer.testing:

pytorch_lightning/trainer/trainer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -598,9 +598,6 @@ def train(self):
598598
if self.max_steps and self.max_steps <= self.global_step:
599599
return
600600

601-
# update LR schedulers
602-
self.optimizer_connector.update_learning_rates(interval='epoch')
603-
604601
# early stopping
605602
met_min_epochs = epoch >= self.min_epochs - 1
606603
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
@@ -626,7 +623,7 @@ def train(self):
626623
# hook
627624
self.train_loop.on_train_end()
628625

629-
def run_evaluation(self, max_batches=None):
626+
def run_evaluation(self, max_batches=None, on_epoch=False):
630627

631628
# used to know if we are logging for val, test + reset cached results
632629
self._set_wide_running_stage(RunningStage.TESTING if self.testing else RunningStage.EVALUATING)
@@ -639,7 +636,7 @@ def run_evaluation(self, max_batches=None):
639636
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches)
640637

641638
# check if we want to skip this evaluation
642-
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
639+
if self.evaluation_loop.should_skip_evaluation(max_batches):
643640
return [], []
644641

645642
# ref model
@@ -705,6 +702,10 @@ def run_evaluation(self, max_batches=None):
705702
# hook
706703
self.evaluation_loop.on_evaluation_epoch_end()
707704

705+
# update epoch-level lr_schedulers
706+
if on_epoch:
707+
self.optimizer_connector.update_learning_rates(interval='epoch')
708+
708709
# hook
709710
self.evaluation_loop.on_evaluation_end()
710711

pytorch_lightning/trainer/training_loop.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import torch
2020

21-
from pytorch_lightning.callbacks import ModelCheckpoint
21+
from pytorch_lightning.callbacks import EarlyStopping
2222
from pytorch_lightning.core.memory import ModelSummary
2323
from pytorch_lightning.core.optimizer import LightningOptimizer
2424
from pytorch_lightning.core.step_result import Result
@@ -161,7 +161,7 @@ def on_train_end(self):
161161
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
162162
# when a checkpoint was saved at the last step
163163
self.trainer.global_step -= 1
164-
self.check_checkpoint_callback(should_save=True, is_last=True)
164+
self.check_checkpoint_callback(should_update=True, is_last=True)
165165
self.trainer.global_step += 1
166166

167167
# hook
@@ -184,18 +184,27 @@ def on_train_end(self):
184184
model.cpu()
185185
torch.cuda.empty_cache()
186186

187-
def check_checkpoint_callback(self, should_save, is_last=False):
188-
# TODO bake this logic into the checkpoint callback
189-
if should_save and self.trainer.checkpoint_connector.has_trained:
190-
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
187+
def check_checkpoint_callback(self, should_update, is_last=False):
188+
# TODO bake this logic into the ModelCheckpoint callback
189+
if should_update and self.trainer.checkpoint_connector.has_trained:
190+
callbacks = self.trainer.checkpoint_callbacks
191191

192-
if is_last and any(c.save_last for c in checkpoint_callbacks):
192+
if is_last and any(cb.save_last for cb in callbacks):
193193
rank_zero_info("Saving latest checkpoint...")
194194

195195
model = self.trainer.get_model()
196196

197-
for callback in checkpoint_callbacks:
198-
callback.on_validation_end(self.trainer, model)
197+
for cb in callbacks:
198+
cb.on_validation_end(self.trainer, model)
199+
200+
def check_early_stopping_callback(self, should_update):
201+
# TODO bake this logic into the EarlyStopping callback
202+
if should_update and self.trainer.checkpoint_connector.has_trained:
203+
callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)]
204+
model = self.trainer.get_model()
205+
206+
for cb in callbacks:
207+
cb.on_validation_end(self.trainer, model)
199208

200209
def on_train_epoch_start(self, epoch):
201210

@@ -521,7 +530,6 @@ def tbptt_split_batch(self, batch):
521530
return splits
522531

523532
def run_training_epoch(self):
524-
525533
# get model
526534
model = self.trainer.get_model()
527535

@@ -584,11 +592,12 @@ def run_training_epoch(self):
584592
self.trainer.checkpoint_connector.has_trained = True
585593

586594
# max steps reached, end training
587-
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
588-
accumulation_done = self._accumulated_batches_reached()
589-
# Ensure accumulation across batches has completed before breaking loop
590-
if accumulation_done:
591-
break
595+
if (
596+
self.trainer.max_steps is not None
597+
and self.trainer.max_steps == self.trainer.global_step + 1
598+
and self._accumulated_batches_reached()
599+
):
600+
break
592601

593602
# end epoch early
594603
# stop when the flag is changed or we've gone past the amount
@@ -599,7 +608,7 @@ def run_training_epoch(self):
599608
self.trainer.total_batch_idx += 1
600609

601610
# stop epoch if we limited the number of training batches
602-
if (batch_idx + 1) >= self.trainer.num_training_batches:
611+
if self._num_training_batches_reached(is_last_batch):
603612
break
604613

605614
# progress global step according to grads progress
@@ -613,8 +622,20 @@ def run_training_epoch(self):
613622
epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers
614623
)
615624

616-
# when no val loop is present or fast-dev-run still need to call checkpoints
617-
self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model)))
625+
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
626+
if should_check_val:
627+
self.trainer.run_evaluation(on_epoch=True)
628+
# reset stage to train
629+
self.trainer.logger_connector.set_stage("train")
630+
631+
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
632+
should_train_only = self.trainer.disable_validation or should_skip_eval
633+
634+
if should_train_only:
635+
# update epoch level lr_schedulers
636+
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
637+
self.check_checkpoint_callback(True)
638+
self.check_early_stopping_callback(True)
618639

619640
# increment the global step once
620641
# progress global step according to grads progress
@@ -840,25 +861,33 @@ def increment_accumulated_grad_global_step(self):
840861
def _accumulated_batches_reached(self):
841862
return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
842863

843-
def _num_training_batches_reached(self):
844-
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
864+
def _num_training_batches_reached(self, is_last_batch=False):
865+
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch
845866

846867
def should_accumulate(self):
847868
# checks if backward or backward + optimizer step (via closure)
848869
accumulation_done = self._accumulated_batches_reached()
849870
is_final_batch = self._num_training_batches_reached()
850871
return not (accumulation_done or is_final_batch)
851872

852-
def should_check_val_fx(self, batch_idx, is_last_batch):
873+
def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
853874
# decide if we should run validation
854875
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
855876
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
856877
can_check_val = self.trainer.enable_validation and is_val_check_epoch
857-
should_check_val = is_val_check_batch or self.trainer.should_stop
858878
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
859-
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)
879+
epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches
880+
881+
should_check_val = (
882+
(is_val_check_batch and epoch_end_val_check)
883+
or self.trainer.should_stop
884+
or is_last_batch_for_infinite_dataset
885+
) if on_epoch else (
886+
is_val_check_batch
887+
and not epoch_end_val_check
888+
)
860889

861-
return should_check_val
890+
return should_check_val and can_check_val
862891

863892
def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
864893
# enable not needing to add opt_idx to training_step

tests/callbacks/test_early_stopping.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,9 @@ def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_ep
115115

116116
class ModelOverrideValidationReturn(EvalModelTemplate):
117117
validation_return_values = torch.Tensor(loss_values)
118-
count = 0
119118

120119
def validation_epoch_end(self, outputs):
121-
loss = self.validation_return_values[self.count]
122-
self.count += 1
120+
loss = self.validation_return_values[self.current_epoch]
123121
return {"test_val_loss": loss}
124122

125123
model = ModelOverrideValidationReturn()
@@ -135,6 +133,41 @@ def validation_epoch_end(self, outputs):
135133
assert trainer.current_epoch == expected_stop_epoch
136134

137135

136+
@pytest.mark.parametrize('validation_step', ['base', None])
137+
@pytest.mark.parametrize(
138+
"loss_values, patience, expected_stop_epoch",
139+
[
140+
([6, 5, 5, 5, 5, 5], 3, 4),
141+
([6, 5, 4, 4, 3, 3], 1, 3),
142+
([6, 5, 6, 5, 5, 5], 3, 4),
143+
],
144+
)
145+
def test_early_stopping_patience_train(tmpdir, validation_step, loss_values, patience, expected_stop_epoch):
146+
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
147+
148+
class ModelOverrideTrainReturn(EvalModelTemplate):
149+
train_return_values = torch.Tensor(loss_values)
150+
151+
def training_epoch_end(self, outputs):
152+
loss = self.train_return_values[self.current_epoch]
153+
self.log('train_loss', loss)
154+
155+
model = ModelOverrideTrainReturn()
156+
157+
if validation_step is None:
158+
model.validation_step = None
159+
160+
early_stop_callback = EarlyStopping(monitor="train_loss", patience=patience, verbose=True)
161+
trainer = Trainer(
162+
default_root_dir=tmpdir,
163+
callbacks=[early_stop_callback],
164+
num_sanity_val_steps=0,
165+
max_epochs=10,
166+
)
167+
trainer.fit(model)
168+
assert trainer.current_epoch == expected_stop_epoch
169+
170+
138171
def test_pickling(tmpdir):
139172
early_stopping = EarlyStopping()
140173

tests/checkpointing/test_checkpoint_callback_frequency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval,
5959
max_epochs=epochs,
6060
weights_summary=None,
6161
val_check_interval=val_check_interval,
62+
progress_bar_refresh_rate=0,
6263
)
6364
trainer.fit(model)
6465

0 commit comments

Comments
 (0)