|
34 | 34 | from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn |
35 | 35 | from pytorch_lightning.utilities.cloud_io import get_filesystem |
36 | 36 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 37 | +from pytorch_lightning.utilities.warnings import WarningCache |
| 38 | + |
| 39 | +warning_cache = WarningCache() |
37 | 40 |
|
38 | 41 |
|
39 | 42 | class ModelCheckpoint(Callback): |
@@ -185,9 +188,6 @@ def __init__( |
185 | 188 | self.save_function = None |
186 | 189 | self.warned_result_obj = False |
187 | 190 |
|
188 | | - if save_top_k is None and monitor is not None: |
189 | | - self.save_top_k = 1 |
190 | | - |
191 | 191 | if prefix: |
192 | 192 | rank_zero_warn( |
193 | 193 | 'Argument `prefix` is deprecated in v1.1 and will be removed in v1.3.' |
@@ -460,17 +460,23 @@ def __resolve_ckpt_dir(self, trainer): |
460 | 460 |
|
461 | 461 | def _add_backward_monitor_support(self, trainer): |
462 | 462 | metrics = trainer.logger_connector.callback_metrics |
| 463 | + deprecation_warning = False |
463 | 464 |
|
464 | | - # backward compatibility... need to deprecate |
465 | 465 | if self.monitor is None and 'val_loss' in metrics: |
466 | 466 | self.monitor = 'val_loss' |
467 | | - |
468 | | - if self.monitor is None and 'checkpoint_on' in metrics: |
469 | | - self.monitor = 'checkpoint_on' |
| 467 | + deprecation_warning = True |
470 | 468 |
|
471 | 469 | if self.save_top_k is None and self.monitor is not None: |
| 470 | + # TODO: Remove `Optional` from `save_top_k` when this is deleted in v1.4 |
472 | 471 | self.save_top_k = 1 |
473 | 472 |
|
| 473 | + if deprecation_warning: |
| 474 | + warning_cache.warn( |
| 475 | + "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2" |
| 476 | + " and will be removed in v1.4. Please, create your own `mc = ModelCheckpoint(monitor='your_monitor')`" |
| 477 | + " and use it as `Trainer(callbacks=[mc])`.", DeprecationWarning |
| 478 | + ) |
| 479 | + |
474 | 480 | def _validate_monitor_key(self, trainer): |
475 | 481 | metrics = trainer.logger_connector.callback_metrics |
476 | 482 |
|
|
0 commit comments