diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4ac800f456c06..ddd0bbc90065e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -32,8 +32,8 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.plugins.rpc_plugin import RPCPlugin +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -232,7 +232,8 @@ def save_checkpoint(self, trainer, pl_module): return self._add_backward_monitor_support(trainer) - self._validate_monitor_key(trainer) + if not self._validate_monitor_key(trainer): + return # track epoch when ckpt was last checked self.last_global_step_saved = global_step @@ -501,17 +502,20 @@ def _add_backward_monitor_support(self, trainer): if self.save_top_k is None and self.monitor is not None: self.save_top_k = 1 - def _validate_monitor_key(self, trainer): + def _validate_monitor_key(self, trainer) -> bool: metrics = trainer.logger_connector.callback_metrics # validate metric if self.monitor is not None and not self._is_valid_monitor_key(metrics): - m = ( - f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" - f" {list(metrics.keys())}. " - f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?" - ) - raise MisconfigurationException(m) + if not trainer.checkpoint_connector._one_training_epoch_completed: + return False + else: + raise MisconfigurationException( + f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" + f" {list(metrics.keys())}. " + f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?" + ) + return True def _get_metric_interpolated_filepath_name( self, diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 429bddd88b77e..36b572aee5f5d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -25,8 +25,8 @@ from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if APEX_AVAILABLE: from apex import amp @@ -41,7 +41,8 @@ def __init__(self, trainer): self.trainer = trainer # used to validate checkpointing logic - self.has_trained = False + self._has_trained = False + self._one_training_epoch_completed = False def restore_weights(self, model: LightningModule): """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 35da90625adef..a8db393fa0a94 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -24,7 +24,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector -from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -47,6 +46,7 @@ from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin +from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin @@ -56,7 +56,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import rank_zero_warn, DeviceType +from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -494,7 +494,8 @@ def train(self): # set stage for logging self.logger_connector.set_stage("train") - self.checkpoint_connector.has_trained = False + self.checkpoint_connector._has_trained = False + self.checkpoint_connector._one_training_epoch_completed = False # enable train mode model = self.get_model() @@ -526,6 +527,8 @@ def train(self): # update LR schedulers self.optimizer_connector.update_learning_rates(interval='epoch') + self.checkpoint_connector._one_training_epoch_completed = True + # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 68a0f4781c9a9..3c9044f8bdb3c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -213,7 +213,7 @@ def on_train_end(self): def check_checkpoint_callback(self, should_save, is_last=False): # TODO bake this logic into the checkpoint callback - if should_save and self.trainer.checkpoint_connector.has_trained: + if should_save and self.trainer.checkpoint_connector._has_trained: checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] if is_last and any(c.save_last for c in checkpoint_callbacks): @@ -597,7 +597,7 @@ def run_training_epoch(self): # update LR schedulers monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) - self.trainer.checkpoint_connector.has_trained = True + self.trainer.checkpoint_connector._has_trained = True # max steps reached, end training if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1: diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 9724f05247c00..73b15453fdde4 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -15,14 +15,14 @@ import os import warnings from functools import wraps +from typing import Any, Optional, Union import torch + from pytorch_lightning import _logger as log -from typing import Union, Optional, Any if torch.distributed.is_available(): - from torch.distributed import ReduceOp - from torch.distributed import group + from torch.distributed import ReduceOp, group else: class ReduceOp: SUM = None diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 106c34030051e..97df8a0b44b9d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -25,15 +25,16 @@ import torch import yaml from omegaconf import Container, OmegaConf +from torch.utils.data import DataLoader, Dataset, random_split import pytorch_lightning as pl import tests.base.develop_utils as tutils -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import LightningModule, Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import BoringModel +from tests.base import BoringModel, RandomDataset class LogInTwoMethods(BoringModel): @@ -702,7 +703,7 @@ def validation_epoch_end(self, *_): ... def assert_trainer_init(trainer): - assert not trainer.checkpoint_connector.has_trained + assert not trainer.checkpoint_connector._has_trained assert trainer.global_step == 0 assert trainer.current_epoch == 0 @@ -739,7 +740,7 @@ def assert_checkpoint_log_dir(idx): model = ExtendedBoringModel() trainer.fit(model) - assert trainer.checkpoint_connector.has_trained + assert trainer.checkpoint_connector._has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs - 1 assert_checkpoint_log_dir(0) @@ -759,12 +760,12 @@ def assert_checkpoint_log_dir(idx): model = ExtendedBoringModel() trainer.test(model) - assert not trainer.checkpoint_connector.has_trained + assert not trainer.checkpoint_connector._has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs trainer.fit(model) - assert not trainer.checkpoint_connector.has_trained + assert not trainer.checkpoint_connector._has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert_checkpoint_log_dir(idx) @@ -940,6 +941,41 @@ def __init__(self, hparams): assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type +def test_model_checkpoint_with_training_epoch_end(tmpdir): + """ + This test ensures ModelCheckpoint issues a warning when the monitor is logged on training_epoch_end + """ + class TestedModel(BoringModel): + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('train_loss', loss) + return {"loss": loss} + + def training_epoch_end(self, outputs) -> None: + avg_loss = torch.stack([x["loss"] for x in outputs]).mean() + self.log('epoch_end_train_loss', avg_loss) + + model = TestedModel() + + chk = ModelCheckpoint(dirpath=tmpdir, monitor='epoch_end_train_loss', save_top_k=-1) + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=4, + progress_bar_refresh_rate=1, + callbacks=[chk], + ) + trainer.current_epoch = 2 + trainer.fit(model) + + chks = os.listdir(tmpdir) + assert 'epoch=4.ckpt' not in chks + assert 'epoch=3.ckpt' not in chks + assert 'epoch=2.ckpt' not in chks + + + @pytest.mark.parametrize('max_epochs', [3, 4]) @pytest.mark.parametrize( 'save_top_k, expected', @@ -976,4 +1012,4 @@ def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, ex assert set(ckpt_files) == set(expected) epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files] - assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs)) + assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs)) \ No newline at end of file