Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9c40fee
Seperate epoch validaton from step validation
rohitgr7 Dec 20, 2020
ed6ebf1
update system
rohitgr7 Dec 20, 2020
236b052
test
rohitgr7 Dec 26, 2020
788203b
baked logic in callbacks
rohitgr7 Dec 27, 2020
c7b24ca
unbake logic in callbacks
rohitgr7 Dec 27, 2020
42b0c7b
fix the call for scheduler
rohitgr7 Dec 30, 2020
0cc0254
use property
rohitgr7 Jan 2, 2021
15e09b0
pep
rohitgr7 Jan 17, 2021
c51f946
correct rebase
rohitgr7 Jan 17, 2021
2c8ed93
gitignore
rohitgr7 Jan 17, 2021
5879528
ref
rohitgr7 Jan 23, 2021
2e6c601
add tests
rohitgr7 Jan 23, 2021
d38dba4
Merge branch 'master' into bugfix/ep_end_ckpt
rohitgr7 Jan 23, 2021
465a6f4
Merge branch 'master' into bugfix/ep_end_ckpt
rohitgr7 Jan 25, 2021
b3d601f
fix
rohitgr7 Jan 25, 2021
d84996a
add early stopping test
rohitgr7 Jan 25, 2021
549eb89
trigger
rohitgr7 Jan 25, 2021
260d1f5
chlog
rohitgr7 Jan 25, 2021
99cc9f5
Merge branch 'master' into bugfix/ep_end_ckpt
rohitgr7 Jan 25, 2021
85af968
rev
rohitgr7 Jan 25, 2021
740a07e
1.3
rohitgr7 Jan 26, 2021
ebbd980
Merge branch 'master' into bugfix/ep_end_ckpt
rohitgr7 Jan 27, 2021
ed8df7b
Merge branch 'master' into bugfix/ep_end_ckpt
tchaton Jan 27, 2021
6a376c1
Merge branch 'master' into bugfix/ep_end_ckpt
rohitgr7 Jan 29, 2021
465579b
log
rohitgr7 Jan 29, 2021
40bf21b
Apply suggestions from code review
rohitgr7 Feb 1, 2021
305d8f9
Merge branch 'master' into bugfix/ep_end_ckpt
tchaton Feb 5, 2021
6feabac
Update pytorch_lightning/trainer/training_loop.py
rohitgr7 Feb 5, 2021
cffe27f
Update CHANGELOG.md
rohitgr7 Feb 5, 2021
65d0797
Merge branch 'master' into bugfix/ep_end_ckpt
mergify[bot] Feb 7, 2021
42276ad
Apply suggestions from code review
Borda Feb 8, 2021
bcdd9ed
date
Borda Feb 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,6 @@ pytorch\ lightning
test-reports/
wandb
.forked/

# ctags
tags
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [1.1.8] - 2021-02-06
## [1.1.8] - 2021-02-08

### Fixed

- Separate epoch validation from step validation ([#5208](https://github.com/PyTorchLightning/pytorch-lightning/pull/5208))
- Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775))


## [1.1.7] - 2021-02-03

