From ec6639139a7bbd0e7c75aef55d2f4f1cea984e73 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 11 Dec 2020 11:04:56 +0100 Subject: [PATCH 1/8] wip --- .../callbacks/model_checkpoint.py | 17 +++--- pytorch_lightning/utilities/distributed.py | 8 +-- tests/checkpointing/test_model_checkpoint.py | 60 ++++++++++++++++++- 3 files changed, 72 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1354f7f5056b3..60a06d3225c21 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -32,8 +32,8 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.plugins.rpc_plugin import RPCPlugin +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -508,12 +508,15 @@ def _validate_monitor_key(self, trainer): # validate metric if self.monitor is not None and not self._is_valid_monitor_key(metrics): - m = ( - f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" - f" {list(metrics.keys())}. " - f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?" - ) - raise MisconfigurationException(m) + if trainer.current_epoch == 0: + m = ( + f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" + f" {list(metrics.keys())}. You might " + f"HINT: If you monitor training_epoch_end" + ) + rank_zero_warn(m) + else: + raise MisconfigurationException(m) def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int): filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 9724f05247c00..8fc162bee684d 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -15,14 +15,14 @@ import os import warnings from functools import wraps +from typing import Any, Optional, Union import torch + from pytorch_lightning import _logger as log -from typing import Union, Optional, Any if torch.distributed.is_available(): - from torch.distributed import ReduceOp - from torch.distributed import group + from torch.distributed import ReduceOp, group else: class ReduceOp: SUM = None @@ -145,7 +145,7 @@ def sync_ddp( if group is None: group = torch.distributed.group.WORLD - if reduce_op is None: + if reduce_op is None or reduce_op == "sum": reduce_op = torch.distributed.ReduceOp.SUM elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): reduce_op = torch.distributed.ReduceOp.SUM diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 31154eac1bf0d..738986358a565 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -27,15 +27,16 @@ import torch import yaml from omegaconf import Container, OmegaConf +from torch.utils.data import DataLoader, Dataset, random_split import pytorch_lightning as pl import tests.base.develop_utils as tutils -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import LightningModule, Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import BoringModel +from tests.base import BoringModel, RandomDataset class LogInTwoMethods(BoringModel): @@ -1020,3 +1021,58 @@ def __init__(self, hparams): else: # make sure it's not AttributeDict assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type + + +@mock.patch("torch.save") # need to mock torch.save or we get pickle error +def test_model_checkpoint_with_training_epoch_End(tmpdir): + + """ + This test assert ModelCheckpoint finds monitor metrics when logged on training_epoch_end + """ + class TestedModel(BoringModel): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def loss(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('train_loss', loss) # comment this line and it will work + return {"loss": loss} + + def training_step_end(self, training_step_outputs): + return training_step_outputs + + def training_epoch_end(self, outputs) -> None: + avg_loss = torch.stack([x["loss"] for x in outputs]).mean() + self.log('epoch_end_train_loss', avg_loss) + self.log('gb_step', self.global_step) + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + torch.stack([x['x'] for x in outputs]).mean() + + model = TestedModel() + + callbacks=[pl.callbacks.ModelCheckpoint(monitor='epoch_end_train_loss', save_top_k=-1)] + # Initialize a trainer + trainer = pl.Trainer( + max_epochs=3, + progress_bar_refresh_rate=1, + callbacks=callbacks, + ) + + # Train the model ⚡ + trainer.fit(model) From d6eab2cd7b65c8bb8ddd373ce3ce14ea5e01cfe3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 11 Dec 2020 12:35:33 +0100 Subject: [PATCH 2/8] simplify test --- .../callbacks/model_checkpoint.py | 13 +++--- .../connectors/checkpoint_connector.py | 5 ++- pytorch_lightning/trainer/trainer.py | 7 ++- tests/checkpointing/test_model_checkpoint.py | 43 +++---------------- 4 files changed, 22 insertions(+), 46 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 60a06d3225c21..faad92441d526 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -508,14 +508,15 @@ def _validate_monitor_key(self, trainer): # validate metric if self.monitor is not None and not self._is_valid_monitor_key(metrics): - if trainer.current_epoch == 0: - m = ( - f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" - f" {list(metrics.keys())}. You might " - f"HINT: If you monitor training_epoch_end" - ) + m = ( + f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" + f" {list(metrics.keys())}. " + ) + if not trainer.checkpoint_connector.one_training_epoch_completed: + m += "Running first epoch, a MisconfigurationException will be raise next epoch" rank_zero_warn(m) else: + m += f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?" raise MisconfigurationException(m) def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int): diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 2311cc767de2d..3c251c638cc21 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -20,11 +20,11 @@ import pytorch_lightning from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import APEX_AVAILABLE, OMEGACONF_AVAILABLE, AMPType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if APEX_AVAILABLE: from apex import amp @@ -40,6 +40,7 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False + self.one_training_epoch_completed = False def restore_weights(self, model: LightningModule): """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 35da90625adef..36ec3faeb41f6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -24,7 +24,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector -from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -47,6 +46,7 @@ from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin +from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin @@ -56,7 +56,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import rank_zero_warn, DeviceType +from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -495,6 +495,7 @@ def train(self): self.logger_connector.set_stage("train") self.checkpoint_connector.has_trained = False + self.checkpoint_connector.one_training_epoch_completed = False # enable train mode model = self.get_model() @@ -526,6 +527,8 @@ def train(self): # update LR schedulers self.optimizer_connector.update_learning_rates(interval='epoch') + self.checkpoint_connector.one_training_epoch_completed = True + # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 738986358a565..fc659e98a3238 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1023,56 +1023,27 @@ def __init__(self, hparams): assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type -@mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_model_checkpoint_with_training_epoch_End(tmpdir): +def test_model_checkpoint_with_training_epoch_end(tmpdir): """ - This test assert ModelCheckpoint finds monitor metrics when logged on training_epoch_end + This test assert ModelCheckpoint a warming is issued when monitor metric is used in training_epoch_end """ class TestedModel(BoringModel): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - def forward(self, x): - return self.layer(x) - - def loss(self, batch, prediction): - # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls - return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) - - def training_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log('train_loss', loss) # comment this line and it will work - return {"loss": loss} - - def training_step_end(self, training_step_outputs): - return training_step_outputs - def training_epoch_end(self, outputs) -> None: avg_loss = torch.stack([x["loss"] for x in outputs]).mean() self.log('epoch_end_train_loss', avg_loss) self.log('gb_step', self.global_step) - def validation_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"x": loss} - - def validation_epoch_end(self, outputs) -> None: - torch.stack([x['x'] for x in outputs]).mean() - model = TestedModel() - callbacks=[pl.callbacks.ModelCheckpoint(monitor='epoch_end_train_loss', save_top_k=-1)] - # Initialize a trainer + chk = ModelCheckpoint(monitor='epoch_end_train_loss', save_top_k=-1) + trainer = pl.Trainer( + default_root_dir=tmpdir, max_epochs=3, progress_bar_refresh_rate=1, - callbacks=callbacks, + callbacks=[chk], ) - - # Train the model ⚡ + trainer.current_epoch = 2 trainer.fit(model) From a4694ac874d9d6b958c5673207e5cb5902d6ec18 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 11 Dec 2020 13:19:34 +0100 Subject: [PATCH 3/8] test for warning --- .../callbacks/model_checkpoint.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 24 ++++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index faad92441d526..bfd9130858163 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -514,7 +514,7 @@ def _validate_monitor_key(self, trainer): ) if not trainer.checkpoint_connector.one_training_epoch_completed: m += "Running first epoch, a MisconfigurationException will be raise next epoch" - rank_zero_warn(m) + rank_zero_warn(m, UserWarning) else: m += f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?" raise MisconfigurationException(m) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index fc659e98a3238..e9b1a57738450 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1030,6 +1030,12 @@ def test_model_checkpoint_with_training_epoch_end(tmpdir): """ class TestedModel(BoringModel): + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('train_loss', loss) + return {"loss": loss} + def training_epoch_end(self, outputs) -> None: avg_loss = torch.stack([x["loss"] for x in outputs]).mean() self.log('epoch_end_train_loss', avg_loss) @@ -1038,12 +1044,12 @@ def training_epoch_end(self, outputs) -> None: model = TestedModel() chk = ModelCheckpoint(monitor='epoch_end_train_loss', save_top_k=-1) - - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=3, - progress_bar_refresh_rate=1, - callbacks=[chk], - ) - trainer.current_epoch = 2 - trainer.fit(model) + with pytest.warns(UserWarning, match="Running first epoch, a MisconfigurationException"): + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=3, + progress_bar_refresh_rate=1, + callbacks=[chk], + ) + trainer.current_epoch = 2 + trainer.fit(model) From 8e484bf8b47965a92ddf7722282600a428640663 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 11 Dec 2020 14:51:23 +0100 Subject: [PATCH 4/8] move to protected --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- .../trainer/connectors/checkpoint_connector.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 6 +++--- pytorch_lightning/trainer/training_loop.py | 4 ++-- tests/checkpointing/test_model_checkpoint.py | 8 ++++---- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index bfd9130858163..365d3f0fecdb2 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -512,7 +512,7 @@ def _validate_monitor_key(self, trainer): f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" f" {list(metrics.keys())}. " ) - if not trainer.checkpoint_connector.one_training_epoch_completed: + if not trainer.checkpoint_connector._one_training_epoch_completed: m += "Running first epoch, a MisconfigurationException will be raise next epoch" rank_zero_warn(m, UserWarning) else: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3c251c638cc21..c03d6596ab0c1 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -39,8 +39,8 @@ def __init__(self, trainer): self.trainer = trainer # used to validate checkpointing logic - self.has_trained = False - self.one_training_epoch_completed = False + self._has_trained = False + self._one_training_epoch_completed = False def restore_weights(self, model: LightningModule): """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 36ec3faeb41f6..a8db393fa0a94 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -494,8 +494,8 @@ def train(self): # set stage for logging self.logger_connector.set_stage("train") - self.checkpoint_connector.has_trained = False - self.checkpoint_connector.one_training_epoch_completed = False + self.checkpoint_connector._has_trained = False + self.checkpoint_connector._one_training_epoch_completed = False # enable train mode model = self.get_model() @@ -527,7 +527,7 @@ def train(self): # update LR schedulers self.optimizer_connector.update_learning_rates(interval='epoch') - self.checkpoint_connector.one_training_epoch_completed = True + self.checkpoint_connector._one_training_epoch_completed = True # early stopping met_min_epochs = epoch >= self.min_epochs - 1 diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 20dfb0f4b380f..7b5e3685c406b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -213,7 +213,7 @@ def on_train_end(self): def check_checkpoint_callback(self, should_save, is_last=False): # TODO bake this logic into the checkpoint callback - if should_save and self.trainer.checkpoint_connector.has_trained: + if should_save and self.trainer.checkpoint_connector._has_trained: checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] if is_last and any(c.save_last for c in checkpoint_callbacks): @@ -599,7 +599,7 @@ def run_training_epoch(self): # update LR schedulers monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) - self.trainer.checkpoint_connector.has_trained = True + self.trainer.checkpoint_connector._has_trained = True # max steps reached, end training if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e9b1a57738450..402156b5b2eac 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -768,7 +768,7 @@ def validation_step(self, batch, batch_idx): return {"val_loss": loss} def assert_trainer_init(trainer): - assert not trainer.checkpoint_connector.has_trained + assert not trainer.checkpoint_connector._has_trained assert trainer.global_step == 0 assert trainer.current_epoch == 0 @@ -816,7 +816,7 @@ def get_model(): assert_trainer_init(trainer) trainer.fit(model) - assert trainer.checkpoint_connector.has_trained + assert trainer.checkpoint_connector._has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs - 1 assert_checkpoint_log_dir(0) @@ -842,12 +842,12 @@ def get_model(): assert_trainer_init(trainer) trainer.test(model) - assert not trainer.checkpoint_connector.has_trained + assert not trainer.checkpoint_connector._has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs trainer.fit(model) - assert not trainer.checkpoint_connector.has_trained + assert not trainer.checkpoint_connector._has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert_checkpoint_log_dir(idx) From 31ca923d91f526aa4ff1e3801dd4d4a9c6d777de Mon Sep 17 00:00:00 2001 From: chaton Date: Sun, 13 Dec 2020 22:07:07 +0100 Subject: [PATCH 5/8] Update tests/checkpointing/test_model_checkpoint.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- tests/checkpointing/test_model_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 402156b5b2eac..0c6bfa413cb6c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1024,9 +1024,8 @@ def __init__(self, hparams): def test_model_checkpoint_with_training_epoch_end(tmpdir): - """ - This test assert ModelCheckpoint a warming is issued when monitor metric is used in training_epoch_end + This test ensures ModelCheckpoint issues a warning when the monitor is logged on training_epoch_end """ class TestedModel(BoringModel): From 3484f2be39453563f242baf5f95f044d881aa5d8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 13 Dec 2020 22:29:07 +0100 Subject: [PATCH 6/8] update on comments --- .../callbacks/model_checkpoint.py | 20 +++++++++-------- pytorch_lightning/utilities/distributed.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 22 +++++++++++-------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 365d3f0fecdb2..54342050e8e39 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -232,7 +232,8 @@ def save_checkpoint(self, trainer, pl_module): return self._add_backward_monitor_support(trainer) - self._validate_monitor_key(trainer) + if not self._validate_monitor_key(trainer): + return # track epoch when ckpt was last checked self.last_global_step_saved = global_step @@ -247,6 +248,7 @@ def save_checkpoint(self, trainer, pl_module): # here we call each mode sequentially # Mode 1: save all checkpoints OR only the top k if self.save_top_k: + print(epoch, global_step) self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath) # Mode 2: save the last checkpoint @@ -503,21 +505,21 @@ def _add_backward_monitor_support(self, trainer): if self.save_top_k is None and self.monitor is not None: self.save_top_k = 1 - def _validate_monitor_key(self, trainer): + def _validate_monitor_key(self, trainer) -> bool: metrics = trainer.logger_connector.callback_metrics # validate metric if self.monitor is not None and not self._is_valid_monitor_key(metrics): - m = ( - f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" - f" {list(metrics.keys())}. " - ) if not trainer.checkpoint_connector._one_training_epoch_completed: - m += "Running first epoch, a MisconfigurationException will be raise next epoch" - rank_zero_warn(m, UserWarning) + return False else: - m += f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?" + m = ( + f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" + f" {list(metrics.keys())}. " + f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?" + ) raise MisconfigurationException(m) + return True def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int): filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 8fc162bee684d..73b15453fdde4 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -145,7 +145,7 @@ def sync_ddp( if group is None: group = torch.distributed.group.WORLD - if reduce_op is None or reduce_op == "sum": + if reduce_op is None: reduce_op = torch.distributed.ReduceOp.SUM elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): reduce_op = torch.distributed.ReduceOp.SUM diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 402156b5b2eac..2ef4a17343ab3 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1044,12 +1044,16 @@ def training_epoch_end(self, outputs) -> None: model = TestedModel() chk = ModelCheckpoint(monitor='epoch_end_train_loss', save_top_k=-1) - with pytest.warns(UserWarning, match="Running first epoch, a MisconfigurationException"): - trainer = pl.Trainer( - default_root_dir=tmpdir, - max_epochs=3, - progress_bar_refresh_rate=1, - callbacks=[chk], - ) - trainer.current_epoch = 2 - trainer.fit(model) + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=4, + progress_bar_refresh_rate=1, + callbacks=[chk], + ) + trainer.current_epoch = 2 + trainer.fit(model) + + chks = os.listdir(tmpdir / 'lightning_logs' / 'version_0' / 'checkpoints') + assert any(['epoch=4' not in c for c in chks]) + assert any(['epoch=3' in c for c in chks]) + assert any(['epoch=2' not in c for c in chks]) From 05fd97f031d616ef24e520a5fad183b499ba02fb Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 14 Dec 2020 08:17:18 +0100 Subject: [PATCH 7/8] Update pytorch_lightning/callbacks/model_checkpoint.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/callbacks/model_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 54342050e8e39..59a6a3a4e3e5e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -513,12 +513,11 @@ def _validate_monitor_key(self, trainer) -> bool: if not trainer.checkpoint_connector._one_training_epoch_completed: return False else: - m = ( + raise MisconfigurationException( f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" f" {list(metrics.keys())}. " f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?" ) - raise MisconfigurationException(m) return True def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int): From 15a88db97d0fc0562c15bcc54e3989b5fc60f39f Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Dec 2020 08:20:51 +0100 Subject: [PATCH 8/8] update on comments --- pytorch_lightning/callbacks/model_checkpoint.py | 1 - tests/checkpointing/test_model_checkpoint.py | 11 +++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 54342050e8e39..9ed8c9ae93a0f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -248,7 +248,6 @@ def save_checkpoint(self, trainer, pl_module): # here we call each mode sequentially # Mode 1: save all checkpoints OR only the top k if self.save_top_k: - print(epoch, global_step) self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath) # Mode 2: save the last checkpoint diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 05ba8becf3bbf..a6cd0d82b3cfa 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1038,11 +1038,10 @@ def training_step(self, batch, batch_idx): def training_epoch_end(self, outputs) -> None: avg_loss = torch.stack([x["loss"] for x in outputs]).mean() self.log('epoch_end_train_loss', avg_loss) - self.log('gb_step', self.global_step) model = TestedModel() - chk = ModelCheckpoint(monitor='epoch_end_train_loss', save_top_k=-1) + chk = ModelCheckpoint(dirpath=tmpdir, monitor='epoch_end_train_loss', save_top_k=-1) trainer = pl.Trainer( default_root_dir=tmpdir, max_epochs=4, @@ -1052,7 +1051,7 @@ def training_epoch_end(self, outputs) -> None: trainer.current_epoch = 2 trainer.fit(model) - chks = os.listdir(tmpdir / 'lightning_logs' / 'version_0' / 'checkpoints') - assert any(['epoch=4' not in c for c in chks]) - assert any(['epoch=3' in c for c in chks]) - assert any(['epoch=2' not in c for c in chks]) + chks = os.listdir(tmpdir) + assert 'epoch=4.ckpt' not in chks + assert 'epoch=3.ckpt' not in chks + assert 'epoch=2.ckpt' not in chks