Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
172fe76
Refactor unused argument - model
tarepan Jan 5, 2021
7b1605c
Refactor method discription
tarepan Jan 5, 2021
9e1344b
Refactor method name with its actual functionality
tarepan Jan 5, 2021
5a6d275
Refactor unused argument `on_gpu`
tarepan Jan 5, 2021
883d95a
Add intent commentary
tarepan Jan 5, 2021
97e7eb1
Refactor common restore
tarepan Jan 5, 2021
51ae049
Refactor too much function nest
tarepan Jan 5, 2021
c79d0b0
Refactor too much function nest
tarepan Jan 5, 2021
5c85b21
Refactor function name
tarepan Jan 5, 2021
9a9c217
Fix missing argument
tarepan Jan 5, 2021
9adf999
Refactor hpc load with commons
tarepan Jan 5, 2021
7b6272e
Fix pep8
tarepan Jan 5, 2021
0db7f62
Refactor checkpoint test
tarepan Jan 5, 2021
4bfb232
Refactor for easy test
tarepan Jan 5, 2021
6d52468
Fix pep8
tarepan Jan 5, 2021
2e6b277
Fix trainer setup outside the fit
tarepan Jan 5, 2021
fa03af4
Refactor big method with responsibility
tarepan Jan 5, 2021
c166876
Fix pip8
tarepan Jan 5, 2021
430b5f8
Refactor for diff alignment
tarepan Jan 5, 2021
d9063be
Link upstream issue: #5370
tarepan Jan 6, 2021
5387267
Fix pep8
tarepan Jan 6, 2021
e894f98
Fix type description without functional change
tarepan Jan 7, 2021
87023eb
Merge branch 'release/1.2-dev' into refactor/checkpoint
tarepan Jan 10, 2021
abbfb05
Merge remote-tracking branch 'upstream/release/1.2-dev' into refactor…
tarepan Jan 19, 2021
ae13e0c
Refactor with_gpu type with simple typing
tarepan Jan 19, 2021
1f1d817
Refactor comment format
tarepan Jan 19, 2021
86f9c78
Fix isort
tarepan Jan 19, 2021
c020cb4
Fix pep8
tarepan Jan 19, 2021
2cf5488
Refactor too much type guard
tarepan Jan 19, 2021
0c1f1f5
Merge branch 'release/1.2-dev' into refactor/checkpoint
tarepan Jan 24, 2021
d08bfca
chlog
Borda Jan 29, 2021
5497d3c
Merge branch 'release/1.2-dev' into refactor/checkpoint
awaelchli Jan 31, 2021
54296c5
amp
Borda Feb 1, 2021
740a3ea
Merge branch 'release/1.2-dev' into refactor/checkpoint
Borda Feb 1, 2021
59b1bba
Merge branch 'release/1.2-dev' into refactor/checkpoint
tarepan Feb 4, 2021
d10fba8
Fix yapf
tarepan Feb 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))


- Refactored `hpc_load` and entangled logics in `CheckpointConnector` ([#5371](https://github.com/PyTorchLightning/pytorch-lightning/pull/5371))


- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730))


Expand Down
146 changes: 76 additions & 70 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import re
from pathlib import Path
from typing import Optional, Union
from typing import Any, Dict, Optional, Union

import torch

Expand Down Expand Up @@ -49,28 +49,16 @@ def __init__(self, trainer):
# used to validate checkpointing logic
self.has_trained = False

def restore_weights(self) -> None:
"""
Attempt to restore a checkpoint (e.g. weights) in this priority:
1. from HPC weights
2. from `resume_from_checkpoint` file
3. don't restore
def attempt_to_restore(self) -> None:
"""Attempt to restore model/training states.
"""
# clear cache before restore
if self.trainer._device_type == DeviceType.GPU:
torch.cuda.empty_cache()

# 1. Attempt to restore states from HPC checkpoint
dir_path_hpc = str(self.trainer.weights_save_path)
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_suffix is not None:
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU)
rank_zero_info(f'restored hpc model from: {checkpoint_path}')

# 2. Attempt to restore states from `resume_from_checkpoint` file
elif self.trainer.resume_from_checkpoint is not None:
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU)
# attempt to restore states
model: LightningModule = self.trainer.get_model()
self.attempt_to_apply_checkpoint(model)

# wait for all to catch up
self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights')
Expand All @@ -79,53 +67,95 @@ def restore_weights(self) -> None:
if self.trainer._device_type == DeviceType.GPU:
torch.cuda.empty_cache()

def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
All restored states are listed in return value description of `dump_checkpoint`.
def attempt_to_apply_checkpoint(self, model: LightningModule) -> bool:
"""Attempt to apply checkpoint states to model/training with priority.

