Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`


- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162))


### Fixed

- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
Expand Down
3 changes: 0 additions & 3 deletions docs/source/common/hyperparameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ improve readability and reproducibility.
def train_dataloader(self):
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)

.. warning:: Deprecated since v1.1.0. This method of assigning hyperparameters to the LightningModule
will no longer be supported from v1.3.0. Use the ``self.save_hyperparameters()`` method from above instead.


4. You can also save full objects such as `dict` or `Namespace` to the checkpoint.

Expand Down
49 changes: 9 additions & 40 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,26 +86,14 @@ class ModelCheckpoint(Callback):
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with ``v1``.
mode: one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.

.. warning::
Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3.

mode: one of {min, max}.
If ``save_top_k != 0``, the decision to overwrite the current save file is made
based on either the maximization or the minimization of the monitored quantity.
For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
save_weights_only: if ``True``, then only the model's weights will be
saved (``model.save_weights(filepath)``), else the full model
is saved (``model.save(filepath)``).
period: Interval (number of epochs) between checkpoints.
prefix: A string to put at the beginning of checkpoint filename.

.. warning::
This argument has been deprecated in v1.1 and will be removed in v1.3

Note:
For extra customization, ModelCheckpoint includes the following attributes:
Expand All @@ -122,7 +110,7 @@ class ModelCheckpoint(Callback):
MisconfigurationException:
If ``save_top_k`` is neither ``None`` nor more than or equal to ``-1``,
if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or
if ``mode`` is none of ``"min"``, ``"max"``, and ``"auto"``.
if ``mode`` is none of ``"min"`` or ``"max"``.
ValueError:
If ``trainer.save_checkpoint`` is ``None``.

Expand Down Expand Up @@ -166,9 +154,8 @@ def __init__(
save_last: Optional[bool] = None,
save_top_k: Optional[int] = None,
save_weights_only: bool = False,
mode: str = "auto",
mode: str = "min",
period: int = 1,
prefix: str = "",
):
super().__init__()
self.monitor = monitor
Expand All @@ -178,7 +165,6 @@ def __init__(
self.save_weights_only = save_weights_only
self.period = period
self._last_global_step_saved = -1
self.prefix = prefix
self.current_score = None
self.best_k_models = {}
self.kth_best_model_path = ""
Expand All @@ -188,12 +174,6 @@ def __init__(
self.save_function = None
self.warned_result_obj = False

if prefix:
rank_zero_warn(
'Argument `prefix` is deprecated in v1.1 and will be removed in v1.3.'
' Please prepend your prefix in `filename` instead.', DeprecationWarning
)

self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
self.__validate_init_configuration()
Expand Down Expand Up @@ -300,18 +280,8 @@ def __init_monitor_mode(self, monitor, mode):
"max": (-torch_inf, "max"),
}

if mode not in mode_dict and mode != 'auto':
raise MisconfigurationException(f"`mode` can be auto, {', '.join(mode_dict.keys())}, got {mode}")

# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
if mode == 'auto':
rank_zero_warn(
"mode='auto' is deprecated in v1.1 and will be removed in v1.3."
" Default value for mode with be 'min' in v1.3.", DeprecationWarning
)

_condition = monitor is not None and ("acc" in monitor or monitor.startswith("fmeasure"))
mode_dict['auto'] = ((-torch_inf, "max") if _condition else (torch_inf, "min"))
if mode not in mode_dict:
raise MisconfigurationException(f"`mode` can be {', '.join(mode_dict.keys())} but got {mode}")

self.kth_value, self.mode = mode_dict[mode]

Expand Down Expand Up @@ -410,7 +380,7 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
'step=0.ckpt'

"""
filename = self._format_checkpoint_name(self.filename, epoch, step, metrics, prefix=self.prefix)
filename = self._format_checkpoint_name(self.filename, epoch, step, metrics)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))

Expand Down Expand Up @@ -523,7 +493,6 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
trainer.current_epoch,
trainer.global_step,
ckpt_name_metrics,
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpo
)

if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min'))
self.trainer.callbacks.append(ModelCheckpoint())

def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
Expand Down
6 changes: 3 additions & 3 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,9 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt')

# with version
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test')
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name')
ckpt_name = ckpt.format_checkpoint_name(3, 2, {}, ver=3)
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'
assert ckpt_name == tmpdir / 'name-v3.ckpt'

# using slashes
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}')
Expand Down Expand Up @@ -975,5 +975,5 @@ def test_ckpt_version_after_rerun_same_trainer(tmpdir):


def test_model_checkpoint_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
with pytest.raises(MisconfigurationException, match="`mode` can be .* but got unknown_option"):
ModelCheckpoint(mode="unknown_option")
14 changes: 0 additions & 14 deletions tests/deprecated_api/test_remove_1-2.py

This file was deleted.

17 changes: 1 addition & 16 deletions tests/deprecated_api/test_remove_1-3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,10 @@

import pytest

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import LightningModule


def test_v1_3_0_deprecated_arguments(tmpdir):
with pytest.deprecated_call(match='will no longer be supported in v1.3'):
callback = ModelCheckpoint()
Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir)

# Deprecate prefix
with pytest.deprecated_call(match='will be removed in v1.3'):
ModelCheckpoint(prefix='temp')

# Deprecate auto mode
with pytest.deprecated_call(match='will be removed in v1.3'):
ModelCheckpoint(mode='auto')

with pytest.deprecated_call(match='will be removed in v1.3'):
EarlyStopping(mode='auto')

with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"):

Expand Down
34 changes: 8 additions & 26 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,34 +421,17 @@ def test_dp_output_reduce():


@pytest.mark.parametrize(
["save_top_k", "save_last", "file_prefix", "expected_files"],
"save_top_k,save_last,expected_files",
[
pytest.param(
-1,
False,
"",
{"epoch=4.ckpt", "epoch=3.ckpt", "epoch=2.ckpt", "epoch=1.ckpt", "epoch=0.ckpt"},
id="CASE K=-1 (all)",
),
pytest.param(1, False, "test_prefix", {"test_prefix-epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"),
pytest.param(2, False, "", {"epoch=4.ckpt", "epoch=2.ckpt"}, id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"),
pytest.param(
4,
False,
"",
{"epoch=1.ckpt", "epoch=4.ckpt", "epoch=3.ckpt", "epoch=2.ckpt"},
id="CASE K=4 (save all 4 base)",
),
pytest.param(
3,
False,
"", {"epoch=2.ckpt", "epoch=3.ckpt", "epoch=4.ckpt"},
id="CASE K=3 (save the 2nd, 3rd, 4th model)"
),
pytest.param(1, True, "", {"epoch=4.ckpt", "last.ckpt"}, id="CASE K=1 (save the 4th model and the last model)"),
pytest.param(-1, False, [f"epoch={i}.ckpt" for i in range(5)], id="CASE K=-1 (all)"),
pytest.param(1, False, {"epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"),
pytest.param(2, False, [f"epoch={i}.ckpt" for i in (2, 4)], id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"),
pytest.param(4, False, [f"epoch={i}.ckpt" for i in range(1, 5)], id="CASE K=4 (save all 4 base)"),
pytest.param(3, False, [f"epoch={i}.ckpt" for i in range(2, 5)], id="CASE K=3 (save the 2nd, 3rd, 4th model)"),
pytest.param(1, True, {"epoch=4.ckpt", "last.ckpt"}, id="CASE K=1 (save the 4th model and the last model)"),
],
)
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files):
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files):
"""Test ModelCheckpoint options."""

def mock_save_function(filepath, *args):
Expand All @@ -463,7 +446,6 @@ def mock_save_function(filepath, *args):
monitor='checkpoint_on',
save_top_k=save_top_k,
save_last=save_last,
prefix=file_prefix,
verbose=1
)
checkpoint_callback.save_function = mock_save_function
Expand Down