Skip to content

Commit 7d1288b

Browse files
awaelchlicarmoccaJeff Yang
authored andcommitted
deprecate passing ModelCheckpoint instance to Trainer(checkpoint_callback=...) (#4336)
* first attempt * update tests * support multiple * test bugfix * changelog * pep * pep * import order * import * improve test for resuming * test * update test * add references test Co-authored-by: Carlos Mocholí <[email protected]> * docstring suggestion deprecation Co-authored-by: Jeff Yang <[email protected]> * paramref Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jeff Yang <[email protected]> (cherry picked from commit d1234c5)
1 parent e94d48c commit 7d1288b

File tree

9 files changed

+167
-27
lines changed

9 files changed

+167
-27
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5757
- Deprecated bool values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))
5858

5959

60+
- Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336))
61+
6062
### Removed
6163

6264

pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
16+
from typing import Union, Optional
17+
1518
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
1619
from pytorch_lightning.utilities import rank_zero_warn
1720
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -44,25 +47,31 @@ def on_trainer_init(
4447
# configure checkpoint callback
4548
# it is important that this is the last callback to run
4649
# pass through the required args to figure out defaults
47-
checkpoint_callback = self.init_default_checkpoint_callback(checkpoint_callback)
48-
if checkpoint_callback:
49-
self.trainer.callbacks.append(checkpoint_callback)
50-
51-
# TODO refactor codebase (tests) to not directly reach into these callbacks
52-
self.trainer.checkpoint_callback = checkpoint_callback
50+
self.configure_checkpoint_callbacks(checkpoint_callback)
5351

5452
# init progress bar
5553
self.trainer._progress_bar_callback = self.configure_progress_bar(
5654
progress_bar_refresh_rate, process_position
5755
)
5856

59-
def init_default_checkpoint_callback(self, checkpoint_callback):
60-
if checkpoint_callback is True:
61-
checkpoint_callback = ModelCheckpoint(dirpath=None, filename=None)
62-
elif checkpoint_callback is False:
63-
checkpoint_callback = None
57+
def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]):
58+
if isinstance(checkpoint_callback, ModelCheckpoint):
59+
# TODO: deprecated, remove this block in v1.4.0
60+
rank_zero_warn(
61+
"Passing a ModelCheckpoint instance to Trainer(checkpoint_callbacks=...)"
62+
" is deprecated since v1.1 and will no longer be supported in v1.4.",
63+
DeprecationWarning
64+
)
65+
self.trainer.callbacks.append(checkpoint_callback)
66+
67+
if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
68+
raise MisconfigurationException(
69+
"Trainer was configured with checkpoint_callback=False but found ModelCheckpoint"
70+
" in callbacks list."
71+
)
6472

65-
return checkpoint_callback
73+
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
74+
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None))
6675

6776
def configure_progress_bar(self, refresh_rate=1, process_position=0):
6877
progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)]
@@ -83,3 +92,6 @@ def configure_progress_bar(self, refresh_rate=1, process_position=0):
8392
progress_bar_callback = None
8493

8594
return progress_bar_callback
95+
96+
def _trainer_has_checkpoint_callbacks(self):
97+
return len(self.trainer.checkpoint_callbacks) > 0

pytorch_lightning/trainer/properties.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from argparse import ArgumentParser, Namespace
1818
from typing import List, Optional, Union, Type, TypeVar
1919

20-
from pytorch_lightning.callbacks import ProgressBarBase
20+
from pytorch_lightning.callbacks import Callback, ProgressBarBase, ModelCheckpoint
2121
from pytorch_lightning.core.lightning import LightningModule
2222
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
2323
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
@@ -46,6 +46,7 @@ class TrainerProperties(ABC):
4646
_weights_save_path: str
4747
model_connector: ModelConnector
4848
checkpoint_connector: CheckpointConnector
49+
callbacks: List[Callback]
4950

5051
@property
5152
def use_amp(self) -> bool:
@@ -187,6 +188,20 @@ def weights_save_path(self) -> str:
187188
return os.path.normpath(self._weights_save_path)
188189
return self._weights_save_path
189190

191+
@property
192+
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
193+
"""
194+
The first checkpoint callback in the Trainer.callbacks list, or ``None`` if
195+
no checkpoint callbacks exist.
196+
"""
197+
callbacks = self.checkpoint_callbacks
198+
return callbacks[0] if len(callbacks) > 0 else None
199+
200+
@property
201+
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
202+
""" A list of all instances of ModelCheckpoint found in the Trainer.callbacks list. """
203+
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
204+
190205
def save_checkpoint(self, filepath, weights_only: bool = False):
191206
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
192207