### Fixed
Expand Down
18 changes: 0 additions & 18 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ def __init__(
self.stopped_epoch = 0
self.mode = mode
self.warned_result_obj = False
# Indicates, if eval results are used as basis for early stopping
# It is set to False initially and overwritten, if eval results have been validated
self.based_on_eval_results = False

self.__init_monitor_mode()

Expand Down Expand Up @@ -164,21 +161,6 @@ def on_validation_end(self, trainer, pl_module):

self._run_early_stopping_check(trainer, pl_module)

def on_validation_epoch_end(self, trainer, pl_module):
if trainer.fast_dev_run or trainer.running_sanity_check:
return

if self._validate_condition_metric(trainer.callback_metrics):
# turn off early stopping in on_train_epoch_end
self.based_on_eval_results = True

def on_train_epoch_end(self, trainer, pl_module, outputs):
# disable early stopping in train loop when there's a val loop
if self.based_on_eval_results:
return

self._run_early_stopping_check(trainer, pl_module)

def _run_early_stopping_check(self, trainer, pl_module):
"""
Checks whether the early stopping condition is met
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.last_global_step_saved = -1
self._last_global_step_saved = -1
self.prefix = prefix
self.current_score = None
self.best_k_models = {}
Expand Down Expand Up @@ -231,15 +231,15 @@ def save_checkpoint(self, trainer, pl_module):
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
or self.last_global_step_saved == global_step # already saved at the last step
or self._last_global_step_saved == global_step # already saved at the last step
):
return

self._add_backward_monitor_support(trainer)
self._validate_monitor_key(trainer)

# track epoch when ckpt was last checked
self.last_global_step_saved = global_step
self._last_global_step_saved = global_step

# what can be monitored
monitor_candidates = self._monitor_candidates(trainer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def save_checkpoint(self, filepath, weights_only: bool = False):
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}'
'Warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
)
atomic_save(checkpoint, filepath)
13 changes: 2 additions & 11 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,8 @@ def get_evaluation_dataloaders(self, max_batches):

return dataloaders, max_batches

def should_skip_evaluation(self, dataloaders, max_batches):
# skip when dataloaders aren't defined
if dataloaders is None:
return True

# enable disabling validation step with limit_val_batches = 0
should_skip = sum(max_batches) == 0
if should_skip:
return True

return False
def should_skip_evaluation(self, max_batches):
return sum(max_batches) == 0

def on_evaluation_start(self, *args, **kwargs):
if self.trainer.testing:
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,6 @@ def train(self):
if self.max_steps and self.max_steps <= self.global_step:
return

# update LR schedulers
self.optimizer_connector.update_learning_rates(interval='epoch')

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
Expand All @@ -591,7 +588,7 @@ def train(self):
# hook
self.train_loop.on_train_end()

def run_evaluation(self, max_batches=None):
def run_evaluation(self, max_batches=None, on_epoch=False):

# used to know if we are logging for val, test + reset cached results
self.logger_connector.set_stage(self.testing, reset=True)
Expand All @@ -603,7 +600,7 @@ def run_evaluation(self, max_batches=None):
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches)

# check if we want to skip this evaluation
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
if self.evaluation_loop.should_skip_evaluation(max_batches):
return [], []

# ref model
Expand Down Expand Up @@ -664,6 +661,10 @@ def run_evaluation(self, max_batches=None):
# hook
self.evaluation_loop.on_evaluation_epoch_end()

# update epoch-level lr_schedulers
if on_epoch:
self.optimizer_connector.update_learning_rates(interval='epoch')

# hook
self.evaluation_loop.on_evaluation_end()

Expand Down
78 changes: 53 additions & 25 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.distributed as torch_distrib

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -153,7 +153,7 @@ def on_train_end(self):
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
# when a checkpoint was saved at the last step
self.trainer.global_step -= 1
self.check_checkpoint_callback(should_save=True, is_last=True)
self.check_checkpoint_callback(should_update=True, is_last=True)
self.trainer.global_step += 1

# hook
Expand All @@ -176,18 +176,27 @@ def on_train_end(self):
model.cpu()
torch.cuda.empty_cache()

def check_checkpoint_callback(self, should_save, is_last=False):
# TODO bake this logic into the checkpoint callback
if should_save and self.trainer.checkpoint_connector.has_trained:
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
def check_checkpoint_callback(self, should_update, is_last=False):
# TODO bake this logic into the ModelCheckpoint callback
if should_update and self.trainer.checkpoint_connector.has_trained:
callbacks = self.trainer.checkpoint_callbacks

if is_last and any(c.save_last for c in checkpoint_callbacks):
if is_last and any(cb.save_last for cb in callbacks):
rank_zero_info("Saving latest checkpoint...")

model = self.trainer.get_model()

for callback in checkpoint_callbacks:
callback.on_validation_end(self.trainer, model)
for cb in callbacks:
cb.on_validation_end(self.trainer, model)

def check_early_stopping_callback(self, should_update):
# TODO bake this logic into the EarlyStopping callback
if should_update and self.trainer.checkpoint_connector.has_trained:
callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)]
model = self.trainer.get_model()

for cb in callbacks:
cb.on_validation_end(self.trainer, model)

def on_train_epoch_start(self, epoch):

Expand Down Expand Up @@ -518,7 +527,6 @@ def tbptt_split_batch(self, batch):
return splits

def run_training_epoch(self):

# get model
model = self.trainer.get_model()

Expand All @@ -531,7 +539,6 @@ def run_training_epoch(self):
# enable profiling for the dataloader
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
should_check_val = False
for batch_idx, (batch, is_last_batch) in train_dataloader:

self.trainer.batch_idx = batch_idx
Expand Down Expand Up @@ -580,11 +587,12 @@ def run_training_epoch(self):
self.trainer.checkpoint_connector.has_trained = True

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
accumulation_done = self._accumulated_batches_reached()
# Ensure accumulation across batches has completed before breaking loop
if accumulation_done:
break
if (
self.trainer.max_steps is not None
and self.trainer.max_steps == self.trainer.global_step + 1
and self._accumulated_batches_reached()
):
break

# end epoch early
# stop when the flag is changed or we've gone past the amount
Expand All @@ -595,7 +603,7 @@ def run_training_epoch(self):
self.trainer.total_batch_idx += 1

# stop epoch if we limited the number of training batches
if (batch_idx + 1) >= self.trainer.num_training_batches:
if self._num_training_batches_reached(is_last_batch):
break

# progress global step according to grads progress
Expand All @@ -612,8 +620,20 @@ def run_training_epoch(self):
self.num_optimizers
)

# when no val loop is present or fast-dev-run still need to call checkpoints
self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model)))
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
if should_check_val:
self.trainer.run_evaluation(on_epoch=True)
# reset stage to train
self.trainer.logger_connector.set_stage("train")

should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly confused about this part. Can you explain why we check val and then decide if we should skip it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just to check whether there is any validation datasets available or not. If there isn't then we should run train_only_check else not. There are two cases for no validation, one when there is no validation_step other one when there is a validation_step but no validation_batches. Since even if we have a validation_step but no validation_batches, it used to skip the train_only_check, but ideally it should not.
Resolves: #4603 issue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohitgr7 I'm debugging an issue right now related to this.

pytorch_lightning/trainer/training_loop.py:753: Input params: batch_idx=1, is_last_batch=True, on_epoch=True
pytorch_lightning/trainer/training_loop.py:755: batch_idx+1=2, trainer.val_check_batch=2, is_val_check_batch=True
pytorch_lightning/trainer/training_loop.py:758: current_epoch+1=1, trainer.check_val_every_n_epoch=1, is_val_check_epoch=True
pytorch_lightning/trainer/training_loop.py:761: enable_validation=True, is_val_check_epoch=True, can_check_val=True
pytorch_lightning/trainer/training_loop.py:765: is_last_batch=True, trainer.val_check_batch=2, is_last_batch_for_infinite_dataset=False
pytorch_lightning/trainer/training_loop.py:768: batch_idx + 1=2, trainer.num_training_batches=2, epoch_end_val_check=True
pytorch_lightning/trainer/training_loop.py:774: is_val_check_batch=True, is_val_check_epoch=True, can_check_val=True, is_last_batch_for_infinite_dataset=False, epoch_end_val_check=True, should_check_val=True
pytorch_lightning/trainer/training_loop.py:775: should_check_val=True, can_check_val=True
pytorch_lightning/trainer/training_loop.py:487: should_check_val=True
pytorch_lightning/trainer/training_loop.py:489: should_skip_eval=True, trainer.num_val_batches=[]

this check for should_skip_eval is forcing the should_train_only to be True, which causes the checkpoint callback to run before validation. The checkpoint is configured for a metric that appears only in validation, which leads to a failure. I don't get why should_skip_eval affects the should_train_only - shouldn't that be decided entirely by self.trainer.disable_validation ?

this could also be pointing to a bug in how self.trainer.num_val_batches is set

should_train_only = self.trainer.disable_validation or should_skip_eval

if should_train_only:
# update epoch level lr_schedulers
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
self.check_checkpoint_callback(True)
self.check_early_stopping_callback(True)

# increment the global step once
# progress global step according to grads progress
Expand Down Expand Up @@ -853,25 +873,33 @@ def increment_accumulated_grad_global_step(self):
def _accumulated_batches_reached(self):
return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0

def _num_training_batches_reached(self):
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
def _num_training_batches_reached(self, is_last_batch=False):
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch

def should_accumulate(self):
# checks if backward or backward + optimizer step (via closure)
accumulation_done = self._accumulated_batches_reached()
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)

def should_check_val_fx(self, batch_idx, is_last_batch):
def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
can_check_val = self.trainer.enable_validation and is_val_check_epoch
should_check_val = is_val_check_batch or self.trainer.should_stop
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)
epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches

should_check_val = (
(is_val_check_batch and epoch_end_val_check)
or self.trainer.should_stop
or is_last_batch_for_infinite_dataset
) if on_epoch else (
is_val_check_batch
and not epoch_end_val_check
)

return should_check_val
return should_check_val and can_check_val

def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
# enable not needing to add opt_idx to training_step
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ def test_trainer_callback_system(torch_save):
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_batch_end(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
call.on_epoch_end(trainer, model),
call.on_train_epoch_end(trainer, model, ANY),
Comment on lines +89 to +90
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that @williamFalcon had a point some time ago about training shall be till validation, and the example was with validation multiple times over long training...
cc: @tchaton @PyTorchLightning/core-contributors

Copy link
Contributor Author

@rohitgr7 rohitgr7 Feb 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it still works like that only if val_check_interval < 1.0 or it an int where val_check_interval % num_training_batches != 0. But if it is set to 1.0 then validation here happens after training_epoch because we create checkpoints in on_validation_end and epoch level learning rates are updated once training is done since in case of ReduceLROnPlateau we need to have the monitor metrics and they are only available after complete training is done in case monitor is training specific.

call.on_validation_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model),
call.on_epoch_end(trainer, model),
call.on_train_epoch_end(trainer, model, ANY),
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, 'fit'),
Expand Down
39 changes: 36 additions & 3 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,9 @@ def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_ep

class ModelOverrideValidationReturn(EvalModelTemplate):
validation_return_values = torch.Tensor(loss_values)
count = 0

def validation_epoch_end(self, outputs):
loss = self.validation_return_values[self.count]
self.count += 1
loss = self.validation_return_values[self.current_epoch]
return {"test_val_loss": loss}

model = ModelOverrideValidationReturn()
Expand All @@ -133,6 +131,41 @@ def validation_epoch_end(self, outputs):
assert trainer.current_epoch == expected_stop_epoch


@pytest.mark.parametrize('validation_step', ['base', None])
@pytest.mark.parametrize(
"loss_values, patience, expected_stop_epoch",
[
([6, 5, 5, 5, 5, 5], 3, 4),
([6, 5, 4, 4, 3, 3], 1, 3),
([6, 5, 6, 5, 5, 5], 3, 4),
],
)
def test_early_stopping_patience_train(tmpdir, validation_step, loss_values, patience, expected_stop_epoch):
"""Test to ensure that early stopping is not triggered before patience is exhausted."""

class ModelOverrideTrainReturn(EvalModelTemplate):
train_return_values = torch.Tensor(loss_values)

def training_epoch_end(self, outputs):
loss = self.train_return_values[self.current_epoch]
self.log('train_loss', loss)

model = ModelOverrideTrainReturn()

if validation_step is None:
model.validation_step = None

early_stop_callback = EarlyStopping(monitor="train_loss", patience=patience, verbose=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
num_sanity_val_steps=0,
max_epochs=10,
)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch


def test_pickling(tmpdir):
early_stopping = EarlyStopping()

Expand Down
3 changes: 2 additions & 1 deletion tests/checkpointing/test_checkpoint_callback_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval,
default_root_dir=tmpdir,
max_epochs=epochs,
weights_summary=None,
val_check_interval=val_check_interval
val_check_interval=val_check_interval,
progress_bar_refresh_rate=0,
)
trainer.fit(model)

Expand Down
Loading