diff --git a/.gitignore b/.gitignore index d6ae2ef48ed01..0e51e29f9ae4e 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,6 @@ pytorch\ lightning test-reports/ wandb .forked/ + +# ctags +tags diff --git a/CHANGELOG.md b/CHANGELOG.md index cc39f2dd98b0c..7934003915850 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,12 +4,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [1.1.8] - 2021-02-06 +## [1.1.8] - 2021-02-08 ### Fixed +- Separate epoch validation from step validation ([#5208](https://github.com/PyTorchLightning/pytorch-lightning/pull/5208)) - Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775)) + ## [1.1.7] - 2021-02-03 ### Fixed diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 72e7a7944cfcf..8f678dfea9ef1 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -88,9 +88,6 @@ def __init__( self.stopped_epoch = 0 self.mode = mode self.warned_result_obj = False - # Indicates, if eval results are used as basis for early stopping - # It is set to False initially and overwritten, if eval results have been validated - self.based_on_eval_results = False self.__init_monitor_mode() @@ -164,21 +161,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer, pl_module) - def on_validation_epoch_end(self, trainer, pl_module): - if trainer.fast_dev_run or trainer.running_sanity_check: - return - - if self._validate_condition_metric(trainer.callback_metrics): - # turn off early stopping in on_train_epoch_end - self.based_on_eval_results = True - - def on_train_epoch_end(self, trainer, pl_module, outputs): - # disable early stopping in train loop when there's a val loop - if self.based_on_eval_results: - return - - self._run_early_stopping_check(trainer, pl_module) - def _run_early_stopping_check(self, trainer, pl_module): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6ec444eaa3838..8faecb07dc21f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -166,7 +166,7 @@ def __init__( self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period - self.last_global_step_saved = -1 + self._last_global_step_saved = -1 self.prefix = prefix self.current_score = None self.best_k_models = {} @@ -231,7 +231,7 @@ def save_checkpoint(self, trainer, pl_module): or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch or trainer.running_sanity_check # don't save anything during sanity check - or self.last_global_step_saved == global_step # already saved at the last step + or self._last_global_step_saved == global_step # already saved at the last step ): return @@ -239,7 +239,7 @@ def save_checkpoint(self, trainer, pl_module): self._validate_monitor_key(trainer) # track epoch when ckpt was last checked - self.last_global_step_saved = global_step + self._last_global_step_saved = global_step # what can be monitored monitor_candidates = self._monitor_candidates(trainer) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 34d572de84c51..6953ff29330d6 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -400,6 +400,7 @@ def save_checkpoint(self, filepath, weights_only: bool = False): if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( - 'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}' + 'Warning, `hyper_parameters` dropped from checkpoint.' + f' An attribute is not picklable {err}' ) atomic_save(checkpoint, filepath) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 63f65bead2579..ea759a1c29750 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -71,17 +71,8 @@ def get_evaluation_dataloaders(self, max_batches): return dataloaders, max_batches - def should_skip_evaluation(self, dataloaders, max_batches): - # skip when dataloaders aren't defined - if dataloaders is None: - return True - - # enable disabling validation step with limit_val_batches = 0 - should_skip = sum(max_batches) == 0 - if should_skip: - return True - - return False + def should_skip_evaluation(self, max_batches): + return sum(max_batches) == 0 def on_evaluation_start(self, *args, **kwargs): if self.trainer.testing: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bbfc3fd3202f2..b65fcf9c72d37 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -563,9 +563,6 @@ def train(self): if self.max_steps and self.max_steps <= self.global_step: return - # update LR schedulers - self.optimizer_connector.update_learning_rates(interval='epoch') - # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True @@ -591,7 +588,7 @@ def train(self): # hook self.train_loop.on_train_end() - def run_evaluation(self, max_batches=None): + def run_evaluation(self, max_batches=None, on_epoch=False): # used to know if we are logging for val, test + reset cached results self.logger_connector.set_stage(self.testing, reset=True) @@ -603,7 +600,7 @@ def run_evaluation(self, max_batches=None): dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) # check if we want to skip this evaluation - if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches): + if self.evaluation_loop.should_skip_evaluation(max_batches): return [], [] # ref model @@ -664,6 +661,10 @@ def run_evaluation(self, max_batches=None): # hook self.evaluation_loop.on_evaluation_epoch_end() + # update epoch-level lr_schedulers + if on_epoch: + self.optimizer_connector.update_learning_rates(interval='epoch') + # hook self.evaluation_loop.on_evaluation_end() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 83af076eaa8d2..0b030b1e39a85 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -18,7 +18,7 @@ import torch import torch.distributed as torch_distrib -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer @@ -153,7 +153,7 @@ def on_train_end(self): # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_save=True, is_last=True) + self.check_checkpoint_callback(should_update=True, is_last=True) self.trainer.global_step += 1 # hook @@ -176,18 +176,27 @@ def on_train_end(self): model.cpu() torch.cuda.empty_cache() - def check_checkpoint_callback(self, should_save, is_last=False): - # TODO bake this logic into the checkpoint callback - if should_save and self.trainer.checkpoint_connector.has_trained: - checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + def check_checkpoint_callback(self, should_update, is_last=False): + # TODO bake this logic into the ModelCheckpoint callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = self.trainer.checkpoint_callbacks - if is_last and any(c.save_last for c in checkpoint_callbacks): + if is_last and any(cb.save_last for cb in callbacks): rank_zero_info("Saving latest checkpoint...") model = self.trainer.get_model() - for callback in checkpoint_callbacks: - callback.on_validation_end(self.trainer, model) + for cb in callbacks: + cb.on_validation_end(self.trainer, model) + + def check_early_stopping_callback(self, should_update): + # TODO bake this logic into the EarlyStopping callback + if should_update and self.trainer.checkpoint_connector.has_trained: + callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] + model = self.trainer.get_model() + + for cb in callbacks: + cb.on_validation_end(self.trainer, model) def on_train_epoch_start(self, epoch): @@ -518,7 +527,6 @@ def tbptt_split_batch(self, batch): return splits def run_training_epoch(self): - # get model model = self.trainer.get_model() @@ -531,7 +539,6 @@ def run_training_epoch(self): # enable profiling for the dataloader train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 - should_check_val = False for batch_idx, (batch, is_last_batch) in train_dataloader: self.trainer.batch_idx = batch_idx @@ -580,11 +587,12 @@ def run_training_epoch(self): self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training - if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1: - accumulation_done = self._accumulated_batches_reached() - # Ensure accumulation across batches has completed before breaking loop - if accumulation_done: - break + if ( + self.trainer.max_steps is not None + and self.trainer.max_steps == self.trainer.global_step + 1 + and self._accumulated_batches_reached() + ): + break # end epoch early # stop when the flag is changed or we've gone past the amount @@ -595,7 +603,7 @@ def run_training_epoch(self): self.trainer.total_batch_idx += 1 # stop epoch if we limited the number of training batches - if (batch_idx + 1) >= self.trainer.num_training_batches: + if self._num_training_batches_reached(is_last_batch): break # progress global step according to grads progress @@ -612,8 +620,20 @@ def run_training_epoch(self): self.num_optimizers ) - # when no val loop is present or fast-dev-run still need to call checkpoints - self.check_checkpoint_callback(not (should_check_val or is_overridden('validation_step', model))) + should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) + if should_check_val: + self.trainer.run_evaluation(on_epoch=True) + # reset stage to train + self.trainer.logger_connector.set_stage("train") + + should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) + should_train_only = self.trainer.disable_validation or should_skip_eval + + if should_train_only: + # update epoch level lr_schedulers + self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) # increment the global step once # progress global step according to grads progress @@ -853,8 +873,8 @@ def increment_accumulated_grad_global_step(self): def _accumulated_batches_reached(self): return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 - def _num_training_batches_reached(self): - return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches + def _num_training_batches_reached(self, is_last_batch=False): + return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch def should_accumulate(self): # checks if backward or backward + optimizer step (via closure) @@ -862,16 +882,24 @@ def should_accumulate(self): is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) - def should_check_val_fx(self, batch_idx, is_last_batch): + def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 can_check_val = self.trainer.enable_validation and is_val_check_epoch - should_check_val = is_val_check_batch or self.trainer.should_stop is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") - should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset) + epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches + + should_check_val = ( + (is_val_check_batch and epoch_end_val_check) + or self.trainer.should_stop + or is_last_batch_for_infinite_dataset + ) if on_epoch else ( + is_val_check_batch + and not epoch_end_val_check + ) - return should_check_val + return should_check_val and can_check_val def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index c9baf0db6976d..d6421877e80ad 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -86,6 +86,8 @@ def test_trainer_callback_system(torch_save): call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_batch_end(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), + call.on_epoch_end(trainer, model), + call.on_train_epoch_end(trainer, model, ANY), call.on_validation_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), @@ -93,8 +95,6 @@ def test_trainer_callback_system(torch_save): call.on_validation_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), - call.on_epoch_end(trainer, model), - call.on_train_epoch_end(trainer, model, ANY), call.on_train_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 6fa7e5f3567b6..c8099d11f6707 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -113,11 +113,9 @@ def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_ep class ModelOverrideValidationReturn(EvalModelTemplate): validation_return_values = torch.Tensor(loss_values) - count = 0 def validation_epoch_end(self, outputs): - loss = self.validation_return_values[self.count] - self.count += 1 + loss = self.validation_return_values[self.current_epoch] return {"test_val_loss": loss} model = ModelOverrideValidationReturn() @@ -133,6 +131,41 @@ def validation_epoch_end(self, outputs): assert trainer.current_epoch == expected_stop_epoch +@pytest.mark.parametrize('validation_step', ['base', None]) +@pytest.mark.parametrize( + "loss_values, patience, expected_stop_epoch", + [ + ([6, 5, 5, 5, 5, 5], 3, 4), + ([6, 5, 4, 4, 3, 3], 1, 3), + ([6, 5, 6, 5, 5, 5], 3, 4), + ], +) +def test_early_stopping_patience_train(tmpdir, validation_step, loss_values, patience, expected_stop_epoch): + """Test to ensure that early stopping is not triggered before patience is exhausted.""" + + class ModelOverrideTrainReturn(EvalModelTemplate): + train_return_values = torch.Tensor(loss_values) + + def training_epoch_end(self, outputs): + loss = self.train_return_values[self.current_epoch] + self.log('train_loss', loss) + + model = ModelOverrideTrainReturn() + + if validation_step is None: + model.validation_step = None + + early_stop_callback = EarlyStopping(monitor="train_loss", patience=patience, verbose=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[early_stop_callback], + num_sanity_val_steps=0, + max_epochs=10, + ) + trainer.fit(model) + assert trainer.current_epoch == expected_stop_epoch + + def test_pickling(tmpdir): early_stopping = EarlyStopping() diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index f9686dce159dd..69a8407a88590 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -56,7 +56,8 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval, default_root_dir=tmpdir, max_epochs=epochs, weights_summary=None, - val_check_interval=val_check_interval + val_check_interval=val_check_interval, + progress_bar_refresh_rate=0, ) trainer.fit(model) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 337d293df02ee..e8888ce22565a 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import math import os import pickle import platform import re from argparse import Namespace +from distutils.version import LooseVersion from pathlib import Path from unittest import mock from unittest.mock import Mock @@ -50,26 +52,91 @@ def validation_epoch_end(self, outputs): @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -@pytest.mark.parametrize('save_top_k', [-1]) -def test_model_checkpoint_correct_score(tmpdir, save_top_k): - """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path""" - tutils.reset_seed() +@pytest.mark.parametrize( + "validation_step,val_dataloaders,monitor", + [ + ('base', "base", 'val_log'), + ('base', "base", 'train_log_epoch'), + (None, "base", 'train_log_epoch'), + ("base", None, 'train_log_epoch') + ], +) +def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step, val_dataloaders, monitor): + """ + Test that when a model checkpoint is saved, it saves with + the correct score appended to ckpt_path and checkpoint data + """ + max_epochs = 3 + limit_train_batches = 5 + limit_val_batches = 7 - model = LogInTwoMethods() + class CustomBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) + self.val_logs = torch.randn(max_epochs, limit_val_batches) + + def training_step(self, batch, batch_idx): + out = super().training_step(batch, batch_idx) + log_value = self.train_log_epochs[self.current_epoch, batch_idx] + self.log('train_log', log_value, on_epoch=True) + return out - filename = "{val_acc:.4f}-{epoch}" + def validation_step(self, batch, batch_idx): + out = super().validation_step(batch, batch_idx) + log_value = self.val_logs[self.current_epoch, batch_idx] + self.log('val_log', log_value) + self.log('epoch', self.current_epoch, on_epoch=True) + return out - checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor='val_acc', save_top_k=save_top_k) + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.2) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] - trainer = Trainer(default_root_dir=tmpdir, callbacks=[checkpoint], overfit_batches=0.20, max_epochs=2) + filename = '{' + f'{monitor}' + ':.4f}-{epoch}' + checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) + + model = CustomBoringModel() + + if validation_step is None: + model.validation_step = None + if val_dataloaders is None: + model.val_dataloaders = None + + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint], + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + max_epochs=max_epochs, + progress_bar_refresh_rate=0, + ) trainer.fit(model) ckpt_files = list(Path(tmpdir).glob('*.ckpt')) - - metrics = trainer.dev_debugger.logged_metrics - expected_filenames = {f'val_acc={metric["val_acc"]:.4f}-epoch={metric["epoch"]}.ckpt' for metric in metrics} - for ckpt_file in ckpt_files: - assert os.path.basename(ckpt_file) in expected_filenames + scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric] + assert len(ckpt_files) == len(scores) == max_epochs + + for epoch in range(max_epochs): + score = scores[epoch] + expected_score = getattr(model, f'{monitor}s')[epoch].mean().item() + expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt' + assert math.isclose(score, expected_score, rel_tol=1e-4) + + chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) + assert chk['epoch'] == epoch + 1 + assert chk['global_step'] == limit_train_batches * (epoch + 1) + + mc_specific_data = chk['callbacks'][type(checkpoint)] + assert mc_specific_data['dirpath'] == checkpoint.dirpath + assert mc_specific_data['monitor'] == monitor + assert mc_specific_data['current_score'] == score + + lr_scheduler_specific_data = chk['lr_schedulers'][0] + assert lr_scheduler_specific_data['_step_count'] == epoch + 2 + if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): + assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1 ** (epoch + 1)) @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index b5773bee87358..cc4560a3e01f7 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -425,6 +425,8 @@ def on_test_model_train(self): 'on_after_backward', 'on_before_zero_grad', 'on_train_batch_end', + 'on_epoch_end', + 'on_train_epoch_end', 'on_validation_model_eval', 'on_validation_epoch_start', 'on_validation_batch_start', @@ -432,8 +434,6 @@ def on_test_model_train(self): 'on_validation_epoch_end', 'on_save_checkpoint', 'on_validation_model_train', - 'on_epoch_end', - 'on_train_epoch_end', 'on_train_end', 'on_fit_end', ]