pytorch_lightning/trainer/trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class Trainer(
8585
def __init__(
8686
self,
8787
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
88-
checkpoint_callback: Union[ModelCheckpoint, bool] = True,
88+
checkpoint_callback: bool = True,
8989
callbacks: Optional[List[Callback]] = None,
9090
default_root_dir: Optional[str] = None,
9191
gradient_clip_val: float = 0,
@@ -169,7 +169,12 @@ def __init__(
169169
170170
callbacks: Add a list of callbacks.
171171
172-
checkpoint_callback: Callback for checkpointing.
172+
checkpoint_callback: If ``True``, enable checkpointing.
173+
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
174+
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``.
175+
176+
.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
177+
v1.1.0 and will be unsupported from v1.4.0.
173178
174179
check_val_every_n_epoch: Check val every n train epochs.
175180
@@ -297,7 +302,6 @@ def __init__(
297302

298303
# init callbacks
299304
# Declare attributes to be set in callback_connector on_trainer_init
300-
self.checkpoint_callback: Union[ModelCheckpoint, bool] = checkpoint_callback
301305
self.callback_connector.on_trainer_init(
302306
callbacks,
303307
checkpoint_callback,

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial):
144144
trainer.weights_summary = None # not needed before full run
145145
trainer.logger = DummyLogger()
146146
trainer.callbacks = [] # not needed before full run
147-
trainer.checkpoint_callback = False # required for saving
148147
trainer.limit_train_batches = 1.0
149148
trainer.optimizers, trainer.schedulers = [], [] # required for saving
150149
trainer.model = model # required for saving
@@ -157,7 +156,6 @@ def __scale_batch_restore_params(trainer):
157156
trainer.weights_summary = trainer.__dumped_params['weights_summary']
158157
trainer.logger = trainer.__dumped_params['logger']
159158
trainer.callbacks = trainer.__dumped_params['callbacks']
160-
trainer.checkpoint_callback = trainer.__dumped_params['checkpoint_callback']
161159
trainer.auto_scale_batch_size = trainer.__dumped_params['auto_scale_batch_size']
162160
trainer.limit_train_batches = trainer.__dumped_params['limit_train_batches']
163161
trainer.model = trainer.__dumped_params['model']

pytorch_lightning/tuner/lr_finder.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,6 @@ def lr_find(
155155
if trainer.progress_bar_callback:
156156
trainer.progress_bar_callback.disable()
157157

158-
# Disable standard checkpoint & early stopping
159-
trainer.checkpoint_callback = False
160-
161158
# Required for saving the model
162159
trainer.optimizers, trainer.schedulers = [], [],
163160
trainer.model = model
@@ -212,7 +209,6 @@ def __lr_finder_restore_params(trainer, model):
212209
trainer.logger = trainer.__dumped_params['logger']
213210
trainer.callbacks = trainer.__dumped_params['callbacks']
214211
trainer.max_steps = trainer.__dumped_params['max_steps']
215-
trainer.checkpoint_callback = trainer.__dumped_params['checkpoint_callback']
216212
model.configure_optimizers = trainer.__dumped_params['configure_optimizers']
217213
del trainer.__dumped_params
218214

tests/checkpointing/test_model_checkpoint.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,3 +746,43 @@ def test_filepath_decomposition_dirpath_filename(tmpdir, filepath, dirpath, file
746746

747747
assert mc_cb.dirpath == dirpath
748748
assert mc_cb.filename == filename
749+
750+
751+
def test_configure_model_checkpoint(tmpdir):
752+
""" Test all valid and invalid ways a checkpoint callback can be passed to the Trainer. """
753+
kwargs = dict(default_root_dir=tmpdir)
754+
callback1 = ModelCheckpoint()
755+
callback2 = ModelCheckpoint()
756+
757+
# no callbacks
758+
trainer = Trainer(checkpoint_callback=False, callbacks=[], **kwargs)
759+
assert not any(isinstance(c, ModelCheckpoint) for c in trainer.callbacks)
760+
assert trainer.checkpoint_callback is None
761+
762+
# default configuration
763+
trainer = Trainer(checkpoint_callback=True, callbacks=[], **kwargs)
764+
assert len([c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)]) == 1
765+
assert isinstance(trainer.checkpoint_callback, ModelCheckpoint)
766+
767+
# custom callback passed to callbacks list, checkpoint_callback=True is ignored
768+
trainer = Trainer(checkpoint_callback=True, callbacks=[callback1], **kwargs)
769+
assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1]
770+
assert trainer.checkpoint_callback == callback1
771+
772+
# multiple checkpoint callbacks
773+
trainer = Trainer(callbacks=[callback1, callback2], **kwargs)
774+
assert trainer.checkpoint_callback == callback1
775+
assert trainer.checkpoint_callbacks == [callback1, callback2]
776+
777+
with pytest.warns(DeprecationWarning, match='will no longer be supported in v1.4'):
778+
trainer = Trainer(checkpoint_callback=callback1, callbacks=[], **kwargs)
779+
assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1]
780+
assert trainer.checkpoint_callback == callback1
781+
782+
with pytest.warns(DeprecationWarning, match="will no longer be supported in v1.4"):
783+
trainer = Trainer(checkpoint_callback=callback1, callbacks=[callback2], **kwargs)
784+
assert trainer.checkpoint_callback == callback2
785+
assert trainer.checkpoint_callbacks == [callback2, callback1]
786+
787+
with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"):
788+
Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs)

tests/models/test_restore.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging as log
1616
import os
1717
import pickle
18+
from copy import deepcopy
1819

1920
import cloudpickle
2021
import pytest
@@ -24,7 +25,7 @@
2425

2526
import tests.base.develop_pipelines as tpipes
2627
import tests.base.develop_utils as tutils
27-
from pytorch_lightning import Trainer, LightningModule, Callback
28+
from pytorch_lightning import Trainer, LightningModule, Callback, seed_everything
2829
from pytorch_lightning.callbacks import ModelCheckpoint
2930
from tests.base import EvalModelTemplate, GenericEvalModelTemplate, TrialMNIST
3031

@@ -51,24 +52,90 @@ def on_train_end(self, trainer, pl_module):
5152
self._check_properties(trainer, pl_module)
5253

5354

54-
def test_resume_from_checkpoint(tmpdir):
55+
def test_model_properties_resume_from_checkpoint(tmpdir):
5556
""" Test that properties like `current_epoch` and `global_step`
5657
in model and trainer are always the same. """
5758
model = EvalModelTemplate()
5859
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
5960
trainer_args = dict(
6061
default_root_dir=tmpdir,
61-
max_epochs=2,
62+
max_epochs=1,
6263
logger=False,
63-
checkpoint_callback=checkpoint_callback,
64-
callbacks=[ModelTrainerPropertyParity()] # this performs the assertions
64+
callbacks=[checkpoint_callback, ModelTrainerPropertyParity()] # this performs the assertions
6565
)
6666
trainer = Trainer(**trainer_args)
6767
trainer.fit(model)
68+
69+
trainer_args.update(max_epochs=2)
6870
trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt"))
6971
trainer.fit(model)
7072

7173

74+
class CaptureCallbacksBeforeTraining(Callback):
75+
callbacks = []
76+
77+
def on_train_start(self, trainer, pl_module):
78+
self.callbacks = deepcopy(trainer.callbacks)
79+
80+
81+
def test_callbacks_state_resume_from_checkpoint(tmpdir):
82+
""" Test that resuming from a checkpoint restores callbacks that persist state. """
83+
model = EvalModelTemplate()
84+
callback_capture = CaptureCallbacksBeforeTraining()
85+
86+
def get_trainer_args():
87+
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
88+
trainer_args = dict(
89+
default_root_dir=tmpdir,
90+
max_steps=1,
91+
logger=False,
92+
callbacks=[
93+
checkpoint,
94+
callback_capture,
95+
]
96+
)
97+
assert checkpoint.best_model_path == ""
98+
assert checkpoint.best_model_score == 0
99+
return trainer_args
100+
101+
# initial training
102+
trainer = Trainer(**get_trainer_args())
103+
trainer.fit(model)
104+
callbacks_before_resume = deepcopy(trainer.callbacks)
105+
106+
# resumed training
107+
trainer = Trainer(**get_trainer_args(), resume_from_checkpoint=str(tmpdir / "last.ckpt"))
108+
trainer.fit(model)
109+
110+
assert len(callbacks_before_resume) == len(callback_capture.callbacks)
111+
112+
for before, after in zip(callbacks_before_resume, callback_capture.callbacks):
113+
if isinstance(before, ModelCheckpoint):
114+
assert before.best_model_path == after.best_model_path
115+
assert before.best_model_score == after.best_model_score
116+
117+
118+
def test_callbacks_references_resume_from_checkpoint(tmpdir):
119+
""" Test that resuming from a checkpoint sets references as expected. """
120+
model = EvalModelTemplate()
121+
args = {'default_root_dir': tmpdir, 'max_steps': 1, 'logger': False}
122+
123+
# initial training
124+
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
125+
trainer = Trainer(**args, callbacks=[checkpoint])
126+
assert checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback
127+
trainer.fit(model)
128+
129+
# resumed training
130+
new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
131+
# pass in a new checkpoint object, which should take
132+
# precedence over the one in the last.ckpt file
133+
trainer = Trainer(**args, callbacks=[new_checkpoint], resume_from_checkpoint=str(tmpdir / "last.ckpt"))
134+
assert checkpoint is not new_checkpoint
135+
assert new_checkpoint is trainer.callbacks[0] is trainer.checkpoint_callback
136+
trainer.fit(model)
137+
138+
72139
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
73140
def test_running_test_pretrained_model_distrib_dp(tmpdir):
74141
"""Verify `test()` on pretrained model."""

tests/test_deprecated.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1616

1717

18+
def test_tbd_remove_in_v1_4_0(tmpdir):
19+
with pytest.deprecated_call(match='will no longer be supported in v1.4'):
20+
callback = ModelCheckpoint()
21+
Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir)
22+
23+
1824
def test_tbd_remove_in_v1_2_0():
1925
with pytest.deprecated_call(match='will be removed in v1.2'):
2026
checkpoint_cb = ModelCheckpoint(filepath='.')

0 commit comments

Comments
 (0)