Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
89f284d
Fix some test errors
Mar 23, 2021
80cfbff
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 23, 2021
536c132
checkpoint consolidation
Mar 24, 2021
f172101
Update ddp_spawn.py
shuyingsunshine21 Mar 24, 2021
bf70e43
Update test_metric_result_integration.py
shuyingsunshine21 Mar 24, 2021
ea74906
Update test_results.py
shuyingsunshine21 Mar 24, 2021
a9aae99
Update utils.py
shuyingsunshine21 Mar 24, 2021
70fe5da
Update utils.py
shuyingsunshine21 Mar 24, 2021
0d23d75
Update test_all_gather_grad.py
shuyingsunshine21 Mar 24, 2021
ca6f98b
Update test_all_gather_grad.py
shuyingsunshine21 Mar 24, 2021
c5053da
Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkp…
shuyingsunshine21 Mar 24, 2021
9d4a2b8
Update test_results.py
shuyingsunshine21 Mar 24, 2021
7635b4f
Revert "Update test_results.py"
shuyingsunshine21 Mar 24, 2021
d64f90c
Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine2…
shuyingsunshine21 Mar 24, 2021
dcdcd29
Revert "Update test_all_gather_grad.py"
shuyingsunshine21 Mar 24, 2021
8651d54
Revert "Update utils.py"
shuyingsunshine21 Mar 24, 2021
15f4b9e
Revert "Update utils.py"
shuyingsunshine21 Mar 24, 2021
250d0aa
Revert "Update test_results.py"
shuyingsunshine21 Mar 24, 2021
6c095b2
Revert "Update test_metric_result_integration.py"
shuyingsunshine21 Mar 24, 2021
8222dc9
Revert "Update ddp_spawn.py"
shuyingsunshine21 Mar 24, 2021
3a9fde9
Revert "checkpoint consolidation"
shuyingsunshine21 Mar 24, 2021
7a369f4
Revert "Revert "checkpoint consolidation""
shuyingsunshine21 Mar 24, 2021
b4a0b9e
Revert "Revert "Revert "checkpoint consolidation"""
shuyingsunshine21 Mar 24, 2021
5cf1db1
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 24, 2021
0ce7e05
Revert "Revert "Update ddp_spawn.py""
shuyingsunshine21 Mar 24, 2021
fe9736d
Revert "Revert "Update test_metric_result_integration.py""
shuyingsunshine21 Mar 24, 2021
c314ef6
Revert "Revert "Update test_results.py""
shuyingsunshine21 Mar 24, 2021
c3feda0
Revert "Revert "Update utils.py""
shuyingsunshine21 Mar 24, 2021
c759477
Revert "Revert "Update test_all_gather_grad.py""
shuyingsunshine21 Mar 24, 2021
7a8e540
Merge branch 'master' of https://github.com/shuyingsunshine21/pytorch…
Mar 24, 2021
ab8b849
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 24, 2021
4e67db2
modify distributed environment to make test pass
Mar 24, 2021
4211f0c
consolidate training loop checkpoints v1
Mar 24, 2021
0bf5398
consolidate training loop checkpoints v2
Mar 25, 2021
9db02f8
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 25, 2021
67b6188
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 25, 2021
bbb5f83
consolidate training loop checkpoints v3
Mar 25, 2021
db37add
consolidate training loop checkpoints v4
Mar 25, 2021
51aefb8
consolidate training end model checkpoint
Mar 26, 2021
78ea90b
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 26, 2021
e013b19
remove distributed environment hack
Mar 26, 2021
6f56167
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 26, 2021
d90cd64
consolidate on_train_end only
Mar 26, 2021
707f987
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 29, 2021
5a2a967
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 10, 2021
7a5a925
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 10, 2021
ccd771f
rebase
Apr 10, 2021
70ebc9f
modify
Apr 10, 2021
ab5d5ca
modify
Apr 10, 2021
ddf76c4
add one more unittest for end of training with invalid monitor
Apr 10, 2021
9f83b6f
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 10, 2021
48a34e8
add changelog
Apr 10, 2021
af7806e
rebase
Apr 11, 2021
f9616a3
comments, call _save_last_checkpoint directly for train end
Apr 13, 2021
2a9e882
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))


- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))

- Fixed model checkpointing at end of training ([#6671](https://github.com/PyTorchLightning/pytorch-lightning/pull/6671))


- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))

