Skip to content
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907))


- Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652))


### Removed

- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
Expand Down
47 changes: 23 additions & 24 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@

import pytorch_lightning
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
DeviceType,
rank_zero_deprecation,
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

Expand All @@ -45,7 +50,7 @@ def hpc_resume_path(self) -> Optional[str]:
dir_path_hpc = str(self.trainer.weights_save_path)
max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_version is not None:
return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt"
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")

def resume_start(self) -> None:
"""
Expand Down Expand Up @@ -129,6 +134,10 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(checkpoint)

# call hpc specific hook
if self.hpc_resume_path is not None:
model.on_hpc_load(self._loaded_checkpoint)

# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)

Expand Down Expand Up @@ -248,6 +257,7 @@ def restore_lr_schedulers(self) -> None:
# ----------------------------------
# PRIVATE OPS
# ----------------------------------

def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
Expand Down Expand Up @@ -365,29 +375,18 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:

return checkpoint

def hpc_load(self, checkpoint_path: str, on_gpu: bool):
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
All restored states are listed in return value description of `dump_checkpoint`.
def hpc_load(self, checkpoint_path: str) -> None:
"""
Attempts to restore the full training and model state from a HPC checkpoint file.

# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# acquire the model
model = self.trainer.lightning_module

# restore model and datamodule state
self.restore_model_state(model, checkpoint)

if self.trainer.root_gpu is not None:
model.cuda(self.trainer.root_gpu)

# restore training state
self.restore_training_state(checkpoint)

# call hpc specific hook
model.on_hpc_load(checkpoint)
.. deprecated::v1.4
Will be removed in v1.6. Use :meth:`restore` instead.
"""
rank_zero_deprecation(
"`CheckpointConnector.hpc_load()` was deprecated in v1.4 and will be removed in v1.6."
" Use `CheckpointConnector.restore()` instead."
)
self.restore(checkpoint_path)

def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
Expand Down
13 changes: 13 additions & 0 deletions tests/deprecated_api/test_remove_1-4.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,16 @@ def training_step(self, batch, batch_idx):

with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
trainer.fit(TestModel())


def test_v1_4_0_deprecated_hpc_load(tmpdir):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
)
trainer.fit(model)
trainer.checkpoint_connector.hpc_save(tmpdir, trainer.logger)
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(str(tmpdir))
with pytest.deprecated_call(match=r"`CheckpointConnector.hpc_load\(\)` was deprecated in v1.4"):
trainer.checkpoint_connector.hpc_load(checkpoint_path)
2 changes: 1 addition & 1 deletion tests/helpers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def run_model_test(
trainer.checkpoint_connector.hpc_save(save_dir, logger)
# test HPC loading
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir)
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
trainer.checkpoint_connector.restore(checkpoint_path)


@torch.no_grad()
Expand Down
2 changes: 1 addition & 1 deletion tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def training_epoch_end(self, outputs) -> None:
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
# test HPC loading
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path)
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
trainer.checkpoint_connector.restore(checkpoint_path)

if on_gpu:
trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1)
Expand Down
13 changes: 13 additions & 0 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import Mock

Expand Down
155 changes: 155 additions & 0 deletions tests/trainer/connectors/test_checkpoint_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import Mock

import torch

from pytorch_lightning import Trainer
from tests.helpers import BoringModel


class HPCHookdedModel(BoringModel):

def __init__(self):
super().__init__()
self.hpc_save_called = 0
self.hpc_load_called = 0

def on_hpc_save(self, checkpoint):
assert "state_dict" in checkpoint
self.hpc_save_called += 1

def on_hpc_load(self, checkpoint):
assert "state_dict" in checkpoint
self.hpc_load_called += 1


def test_hpc_hook_calls(tmpdir):
model = HPCHookdedModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
checkpoint_callback=False,
logger=False,
)
trainer.fit(model)
connector = trainer.checkpoint_connector
connector.hpc_save(tmpdir, logger=Mock())
assert model.hpc_save_called == 1
assert model.hpc_load_called == 0

# new training run, restore from hpc checkpoint file automatically
assert set(os.listdir(tmpdir)) == {"hpc_ckpt_1.ckpt"}
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
checkpoint_callback=False,
logger=False,
)
trainer.fit(model)
assert model.hpc_save_called == 1
assert model.hpc_load_called == 1


def test_preloaded_checkpoint_lifecycle(tmpdir):
""" Tests that the preloaded checkpoint contents gets cleared from memory when it is not required anymore. """
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
)
trainer.fit(model)

connector = trainer.checkpoint_connector

assert not trainer.resume_from_checkpoint
assert not connector.resume_checkpoint_path
assert not connector._loaded_checkpoint

connector.resume_start()
assert not connector.resume_checkpoint_path
assert not connector._loaded_checkpoint
connector.resume_end()
assert not connector.resume_checkpoint_path
assert not connector._loaded_checkpoint

ckpt_path = trainer.checkpoint_callback.best_model_path
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path)
connector = trainer.checkpoint_connector
connector.resume_start()
assert connector.resume_checkpoint_path == ckpt_path
assert connector._loaded_checkpoint
assert isinstance(connector._loaded_checkpoint, dict)
connector.resume_end()
assert not connector.resume_checkpoint_path
assert not connector._loaded_checkpoint


def test_hpc_restore_attempt(tmpdir):
""" Test that restore() attempts to restore the hpc_ckpt with highest priority. """
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
checkpoint_callback=False,
logger=False,
)
trainer.fit(model)

hpc_ckpt_path = tmpdir / "hpc_ckpt_3.ckpt"
trainer.save_checkpoint(hpc_ckpt_path)
assert os.listdir(tmpdir) == ["hpc_ckpt_3.ckpt"]

# set weights to zero
for param in model.parameters():
torch.nn.init.constant_(param, 0)

# case 1: restore hpc first, no explicit resume path provided
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=2,
checkpoint_callback=False,
logger=False,
)
trainer.fit(model)

for param in model.parameters():
assert param.abs().sum() > 0
torch.nn.init.constant_(param, 0)

# case 2: explicit resume path provided, restore hpc anyway
trainer = Trainer(default_root_dir=tmpdir, max_steps=3, resume_from_checkpoint="not existing")
trainer.fit(model)

for param in model.parameters():
assert param.abs().sum() > 0


def test_hpc_max_ckpt_version(tmpdir):
""" Test that the CheckpointConnector is able to find the hpc checkpoint file with the highest version. """
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
)
trainer.fit(model)
trainer.save_checkpoint(tmpdir / "hpc_ckpt.ckpt")
trainer.save_checkpoint(tmpdir / "hpc_ckpt_0.ckpt")
trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt")
trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt")

assert trainer.checkpoint_connector.hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt")
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None