diff --git a/CHANGELOG.md b/CHANGELOG.md index 803ece1a51ed2..d42fcaaa7a46f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated bool values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656)) +- Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336)) + ### Removed diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 187ff237056a2..b8a4276a2d747 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os + +from typing import Union, Optional + from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -44,25 +47,31 @@ def on_trainer_init( # configure checkpoint callback # it is important that this is the last callback to run # pass through the required args to figure out defaults - checkpoint_callback = self.init_default_checkpoint_callback(checkpoint_callback) - if checkpoint_callback: - self.trainer.callbacks.append(checkpoint_callback) - - # TODO refactor codebase (tests) to not directly reach into these callbacks - self.trainer.checkpoint_callback = checkpoint_callback + self.configure_checkpoint_callbacks(checkpoint_callback) # init progress bar self.trainer._progress_bar_callback = self.configure_progress_bar( progress_bar_refresh_rate, process_position ) - def init_default_checkpoint_callback(self, checkpoint_callback): - if checkpoint_callback is True: - checkpoint_callback = ModelCheckpoint(dirpath=None, filename=None) - elif checkpoint_callback is False: - checkpoint_callback = None + def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]): + if isinstance(checkpoint_callback, ModelCheckpoint): + # TODO: deprecated, remove this block in v1.4.0 + rank_zero_warn( + "Passing a ModelCheckpoint instance to Trainer(checkpoint_callbacks=...)" + " is deprecated since v1.1 and will no longer be supported in v1.4.", + DeprecationWarning + ) + self.trainer.callbacks.append(checkpoint_callback) + + if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False: + raise MisconfigurationException( + "Trainer was configured with checkpoint_callback=False but found ModelCheckpoint" + " in callbacks list." + ) - return checkpoint_callback + if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True: + self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None)) def configure_progress_bar(self, refresh_rate=1, process_position=0): progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] @@ -83,3 +92,6 @@ def configure_progress_bar(self, refresh_rate=1, process_position=0): progress_bar_callback = None return progress_bar_callback + + def _trainer_has_checkpoint_callbacks(self): + return len(self.trainer.checkpoint_callbacks) > 0 diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index afb2f4cb5eb91..8d509d41d52bf 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -17,7 +17,7 @@ from argparse import ArgumentParser, Namespace from typing import List, Optional, Union, Type, TypeVar -from pytorch_lightning.callbacks import ProgressBarBase +from pytorch_lightning.callbacks import Callback, ProgressBarBase, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector @@ -46,6 +46,7 @@ class TrainerProperties(ABC): _weights_save_path: str model_connector: ModelConnector checkpoint_connector: CheckpointConnector + callbacks: List[Callback] @property def use_amp(self) -> bool: @@ -187,6 +188,20 @@ def weights_save_path(self) -> str: return os.path.normpath(self._weights_save_path) return self._weights_save_path + @property + def checkpoint_callback(self) -> Optional[ModelCheckpoint]: + """ + The first checkpoint callback in the Trainer.callbacks list, or ``None`` if + no checkpoint callbacks exist. + """ + callbacks = self.checkpoint_callbacks + return callbacks[0] if len(callbacks) > 0 else None + + @property + def checkpoint_callbacks(self) -> List[ModelCheckpoint]: + """ A list of all instances of ModelCheckpoint found in the Trainer.callbacks list. """ + return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + def save_checkpoint(self, filepath, weights_only: bool = False): self.checkpoint_connector.save_checkpoint(filepath, weights_only) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 337eb4c4ed567..008633273a0d1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -85,7 +85,7 @@ class Trainer( def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, - checkpoint_callback: Union[ModelCheckpoint, bool] = True, + checkpoint_callback: bool = True, callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, @@ -169,7 +169,12 @@ def __init__( callbacks: Add a list of callbacks. - checkpoint_callback: Callback for checkpointing. + checkpoint_callback: If ``True``, enable checkpointing. + It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``. + + .. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since + v1.1.0 and will be unsupported from v1.4.0. check_val_every_n_epoch: Check val every n train epochs. @@ -297,7 +302,6 @@ def __init__( # init callbacks # Declare attributes to be set in callback_connector on_trainer_init - self.checkpoint_callback: Union[ModelCheckpoint, bool] = checkpoint_callback self.callback_connector.on_trainer_init( callbacks, checkpoint_callback, diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 87783fbde5d1f..10de8a2d289e5 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -144,7 +144,6 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial): trainer.weights_summary = None # not needed before full run trainer.logger = DummyLogger() trainer.callbacks = [] # not needed before full run - trainer.checkpoint_callback = False # required for saving trainer.limit_train_batches = 1.0 trainer.optimizers, trainer.schedulers = [], [] # required for saving trainer.model = model # required for saving @@ -157,7 +156,6 @@ def __scale_batch_restore_params(trainer): trainer.weights_summary = trainer.__dumped_params['weights_summary'] trainer.logger = trainer.__dumped_params['logger'] trainer.callbacks = trainer.__dumped_params['callbacks'] - trainer.checkpoint_callback = trainer.__dumped_params['checkpoint_callback'] trainer.auto_scale_batch_size = trainer.__dumped_params['auto_scale_batch_size'] trainer.limit_train_batches = trainer.__dumped_params['limit_train_batches'] trainer.model = trainer.__dumped_params['model'] diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index d0ab33df8e1b8..3107f9f44824a 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -155,9 +155,6 @@ def lr_find( if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() - # Disable standard checkpoint & early stopping - trainer.checkpoint_callback = False - # Required for saving the model trainer.optimizers, trainer.schedulers = [], [], trainer.model = model @@ -212,7 +209,6 @@ def __lr_finder_restore_params(trainer, model): trainer.logger = trainer.__dumped_params['logger'] trainer.callbacks = trainer.__dumped_params['callbacks'] trainer.max_steps = trainer.__dumped_params['max_steps'] - trainer.checkpoint_callback = trainer.__dumped_params['checkpoint_callback'] model.configure_optimizers = trainer.__dumped_params['configure_optimizers'] del trainer.__dumped_params diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 1634b73424dd1..3bc2ca436ec15 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -743,3 +743,43 @@ def test_filepath_decomposition_dirpath_filename(tmpdir, filepath, dirpath, file assert mc_cb.dirpath == dirpath assert mc_cb.filename == filename + + +def test_configure_model_checkpoint(tmpdir): + """ Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """ + kwargs = dict(default_root_dir=tmpdir) + callback1 = ModelCheckpoint() + callback2 = ModelCheckpoint() + + # no callbacks + trainer = Trainer(checkpoint_callback=False, callbacks=[], **kwargs) + assert not any(isinstance(c, ModelCheckpoint) for c in trainer.callbacks) + assert trainer.checkpoint_callback is None + + # default configuration + trainer = Trainer(checkpoint_callback=True, callbacks=[], **kwargs) + assert len([c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)]) == 1 + assert isinstance(trainer.checkpoint_callback, ModelCheckpoint) + + # custom callback passed to callbacks list, checkpoint_callback=True is ignored + trainer = Trainer(checkpoint_callback=True, callbacks=[callback1], **kwargs) + assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1] + assert trainer.checkpoint_callback == callback1 + + # multiple checkpoint callbacks + trainer = Trainer(callbacks=[callback1, callback2], **kwargs) + assert trainer.checkpoint_callback == callback1 + assert trainer.checkpoint_callbacks == [callback1, callback2] + + with pytest.warns(DeprecationWarning, match='will no longer be supported in v1.4'): + trainer = Trainer(checkpoint_callback=callback1, callbacks=[], **kwargs) + assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1] + assert trainer.checkpoint_callback == callback1 + + with pytest.warns(DeprecationWarning, match="will no longer be supported in v1.4"): + trainer = Trainer(checkpoint_callback=callback1, callbacks=[callback2], **kwargs) + assert trainer.checkpoint_callback == callback2 + assert trainer.checkpoint_callbacks == [callback2, callback1] + + with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"): + Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 862294e64765f..848d6127c4cdb 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -15,6 +15,7 @@ import logging as log import os import pickle +from copy import deepcopy import cloudpickle import pytest @@ -24,7 +25,7 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils -from pytorch_lightning import Trainer, LightningModule, Callback +from pytorch_lightning import Trainer, LightningModule, Callback, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from tests.base import EvalModelTemplate, GenericEvalModelTemplate, TrialMNIST @@ -51,24 +52,90 @@ def on_train_end(self, trainer, pl_module): self._check_properties(trainer, pl_module) -def test_resume_from_checkpoint(tmpdir): +def test_model_properties_resume_from_checkpoint(tmpdir): """ Test that properties like `current_epoch` and `global_step` in model and trainer are always the same. """ model = EvalModelTemplate() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) trainer_args = dict( default_root_dir=tmpdir, - max_epochs=2, + max_epochs=1, logger=False, - checkpoint_callback=checkpoint_callback, - callbacks=[ModelTrainerPropertyParity()] # this performs the assertions + callbacks=[checkpoint_callback, ModelTrainerPropertyParity()] # this performs the assertions ) trainer = Trainer(**trainer_args) trainer.fit(model) + + trainer_args.update(max_epochs=2) trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt")) trainer.fit(model) +class CaptureCallbacksBeforeTraining(Callback): + callbacks = [] + + def on_train_start(self, trainer, pl_module): + self.callbacks = deepcopy(trainer.callbacks) + + +def test_callbacks_state_resume_from_checkpoint(tmpdir): + """ Test that resuming from a checkpoint restores callbacks that persist state. """ + model = EvalModelTemplate() + callback_capture = CaptureCallbacksBeforeTraining() + + def get_trainer_args(): + checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) + trainer_args = dict( + default_root_dir=tmpdir, + max_steps=1, + logger=False, + callbacks=[ + checkpoint, + callback_capture, + ] + ) + assert checkpoint.best_model_path == "" + assert checkpoint.best_model_score == 0 + return trainer_args + + # initial training + trainer = Trainer(**get_trainer_args()) + trainer.fit(model) + callbacks_before_resume = deepcopy(trainer.callbacks) + + # resumed training + trainer = Trainer(**get_trainer_args(), resume_from_checkpoint=str(tmpdir / "last.ckpt")) + trainer.fit(model) + + assert len(callbacks_before_resume) == len(callback_capture.callbacks) + + for before, after in zip(callbacks_before_resume, callback_capture.callbacks): + if isinstance(before, ModelCheckpoint): + assert before.best_model_path == after.best_model_path + assert before.best_model_score == after.best_model_score + + +def test_callbacks_references_resume_from_checkpoint(tmpdir): + """ Test that resuming from a checkpoint sets references as expected. """ + model = EvalModelTemplate() + args = {'default_root_dir': tmpdir, 'max_steps': 1, 'logger': False} + + # initial training + checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) + trainer = Trainer(**args, callbacks=[checkpoint]) + assert checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback + trainer.fit(model) + + # resumed training + new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) + # pass in a new checkpoint object, which should take + # precedence over the one in the last.ckpt file + trainer = Trainer(**args, callbacks=[new_checkpoint], resume_from_checkpoint=str(tmpdir / "last.ckpt")) + assert checkpoint is not new_checkpoint + assert new_checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback + trainer.fit(model) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_running_test_pretrained_model_distrib_dp(tmpdir): """Verify `test()` on pretrained model.""" diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 60f13383d3777..67f38568e2103 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -15,6 +15,12 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException +def test_tbd_remove_in_v1_4_0(tmpdir): + with pytest.deprecated_call(match='will no longer be supported in v1.4'): + callback = ModelCheckpoint() + Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir) + + def test_tbd_remove_in_v1_2_0(): with pytest.deprecated_call(match='will be removed in v1.2'): checkpoint_cb = ModelCheckpoint(filepath='.')