Expand Down
34 changes: 30 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ class ModelCheckpoint(Callback):
This argument has been deprecated in v1.3 and will be removed in v1.5.

Use ``every_n_val_epochs`` instead.
trigger_on_train_end: Whether to trigger the save_checkpoint at the end of training.
By default, it is turned off. If it is turned on, the model will be saved to file `last.ckpt`.


Note:
For extra customization, ModelCheckpoint includes the following attributes:
Expand Down Expand Up @@ -186,6 +189,7 @@ def __init__(
every_n_train_steps: Optional[int] = None,
every_n_val_epochs: Optional[int] = None,
period: Optional[int] = None,
trigger_on_train_end: bool = False,
):
super().__init__()
self.monitor = monitor
Expand All @@ -205,7 +209,7 @@ def __init__(

self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, period)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, period, trigger_on_train_end)
self.__validate_init_configuration()

def on_pretrain_routine_start(self, trainer, pl_module):
Expand Down Expand Up @@ -239,6 +243,22 @@ def on_validation_end(self, trainer, pl_module) -> None:
return
self.save_checkpoint(trainer)

def on_train_end(self, trainer, *args, **kwargs) -> None:
"""
checkpoints can be saved at the end of the trianing
"""
if not self._trigger_on_train_end:
return
# as we advance one step at end of training, we use global_step - 1
# to avoid saving duplicates
trainer.global_step -= 1
if (not self._should_skip_saving_checkpoint(trainer) and trainer.checkpoint_connector.has_trained):
if self.save_last and self.verbose:
rank_zero_info("Saving last checkpoint...")
monitor_candidates = self._monitor_candidates(trainer)
self._save_last_checkpoint(trainer, monitor_candidates)
trainer.global_step += 1

Comment on lines +246 to +261
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def on_train_end(self, trainer, *args, **kwargs) -> None:
"""
checkpoints can be saved at the end of the trianing
"""
if not self._trigger_on_train_end:
return
# as we advance one step at end of training, we use global_step - 1
# to avoid saving duplicates
trainer.global_step -= 1
if (not self._should_skip_saving_checkpoint(trainer) and trainer.checkpoint_connector.has_trained):
if self.save_last and self.verbose:
rank_zero_info("Saving last checkpoint...")
monitor_candidates = self._monitor_candidates(trainer)
self._save_last_checkpoint(trainer, monitor_candidates)
trainer.global_step += 1
def on_train_end(self, trainer, pl_module) -> None:
"""Save a checkpoint at the very end of training.
This will only save a checkpoint if `save_last` is also enabled
as the monitor metrics produced by training or validation steps or end of epochs
is not guaranteed to be available at this stage.
"""
if self._should_skip_saving_checkpoint(trainer) or not trainer.checkpoint_connector.has_trained:
return
initial_save_last = self.save_last
if self._save_on_train_end and not self.save_last:
rank_zero_warn(
"Requested to save a checkpoint at the end of training but save_last is not set. Temporarily setting save_last=True to save."
)
self.save_last = True
if self.verbose:
rank_zero_info("Saving last checkpoint...")
# as we advance one step at end of training, we use global_step - 1
# to avoid saving duplicates
trainer.global_step -= 1
monitor_candidates = self._monitor_candidates(trainer)
self._save_last_checkpoint(trainer, monitor_candidates)
trainer.global_step += 1
self.save_last = initial_save_last

what do you think of this?

also what should happen if save_last is not set to True? should save on train end take precedence and temporarily override it? should we move the save_last check out of _save_last_checkpoint so the property needs to be checked first before we call save_last_checkpoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the original thought is save_on_train_end is dependent on save_last, so only enabled when save_last is set also. What you proposed is to always enable is regardless of save_last. To make save_on_train_end as an independent triggering, makes sense also.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca @awaelchli what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer the current implementation, maybe throwing a warning so people know they should set both.

