Skip to content

Commit 7aae589

Browse files
carmoccaBorda
andauthored
Add deprecation warning to ModelCheckpoint when logging val_loss with no monitor (#6012)
* Add deprecation warning when logging val_loss with no monitor * EOF * Update CHANGELOG * Clear warning cache before testing * pep8 Co-authored-by: Jirka Borovec <[email protected]>
1 parent 6e8721e commit 7aae589

File tree

5 files changed

+35
-10
lines changed

5 files changed

+35
-10
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
218218
* `xla_device_utils` >> `xla_device`
219219

220220

221+
- Deprecated using `'val_loss'` to set the `ModelCheckpoint` monitor ([#6012](https://github.com/PyTorchLightning/pytorch-lightning/pull/6012))
222+
223+
221224
### Removed
222225

223226
- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
3535
from pytorch_lightning.utilities.cloud_io import get_filesystem
3636
from pytorch_lightning.utilities.exceptions import MisconfigurationException
37+
from pytorch_lightning.utilities.warnings import WarningCache
38+
39+
warning_cache = WarningCache()
3740

3841

3942
class ModelCheckpoint(Callback):
@@ -185,9 +188,6 @@ def __init__(
185188
self.save_function = None
186189
self.warned_result_obj = False
187190

188-
if save_top_k is None and monitor is not None:
189-
self.save_top_k = 1
190-
191191
if prefix:
192192
rank_zero_warn(
193193
'Argument `prefix` is deprecated in v1.1 and will be removed in v1.3.'
@@ -460,17 +460,23 @@ def __resolve_ckpt_dir(self, trainer):
460460

461461
def _add_backward_monitor_support(self, trainer):
462462
metrics = trainer.logger_connector.callback_metrics
463+
deprecation_warning = False
463464

464-
# backward compatibility... need to deprecate
465465
if self.monitor is None and 'val_loss' in metrics:
466466
self.monitor = 'val_loss'
467-
468-
if self.monitor is None and 'checkpoint_on' in metrics:
469-
self.monitor = 'checkpoint_on'
467+
deprecation_warning = True
470468

471469
if self.save_top_k is None and self.monitor is not None:
470+
# TODO: Remove `Optional` from `save_top_k` when this is deleted in v1.4
472471
self.save_top_k = 1
473472

473+
if deprecation_warning:
474+
warning_cache.warn(
475+
"Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"
476+
" and will be removed in v1.4. Please, create your own `mc = ModelCheckpoint(monitor='your_monitor')`"
477+
" and use it as `Trainer(callbacks=[mc])`.", DeprecationWarning
478+
)
479+
474480
def _validate_monitor_key(self, trainer):
475481
metrics = trainer.logger_connector.callback_metrics
476482

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def attach_dataloaders(
9696
train_dataloader=None,
9797
val_dataloaders=None,
9898
test_dataloaders=None,
99-
predict_dataloaders=None
99+
predict_dataloaders=None,
100100
):
101101
# when dataloader is passed via fit, patch the train_dataloader
102102
# functions to overwrite with these implementations

pytorch_lightning/utilities/warnings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ class WarningCache:
1919
def __init__(self):
2020
self.warnings = set()
2121

22-
def warn(self, m):
22+
def warn(self, m, *args, **kwargs):
2323
if m not in self.warnings:
2424
self.warnings.add(m)
25-
rank_zero_warn(m)
25+
rank_zero_warn(m, *args, **kwargs)
2626

2727
def clear(self):
2828
self.warnings.clear()

tests/deprecated_api/test_remove_1-4.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,19 @@ def automatic_optimization(self):
215215
match="`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4"
216216
):
217217
trainer.fit(model)
218+
219+
220+
def test_v1_4_0_deprecated_checkpoint_on(tmpdir):
221+
from pytorch_lightning.callbacks.model_checkpoint import warning_cache
222+
warning_cache.clear()
223+
224+
class TestModel(BoringModel):
225+
226+
def training_step(self, batch, batch_idx):
227+
self.log("val_loss", -batch_idx)
228+
return super().training_step(batch, batch_idx)
229+
230+
trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=True, max_epochs=1)
231+
232+
with pytest.warns(DeprecationWarning, match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
233+
trainer.fit(TestModel())

0 commit comments

Comments
 (0)