Priority:
1. from HPC weights
2. from `resume_from_checkpoint` file
3. don't apply

Returns:
True if applied else False
"""
# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
fs = get_filesystem(checkpoint_path)
if not fs.exists(checkpoint_path):
rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch")
return False
# Design Note:
# `attempt_to_restore` has responsibility to whole state restoration flow (e.g. OOM, parallel processing).
# This method has responsibility to applying/assigning state value from nullable checkpoint.

# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
restored: bool = False

# acquire the model
model = self.trainer.get_model()
# 1. Attempt to apply HPC checkpoint.
dir_path_hpc = str(self.trainer.weights_save_path)
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_suffix is not None:
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
checkpoint = self.restore_states(model, checkpoint_path, self.trainer._device_type == DeviceType.GPU)
model.on_hpc_load(checkpoint)
restored = True
rank_zero_info(f'restored hpc model from: {checkpoint_path}')

# restore model and datamodule state
self.restore_model_state(model, checkpoint)
# 2. Attempt to apply `resume_from_checkpoint` file.
elif self.trainer.resume_from_checkpoint is not None:
adress_checkpoint: str = self.trainer.resume_from_checkpoint
if get_filesystem(adress_checkpoint).exists(adress_checkpoint):
self.restore_states(model, adress_checkpoint, self.trainer._device_type == DeviceType.GPU)
restored = True
rank_zero_info(f"States restored from the checkpoint file at {adress_checkpoint}")
else:
rank_zero_warn(f"checkpoint file at {adress_checkpoint} does not exist.")

if on_gpu:
model.cuda(self.trainer.root_gpu)
# 3. Do not apply, start from scratch.
else:
rank_zero_info("Start from scratch.")

# restore training state
self.restore_training_state(checkpoint)
return restored

rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}")
return True
def restore_states(
self,
model: LightningModule,
checkpoint_path: str,
on_gpu: bool,
) -> Dict[str, Any]:
"""Restore all states from checkpoint in the specified path.

def restore_model_state(self, model: LightningModule, checkpoint) -> None:
"""
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
All restored states are listed in return value description of `dump_checkpoint`.

Args:
on_gpu: Whether trainer is on GPU or not.
"""
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint: Dict[str, Any] = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# restore datamodule states
# restore states
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)
self.restore_model_state(checkpoint, model, on_gpu)
self.restore_training_state(checkpoint)

return checkpoint

def restore_model_state(
self,
checkpoint: Dict[str, Any],
model: LightningModule,
on_gpu: bool,
) -> None:
"""Restore model state.
"""
# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(checkpoint)

# restore model state_dict
model.load_state_dict(checkpoint['state_dict'])

def restore_training_state(self, checkpoint):
"""
Restore trainer state.
# moves the model to the GPU
if on_gpu:
model.cuda(self.trainer.root_gpu)

def restore_training_state(self, checkpoint: Dict[str, Any]) -> None:
"""Restore trainer state.

Model will get its change to update
:param checkpoint:
:return:
Expand Down Expand Up @@ -329,30 +359,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:

return checkpoint

