diff --git a/CHANGELOG.md b/CHANGELOG.md index 623de5b04b652..146be157e2db8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -110,6 +110,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516)) +- Refactored `hpc_load` and entangled logics in `CheckpointConnector` ([#5371](https://github.com/PyTorchLightning/pytorch-lightning/pull/5371)) + + - Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730)) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 903d1521d97f6..1144d3e342da2 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -15,7 +15,7 @@ import os import re from pathlib import Path -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch @@ -49,28 +49,16 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False - def restore_weights(self) -> None: - """ - Attempt to restore a checkpoint (e.g. weights) in this priority: - 1. from HPC weights - 2. from `resume_from_checkpoint` file - 3. don't restore + def attempt_to_restore(self) -> None: + """Attempt to restore model/training states. """ # clear cache before restore if self.trainer._device_type == DeviceType.GPU: torch.cuda.empty_cache() - # 1. Attempt to restore states from HPC checkpoint - dir_path_hpc = str(self.trainer.weights_save_path) - max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") - if max_suffix is not None: - checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' - self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU) - rank_zero_info(f'restored hpc model from: {checkpoint_path}') - - # 2. Attempt to restore states from `resume_from_checkpoint` file - elif self.trainer.resume_from_checkpoint is not None: - self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU) + # attempt to restore states + model: LightningModule = self.trainer.get_model() + self.attempt_to_apply_checkpoint(model) # wait for all to catch up self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') @@ -79,53 +67,95 @@ def restore_weights(self) -> None: if self.trainer._device_type == DeviceType.GPU: torch.cuda.empty_cache() - def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: - """ - Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. - All restored states are listed in return value description of `dump_checkpoint`. + def attempt_to_apply_checkpoint(self, model: LightningModule) -> bool: + """Attempt to apply checkpoint states to model/training with priority. + + Priority: + 1. from HPC weights + 2. from `resume_from_checkpoint` file + 3. don't apply + + Returns: + True if applied else False """ - # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. - fs = get_filesystem(checkpoint_path) - if not fs.exists(checkpoint_path): - rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch") - return False + # Design Note: + # `attempt_to_restore` has responsibility to whole state restoration flow (e.g. OOM, parallel processing). + # This method has responsibility to applying/assigning state value from nullable checkpoint. - # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + restored: bool = False - # acquire the model - model = self.trainer.get_model() + # 1. Attempt to apply HPC checkpoint. + dir_path_hpc = str(self.trainer.weights_save_path) + max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") + if max_suffix is not None: + checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' + checkpoint = self.restore_states(model, checkpoint_path, self.trainer._device_type == DeviceType.GPU) + model.on_hpc_load(checkpoint) + restored = True + rank_zero_info(f'restored hpc model from: {checkpoint_path}') - # restore model and datamodule state - self.restore_model_state(model, checkpoint) + # 2. Attempt to apply `resume_from_checkpoint` file. + elif self.trainer.resume_from_checkpoint is not None: + adress_checkpoint: str = self.trainer.resume_from_checkpoint + if get_filesystem(adress_checkpoint).exists(adress_checkpoint): + self.restore_states(model, adress_checkpoint, self.trainer._device_type == DeviceType.GPU) + restored = True + rank_zero_info(f"States restored from the checkpoint file at {adress_checkpoint}") + else: + rank_zero_warn(f"checkpoint file at {adress_checkpoint} does not exist.") - if on_gpu: - model.cuda(self.trainer.root_gpu) + # 3. Do not apply, start from scratch. + else: + rank_zero_info("Start from scratch.") - # restore training state - self.restore_training_state(checkpoint) + return restored - rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}") - return True + def restore_states( + self, + model: LightningModule, + checkpoint_path: str, + on_gpu: bool, + ) -> Dict[str, Any]: + """Restore all states from checkpoint in the specified path. - def restore_model_state(self, model: LightningModule, checkpoint) -> None: - """ - Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object + Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. + All restored states are listed in return value description of `dump_checkpoint`. + + Args: + on_gpu: Whether trainer is on GPU or not. """ + # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` + checkpoint: Dict[str, Any] = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - # restore datamodule states + # restore states if self.trainer.datamodule is not None: self.trainer.datamodule.on_load_checkpoint(checkpoint) + self.restore_model_state(checkpoint, model, on_gpu) + self.restore_training_state(checkpoint) + + return checkpoint + def restore_model_state( + self, + checkpoint: Dict[str, Any], + model: LightningModule, + on_gpu: bool, + ) -> None: + """Restore model state. + """ # hook: give user access to checkpoint if needed. model.on_load_checkpoint(checkpoint) # restore model state_dict model.load_state_dict(checkpoint['state_dict']) - def restore_training_state(self, checkpoint): - """ - Restore trainer state. + # moves the model to the GPU + if on_gpu: + model.cuda(self.trainer.root_gpu) + + def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: + """Restore trainer state. + Model will get its change to update :param checkpoint: :return: @@ -329,30 +359,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint - def hpc_load(self, checkpoint_path: str, on_gpu: bool): - """ - Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. - All restored states are listed in return value description of `dump_checkpoint`. - """ - - # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - - # acquire the model - model = self.trainer.get_model() - - # restore model and datamodule state - self.restore_model_state(model, checkpoint) - - if self.trainer.root_gpu is not None: - model.cuda(self.trainer.root_gpu) - - # restore training state - self.restore_training_state(checkpoint) - - # call hpc specific hook - model.on_hpc_load(checkpoint) - def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]: """List up files in `dir_path` with `name_key`, then yield maximum suffix number. diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d49c2d79cbebd..56558177dbbb4 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -158,8 +158,8 @@ def setup_training(self): if self.trainer.is_global_zero: ref_model.summarize(mode=self.trainer.weights_summary) - # restore training state and model weights before hpc is called - self.trainer.checkpoint_connector.restore_weights() + # restore model/training states before hpc is called + self.trainer.checkpoint_connector.attempt_to_restore() # on pretrain routine end self.trainer.on_pretrain_routine_end(ref_model) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 38cb53bbd7ae6..b7d313a935ee9 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -113,9 +113,9 @@ def scale_batch_size(trainer, garbage_collection_cuda() log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}') - # Restore initial state of model + # Restore initial state of model from temporary checkpoint, which is deleted after restore. if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) + trainer.checkpoint_connector.restore_states(model, str(save_path), trainer._device_type == DeviceType.GPU) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index ac201f7f26afe..4e701136e227b 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -190,9 +190,9 @@ def lr_find( 'loss': trainer.callbacks[0].losses}) lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose - # Reset model state + # Restore initial state of model from temporary checkpoint, which is deleted after restore. if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) + trainer.checkpoint_connector.restore_states(model, str(save_path), trainer._device_type == DeviceType.GPU) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index 4949d53fc9a50..f173859093417 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -14,6 +14,7 @@ import torch from pytorch_lightning import Trainer +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import DistributedType from tests.base import BoringModel @@ -50,7 +51,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50 trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() -def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, +def run_model_test(trainer_options, model: LightningModule, on_gpu: bool = True, version=None, with_hpc: bool = True, min_acc: float = 0.25): reset_seed() @@ -93,7 +94,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, trainer.checkpoint_connector.hpc_save(save_dir, logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir) - trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu) + checkpoint = trainer.checkpoint_connector.restore_states(model, checkpoint_path, trainer.root_gpu) + trainer.get_model().on_hpc_load(checkpoint) def run_prediction(trained_model, dataloader, dp=False, min_acc=0.25): diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index db1321623dd8a..6ace4662441e6 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -79,7 +79,8 @@ def run_test_from_config(trainer_options): trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path) - trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=args.on_gpu) + checkpoint = trainer.checkpoint_connector.restore_states(model, checkpoint_path, trainer.root_gpu) + trainer.get_model().on_hpc_load(checkpoint) if args.on_gpu: trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 4dbb6554977b3..3f2a046b2cc84 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -16,6 +16,8 @@ import os import pickle from copy import deepcopy +from pathlib import Path +from typing import Optional import cloudpickle import pytest @@ -70,23 +72,28 @@ def test_model_properties_resume_from_checkpoint(tmpdir): trainer.fit(model) -def test_try_resume_from_non_existing_checkpoint(tmpdir): +def test_try_resume_from_non_existing_checkpoint(tmpdir: Path): """ Test that trying to resume from non-existing `resume_from_checkpoint` fail without error.""" model = BoringModel() - checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - logger=False, - callbacks=[checkpoint_cb], - limit_train_batches=0.1, - limit_val_batches=0.1, - ) + + def gen_trainer(name_ckpt: Optional[str]) -> Trainer: + path_dir_saved = tmpdir + path_file_loaded = None if name_ckpt is None else str(tmpdir / name_ckpt) + checkpoint_cb = ModelCheckpoint(dirpath=path_dir_saved, monitor="early_stop_on", save_last=True) + return Trainer( + resume_from_checkpoint=path_file_loaded, + max_epochs=1, + logger=False, + callbacks=[checkpoint_cb], + limit_train_batches=0.1, + limit_val_batches=0.1, + ) + # Generate checkpoint `last.ckpt` with BoringModel - trainer.fit(model) + gen_trainer(None).fit(model) # `True` if resume/restore successfully else `False` - assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu) - assert not trainer.checkpoint_connector.restore(str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu) + assert gen_trainer("last.ckpt").checkpoint_connector.attempt_to_apply_checkpoint(model) + assert not gen_trainer("last_non_existing.ckpt").checkpoint_connector.attempt_to_apply_checkpoint(model) class CaptureCallbacksBeforeTraining(Callback):