def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
return {
"monitor": self.monitor,
Expand Down Expand Up @@ -286,6 +306,7 @@ def save_checkpoint(self, trainer, unused: Optional = None):

def _should_skip_saving_checkpoint(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerState

return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
Expand Down Expand Up @@ -357,7 +378,11 @@ def __init_monitor_mode(self, monitor, mode):
self.kth_value, self.mode = mode_dict[mode]

def __init_triggers(
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int]
self,
every_n_train_steps: Optional[int],
every_n_val_epochs: Optional[int],
period: Optional[int],
trigger_on_train_end: bool,
) -> None:

# Default to running once after each validation epoch if neither
Expand All @@ -379,6 +404,7 @@ def __init_triggers(
self._every_n_val_epochs = period

self._period = self._every_n_val_epochs
self._trigger_on_train_end = trigger_on_train_end

@property
def period(self) -> Optional[int]:
Expand Down Expand Up @@ -585,11 +611,10 @@ def _add_backward_monitor_support(self, trainer):

def _validate_monitor_key(self, trainer):
metrics = trainer.logger_connector.callback_metrics

# validate metric
if self.monitor is not None and not self._is_valid_monitor_key(metrics):
m = (
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics "
f" {list(metrics.keys())}. "
f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?"
)
Expand Down Expand Up @@ -618,6 +643,7 @@ def _monitor_candidates(self, trainer):
return monitor_candidates

def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]):

if not self.save_last:
return

Expand Down
12 changes: 1 addition & 11 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -112,12 +111,6 @@ def on_train_end(self):
return
self._teardown_already_run = True

# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
# when a checkpoint was saved at the last step
self.trainer.global_step -= 1
self.check_checkpoint_callback(should_update=True, is_last=True)
self.trainer.global_step += 1

# hook
self.trainer.call_hook("on_train_end")

Expand All @@ -141,9 +134,6 @@ def check_checkpoint_callback(self, should_update, is_last=False):
if should_update and self.trainer.checkpoint_connector.has_trained:
callbacks = self.trainer.checkpoint_callbacks

if is_last and any(cb.save_last and cb.verbose for cb in callbacks):
rank_zero_info("Saving latest checkpoint...")

model = self.trainer.lightning_module

for cb in callbacks:
Expand Down
4 changes: 2 additions & 2 deletions tests/checkpointing/test_checkpoint_callback_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_mc_called(tmpdir):
@mock.patch('torch.save')
@pytest.mark.parametrize(
['epochs', 'val_check_interval', 'expected'],
[(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)],
[(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 6)],
)
def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_interval: float, expected: int):

Expand All @@ -73,7 +73,7 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter
(1, 1, 1.0, 1),
(2, 2, 1.0, 2),
(2, 1, 0.25, 4),
(2, 2, 0.3, 7),
(2, 2, 0.3, 6),
])
def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int):

Expand Down
108 changes: 91 additions & 17 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def test_model_checkpoint_file_extension(tmpdir):
dirpath=tmpdir,
save_top_k=1,
save_last=True,
trigger_on_train_end=True,
)
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -460,8 +461,7 @@ def test_model_checkpoint_file_extension(tmpdir):
logger=False,
)
trainer.fit(model)

expected = ['epoch=0-step=0.tpkc', 'last.tpkc']
expected = ['last.tpkc']
assert set(expected) == set(os.listdir(tmpdir))


Expand Down Expand Up @@ -593,10 +593,19 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):


@pytest.mark.parametrize("period", list(range(4)))
def test_model_checkpoint_period(tmpdir, period: int):
@pytest.mark.parametrize('trigger_on_train_end', [False, True])
@pytest.mark.parametrize('save_last', [False, True])
def test_model_checkpoint_period(tmpdir, period: int, trigger_on_train_end: bool, save_last: bool):
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period)
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
filename='{epoch}',
save_top_k=-1,
save_last=save_last,
period=period,
trigger_on_train_end=trigger_on_train_end,
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
Expand All @@ -608,16 +617,25 @@ def test_model_checkpoint_period(tmpdir, period: int):
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else []
expected = ([f"epoch={e}.ckpt" for e in range(epochs) if (e + 1) % period == 0] if period > 0 else [])
if save_last and (period > 0 or trigger_on_train_end):
expected.append("last.ckpt")
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
@pytest.mark.parametrize('trigger_on_train_end', [False, True])
@pytest.mark.parametrize('save_last', [False, True])
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs, trigger_on_train_end: bool, save_last: bool):
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs
dirpath=tmpdir,
filename='{epoch}',
save_top_k=-1,
save_last=save_last,
every_n_val_epochs=every_n_val_epochs,
trigger_on_train_end=trigger_on_train_end,
)
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -630,22 +648,31 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
expected = ([f"epoch={e}.ckpt" for e in range(epochs)
if (e + 1) % every_n_val_epochs == 0] if every_n_val_epochs > 0 else [])