def hpc_load(self, checkpoint_path: str, on_gpu: bool):
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
All restored states are listed in return value description of `dump_checkpoint`.
"""

# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# acquire the model
model = self.trainer.get_model()

# restore model and datamodule state
self.restore_model_state(model, checkpoint)

if self.trainer.root_gpu is not None:
model.cuda(self.trainer.root_gpu)

# restore training state
self.restore_training_state(checkpoint)

# call hpc specific hook
model.on_hpc_load(checkpoint)

def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def setup_training(self):
if self.trainer.is_global_zero:
ref_model.summarize(mode=self.trainer.weights_summary)

# restore training state and model weights before hpc is called
self.trainer.checkpoint_connector.restore_weights()
# restore model/training states before hpc is called
self.trainer.checkpoint_connector.attempt_to_restore()

# on pretrain routine end
self.trainer.on_pretrain_routine_end(ref_model)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def scale_batch_size(trainer,
garbage_collection_cuda()
log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}')

# Restore initial state of model
# Restore initial state of model from temporary checkpoint, which is deleted after restore.
if trainer.is_global_zero:
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU)
trainer.checkpoint_connector.restore_states(model, str(save_path), trainer._device_type == DeviceType.GPU)
fs = get_filesystem(str(save_path))
if fs.exists(save_path):
fs.rm(save_path)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ def lr_find(
'loss': trainer.callbacks[0].losses})
lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose

# Reset model state
# Restore initial state of model from temporary checkpoint, which is deleted after restore.
if trainer.is_global_zero:
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU)
trainer.checkpoint_connector.restore_states(model, str(save_path), trainer._device_type == DeviceType.GPU)
fs = get_filesystem(str(save_path))
if fs.exists(save_path):
fs.rm(save_path)
Expand Down
6 changes: 4 additions & 2 deletions tests/base/develop_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import DistributedType
from tests.base import BoringModel
Expand Down Expand Up @@ -50,7 +51,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()


def run_model_test(trainer_options, model, on_gpu: bool = True, version=None,
def run_model_test(trainer_options, model: LightningModule, on_gpu: bool = True, version=None,
with_hpc: bool = True, min_acc: float = 0.25):

reset_seed()
Expand Down Expand Up @@ -93,7 +94,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None,
trainer.checkpoint_connector.hpc_save(save_dir, logger)
# test HPC loading
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir)
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
checkpoint = trainer.checkpoint_connector.restore_states(model, checkpoint_path, trainer.root_gpu)
trainer.get_model().on_hpc_load(checkpoint)


def run_prediction(trained_model, dataloader, dp=False, min_acc=0.25):
Expand Down
3 changes: 2 additions & 1 deletion tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def run_test_from_config(trainer_options):
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
# test HPC loading
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path)
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=args.on_gpu)
checkpoint = trainer.checkpoint_connector.restore_states(model, checkpoint_path, trainer.root_gpu)
trainer.get_model().on_hpc_load(checkpoint)

if args.on_gpu:
trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1)
Expand Down
33 changes: 20 additions & 13 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import os
import pickle
from copy import deepcopy
from pathlib import Path
from typing import Optional

import cloudpickle
import pytest
Expand Down Expand Up @@ -70,23 +72,28 @@ def test_model_properties_resume_from_checkpoint(tmpdir):
trainer.fit(model)


def test_try_resume_from_non_existing_checkpoint(tmpdir):
def test_try_resume_from_non_existing_checkpoint(tmpdir: Path):
""" Test that trying to resume from non-existing `resume_from_checkpoint` fail without error."""
model = BoringModel()
checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=False,
callbacks=[checkpoint_cb],
limit_train_batches=0.1,
limit_val_batches=0.1,
)

def gen_trainer(name_ckpt: Optional[str]) -> Trainer:
path_dir_saved = tmpdir
path_file_loaded = None if name_ckpt is None else str(tmpdir / name_ckpt)
checkpoint_cb = ModelCheckpoint(dirpath=path_dir_saved, monitor="early_stop_on", save_last=True)
return Trainer(
resume_from_checkpoint=path_file_loaded,
max_epochs=1,
logger=False,
callbacks=[checkpoint_cb],
limit_train_batches=0.1,
limit_val_batches=0.1,
)

# Generate checkpoint `last.ckpt` with BoringModel
trainer.fit(model)
gen_trainer(None).fit(model)
# `True` if resume/restore successfully else `False`
assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu)
assert not trainer.checkpoint_connector.restore(str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu)
assert gen_trainer("last.ckpt").checkpoint_connector.attempt_to_apply_checkpoint(model)
assert not gen_trainer("last_non_existing.ckpt").checkpoint_connector.attempt_to_apply_checkpoint(model)


class CaptureCallbacksBeforeTraining(Callback):
Expand Down