Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `ParallelPlugin.torch_distributed_backend` in favor of `DDPStrategy.process_group_backend` property ([#11745](https://github.com/PyTorchLightning/pytorch-lightning/pull/11745))


- Deprecated `ModelCheckpoint.save_checkpoint` in favor of `Trainer.save_checkpoint` ([#12456](https://github.com/PyTorchLightning/pytorch-lightning/pull/12456))


- Deprecated `Trainer.devices` in favor of `Trainer.num_devices` and `Trainer.device_ids` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151))


Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _name, _version
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -352,7 +352,10 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
"""
# TODO: unused method. deprecate it
rank_zero_deprecation(
f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8."
" Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint."
)
monitor_candidates = self._monitor_candidates(trainer)
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
Expand Down
11 changes: 6 additions & 5 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pathlib import Path
from typing import Union
from unittest import mock
from unittest.mock import call, MagicMock, Mock, patch
from unittest.mock import call, Mock, patch

import cloudpickle
import pytest
Expand Down Expand Up @@ -834,7 +834,7 @@ def validation_epoch_end(self, outputs):
val_check_interval=1.0,
max_epochs=len(monitor),
)
trainer.save_checkpoint = MagicMock()
trainer.save_checkpoint = Mock()

trainer.fit(model)

Expand Down Expand Up @@ -1275,9 +1275,10 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir):
def test_last_global_step_saved():
# this should not save anything
model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo")
trainer = MagicMock()
trainer.callback_metrics = {"foo": 123}
model_checkpoint.save_checkpoint(trainer)
trainer = Mock()
monitor_candidates = {"foo": 123}
model_checkpoint._save_topk_checkpoint(trainer, monitor_candidates)
model_checkpoint._save_last_checkpoint(trainer, monitor_candidates)
assert model_checkpoint._last_global_step_saved == 0


Expand Down
10 changes: 10 additions & 0 deletions tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pytorch_lightning
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase, LoggerCollection
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
Expand Down Expand Up @@ -1055,6 +1056,15 @@ def test_trainer_data_parallel_device_ids(monkeypatch, trainer_kwargs, expected_
assert trainer.data_parallel_device_ids == expected_data_parallel_device_ids


def test_deprecated_mc_save_checkpoint():
mc = ModelCheckpoint()
trainer = Trainer()
with mock.patch.object(trainer, "save_checkpoint"), pytest.deprecated_call(
match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6"
):
mc.save_checkpoint(trainer)


def test_v1_8_0_callback_on_load_checkpoint_hook(tmpdir):
class TestCallbackLoadHook(Callback):
def on_load_checkpoint(self, trainer, pl_module, callback_state):
Expand Down