if save_last and (every_n_val_epochs > 0 or trigger_on_train_end):
expected.append("last.ckpt")
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs):
@pytest.mark.parametrize('trigger_on_train_end', [False, True])
@pytest.mark.parametrize('save_last', [False, True])
def test_model_checkpoint_every_n_val_epochs_and_period(
tmpdir, every_n_val_epochs, trigger_on_train_end: bool, save_last: bool
):
""" Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
filename='{epoch}',
save_top_k=-1,
save_last=save_last,
every_n_val_epochs=(2 * every_n_val_epochs),
period=every_n_val_epochs
period=every_n_val_epochs,
trigger_on_train_end=trigger_on_train_end,
)
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -658,8 +685,10 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
expected = ([f"epoch={e}.ckpt" for e in range(epochs)
if (e + 1) % every_n_val_epochs == 0] if every_n_val_epochs > 0 else [])
if save_last and (every_n_val_epochs > 0 or trigger_on_train_end):
expected.append("last.ckpt")
assert set(os.listdir(tmpdir)) == set(expected)


Expand Down Expand Up @@ -794,26 +823,71 @@ def test_default_checkpoint_behavior(tmpdir):
assert ckpts[0] == 'epoch=2-step=14.ckpt'


@pytest.mark.parametrize('save_last', [False, True])
def test_ckpt_on_train_end_with_invalid_monitor(tmpdir, save_last: bool):
""" Tests that the checkpoints are saved at end of training with invalid monitor."""

model = LogInTwoMethods()
model_cpt = ModelCheckpoint(
filename="{epoch}",
dirpath=tmpdir,
every_n_val_epochs=2,
monitor="invalid", # monitor is invalid, save_last is not set
save_last=save_last,
trigger_on_train_end=True,
)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
progress_bar_refresh_rate=0,
callbacks=[model_cpt],
logger=False,
)
trainer.fit(model)
expected = ['last.ckpt'] if save_last else []
assert set(expected) == set(os.listdir(tmpdir))


@pytest.mark.parametrize('max_epochs', [1, 2])
@pytest.mark.parametrize('every_n_val_epochs', [2, 3])
@pytest.mark.parametrize('should_validate', [True, False])
@pytest.mark.parametrize('save_last', [True, False])
@pytest.mark.parametrize('verbose', [True, False])
@pytest.mark.parametrize('trigger_on_train_end', [False, True])
def test_model_checkpoint_save_last_warning(
tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool
tmpdir,
caplog,
max_epochs: int,
every_n_val_epochs: int,
should_validate: bool,
save_last: bool,
verbose: bool,
trigger_on_train_end: bool,
):
"""Tests 'Saving latest checkpoint...' log"""
"""Tests 'Saving last checkpoint...' log"""
model = LogInTwoMethods()
if not should_validate:
model.validation_step = None
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose)
ckpt = ModelCheckpoint(
monitor='early_stop_on',
dirpath=tmpdir,
every_n_val_epochs=every_n_val_epochs,
save_top_k=0,
save_last=save_last,
verbose=verbose,
trigger_on_train_end=trigger_on_train_end,
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[ckpt],
max_epochs=max_epochs,
)
with caplog.at_level(logging.INFO):
trainer.fit(model)
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)
expected = False
if save_last and verbose and trigger_on_train_end:
expected = (max_epochs % every_n_val_epochs != 0)
assert caplog.messages.count('Saving last checkpoint...') == expected


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, b1=0.5, b2=0.999):
assert len(yaml_params.keys()) == 2

# verify artifacts
assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1
assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 0

# verify tb logs
event_acc = EventAccumulator(folder_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):

callback0 = StatefulCallback0()
callback1 = StatefulCallback1()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states")
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True, trigger_on_train_end=True)
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -67,7 +67,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
)
trainer.fit(model)

ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
ckpt = torch.load(str(tmpdir / "last.ckpt"))
state0 = ckpt["callbacks"][type(callback0)]
state1 = ckpt["callbacks"][type(callback1)]
assert "content0" in state0 and state0["content0"] == 0
Expand Down