From 172fe768ae89fb87b159de839355386da4795f5a Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 5 Jan 2021 11:47:13 +0900 Subject: [PATCH 01/30] Refactor unused argument - model --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index f13765ac28ce4..e1f5f9c691d31 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -42,7 +42,7 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False - def restore_weights(self, model: LightningModule) -> None: + def restore_weights(self) -> None: """ Attempt to restore a checkpoint (e.g. weights) in this priority: 1. from HPC weights diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 64eb224a428f1..817de4e2d26a9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -185,7 +185,7 @@ def setup_training(self, model: LightningModule): self.trainer.model = model # restore training state and model weights before hpc is called - self.trainer.checkpoint_connector.restore_weights(model) + self.trainer.checkpoint_connector.restore_weights() # on pretrain routine end self.trainer.on_pretrain_routine_end(ref_model) From 7b1605c4e4a93e6684ac22570520df45407640df Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 5 Jan 2021 11:57:23 +0900 Subject: [PATCH 02/30] Refactor method discription --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index e1f5f9c691d31..8e80919096068 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -44,7 +44,7 @@ def __init__(self, trainer): def restore_weights(self) -> None: """ - Attempt to restore a checkpoint (e.g. weights) in this priority: + Attempt to restore model/training states in this priority: 1. from HPC weights 2. from `resume_from_checkpoint` file 3. don't restore From 9e1344b1a24cb1e388d113d94bb5f2ca46f5eab5 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 5 Jan 2021 12:32:46 +0900 Subject: [PATCH 03/30] Refactor method name with its actual functionality --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- pytorch_lightning/trainer/training_loop.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 8e80919096068..daf60b8dbf102 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -42,7 +42,7 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False - def restore_weights(self) -> None: + def attempt_to_restore(self) -> None: """ Attempt to restore model/training states in this priority: 1. from HPC weights diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 817de4e2d26a9..34f5d0e78e3ff 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -184,8 +184,8 @@ def setup_training(self, model: LightningModule): # if cluster resets state, the model will update with the saved weights self.trainer.model = model - # restore training state and model weights before hpc is called - self.trainer.checkpoint_connector.restore_weights() + # restore training and model before hpc is called + self.trainer.checkpoint_connector.attempt_to_restore() # on pretrain routine end self.trainer.on_pretrain_routine_end(ref_model) From 5a6d275795de59103b84bc72b1bc5dd58a2a1b0c Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 5 Jan 2021 12:45:07 +0900 Subject: [PATCH 04/30] Refactor unused argument `on_gpu` --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 4 ++-- tests/base/develop_pipelines.py | 2 +- tests/models/data/horovod/train_default_model.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index daf60b8dbf102..a38ea6297f3b6 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -58,7 +58,7 @@ def attempt_to_restore(self) -> None: max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") if max_suffix is not None: checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' - self.hpc_load(checkpoint_path, self.trainer.on_gpu) + self.hpc_load(checkpoint_path) rank_zero_info(f'restored hpc model from: {checkpoint_path}') # 2. Attempt to restore states from `resume_from_checkpoint` file @@ -320,7 +320,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint - def hpc_load(self, checkpoint_path: str, on_gpu: bool): + def hpc_load(self, checkpoint_path: str): """ Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. All restored states are listed in return value description of `dump_checkpoint`. diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index 24535dc67da8e..4090bf47e0957 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -90,7 +90,7 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi trainer.checkpoint_connector.hpc_save(save_dir, logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir) - trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu) + trainer.checkpoint_connector.hpc_load(checkpoint_path) def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50): diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 62f874902b094..2a600a22ede7b 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -78,7 +78,7 @@ def run_test_from_config(trainer_options): trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path) - trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=args.on_gpu) + trainer.checkpoint_connector.hpc_load(checkpoint_path) if args.on_gpu: trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1) From 883d95a25776e8dae90791788901abb85cd83a7b Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 5 Jan 2021 14:00:30 +0900 Subject: [PATCH 05/30] Add intent commentary --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index a38ea6297f3b6..8b13cfefbf2ce 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -92,6 +92,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: # restore model and datamodule state self.restore_model_state(model, checkpoint) + # moves the model to the GPU if on_gpu: model.cuda(self.trainer.root_gpu) @@ -335,6 +336,7 @@ def hpc_load(self, checkpoint_path: str): # restore model and datamodule state self.restore_model_state(model, checkpoint) + # moves the model to the GPU if self.trainer.root_gpu is not None: model.cuda(self.trainer.root_gpu) From 97e7eb1f8198a912da9b25bf8fa88f7095b8ab06 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 5 Jan 2021 16:05:45 +0900 Subject: [PATCH 06/30] Refactor common restore --- .../connectors/checkpoint_connector.py | 51 ++++++++----------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 8b13cfefbf2ce..53f490b98e311 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -77,12 +77,31 @@ def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. """ - # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. fs = get_filesystem(checkpoint_path) if not fs.exists(checkpoint_path): rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch") return False + self.restore_from_checkpoint(checkpoint_path, on_gpu) + rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}") + return True + + def hpc_load(self, checkpoint_path: str): + """ + Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. + All restored states are listed in return value description of `dump_checkpoint`. + """ + self.restore_from_checkpoint(checkpoint_path, self.trainer.root_gpu) + # call hpc specific hook + self.trainer.get_model().on_hpc_load(checkpoint) + + def restore_from_checkpoint(self, checkpoint_path: str, with_gpu: Union[bool, Optional[int]]) -> None: + """ + Restore states from existing checkpoint. + `with_gpu=on_gpu` works as normal restore, `with_gpu=trainer.root_gpu` works as hpc restore. + Args: + with_gpu: bool for `on_gpu`, Optional[int] for `trainer.root_gpu`. + """ # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) @@ -93,15 +112,12 @@ def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: self.restore_model_state(model, checkpoint) # moves the model to the GPU - if on_gpu: + if (with_gpu is True) or ((not isinstance(with_gpu, bool)) and (with_gpu is not None)): model.cuda(self.trainer.root_gpu) # restore training state self.restore_training_state(checkpoint) - rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}") - return True - def restore_model_state(self, model: LightningModule, checkpoint) -> None: """ Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object @@ -321,31 +337,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint - def hpc_load(self, checkpoint_path: str): - """ - Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. - All restored states are listed in return value description of `dump_checkpoint`. - """ - - # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - - # acquire the model - model = self.trainer.get_model() - - # restore model and datamodule state - self.restore_model_state(model, checkpoint) - - # moves the model to the GPU - if self.trainer.root_gpu is not None: - model.cuda(self.trainer.root_gpu) - - # restore training state - self.restore_training_state(checkpoint) - - # call hpc specific hook - model.on_hpc_load(checkpoint) - def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]: """List up files in `dir_path` with name_key, then yield maximum suffix number. From 51ae04917cd2e971ffd6d08a5e4b9c2b32bb4e45 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 5 Jan 2021 16:13:13 +0900 Subject: [PATCH 07/30] Refactor too much function nest --- .../connectors/checkpoint_connector.py | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 53f490b98e311..624a4cdff3cda 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -108,21 +108,6 @@ def restore_from_checkpoint(self, checkpoint_path: str, with_gpu: Union[bool, Op # acquire the model model = self.trainer.get_model() - # restore model and datamodule state - self.restore_model_state(model, checkpoint) - - # moves the model to the GPU - if (with_gpu is True) or ((not isinstance(with_gpu, bool)) and (with_gpu is not None)): - model.cuda(self.trainer.root_gpu) - - # restore training state - self.restore_training_state(checkpoint) - - def restore_model_state(self, model: LightningModule, checkpoint) -> None: - """ - Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object - """ - # restore datamodule states if self.trainer.datamodule is not None: self.trainer.datamodule.on_load_checkpoint(checkpoint) @@ -133,6 +118,13 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None: # restore model state_dict model.load_state_dict(checkpoint['state_dict']) + # moves the model to the GPU + if (with_gpu is True) or ((not isinstance(with_gpu, bool)) and (with_gpu is not None)): + model.cuda(self.trainer.root_gpu) + + # restore training state + self.restore_training_state(checkpoint) + def restore_training_state(self, checkpoint): """ Restore trainer state. From c79d0b088747b82c356fcbe1b8fb4aa25e28ff8a Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 5 Jan 2021 17:36:51 +0900 Subject: [PATCH 08/30] Refactor too much function nest --- .../connectors/checkpoint_connector.py | 47 +++++++++++-------- pytorch_lightning/tuner/batch_size_scaling.py | 4 +- pytorch_lightning/tuner/lr_finder.py | 4 +- tests/models/test_restore.py | 31 +++++++----- 4 files changed, 49 insertions(+), 37 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 624a4cdff3cda..8a2dd4264985f 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -42,28 +42,45 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False - def attempt_to_restore(self) -> None: + def attempt_to_restore(self) -> bool: """ Attempt to restore model/training states in this priority: 1. from HPC weights 2. from `resume_from_checkpoint` file 3. don't restore + + Returns: + True if restored else False """ + restored: bool = False + # clear cache before restore if self.trainer.on_gpu: torch.cuda.empty_cache() - # 1. Attempt to restore states from HPC checkpoint + # 1. Attempt to restore states from HPC checkpoint. dir_path_hpc = str(self.trainer.weights_save_path) max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") if max_suffix is not None: checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' self.hpc_load(checkpoint_path) + restored = True rank_zero_info(f'restored hpc model from: {checkpoint_path}') - # 2. Attempt to restore states from `resume_from_checkpoint` file + # 2. Attempt to restore states from `resume_from_checkpoint` file. elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing: - self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu) + adress_checkpoint: str = self.trainer.resume_from_checkpoint + if get_filesystem(adress_checkpoint).exists(adress_checkpoint): + self.restore_from_checkpoint(adress_checkpoint, self.trainer.on_gpu) + restored = True + rank_zero_info(f"States restored from the checkpoint file at {adress_checkpoint}") + else: + rank_zero_warn(f"checkpoint file at {adress_checkpoint} does not exist.") + + # 3. Do not restore, start from scratch. + else: + rank_zero_info("Start from scratch.") + # wait for all to catch up self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') @@ -71,20 +88,9 @@ def attempt_to_restore(self) -> None: # clear cache after restore if self.trainer.on_gpu: torch.cuda.empty_cache() - - def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: - """ - Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. - All restored states are listed in return value description of `dump_checkpoint`. - """ - fs = get_filesystem(checkpoint_path) - if not fs.exists(checkpoint_path): - rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch") - return False - self.restore_from_checkpoint(checkpoint_path, on_gpu) - rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}") - return True - + + return restored + def hpc_load(self, checkpoint_path: str): """ Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. @@ -96,8 +102,9 @@ def hpc_load(self, checkpoint_path: str): def restore_from_checkpoint(self, checkpoint_path: str, with_gpu: Union[bool, Optional[int]]) -> None: """ - Restore states from existing checkpoint. - `with_gpu=on_gpu` works as normal restore, `with_gpu=trainer.root_gpu` works as hpc restore. + Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. + All restored states are listed in return value description of `dump_checkpoint`. + `with_gpu=trainer.on_gpu` works as normal restore, `with_gpu=trainer.root_gpu` works as hpc restore. Args: with_gpu: bool for `on_gpu`, Optional[int] for `trainer.root_gpu`. diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 52662f6172d8d..fd7a0fb62ae56 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -113,9 +113,9 @@ def scale_batch_size(trainer, garbage_collection_cuda() log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}') - # Restore initial state of model + # Restore initial state of model from temporary checkpoint, which is deleted after restore. if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) + trainer.checkpoint_connector.restore_from_checkpoint(str(save_path), trainer.on_gpu) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index e0fab12eec9d3..14f2909cc9bbd 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -190,9 +190,9 @@ def lr_find( 'loss': trainer.callbacks[0].losses}) lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose - # Reset model state + # Restore initial state of model from temporary checkpoint, which is deleted after restore. if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) + trainer.checkpoint_connector.restore_from_checkpoint(str(save_path), trainer.on_gpu) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index a2a9aa6b9042c..ad7339413cc13 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -14,8 +14,10 @@ import glob import logging as log import os +from pathlib import Path import pickle from copy import deepcopy +from typing import Optional import cloudpickle import pytest @@ -71,24 +73,27 @@ def test_model_properties_resume_from_checkpoint(enable_pl_optimizer, tmpdir): trainer.fit(model) -def test_try_resume_from_non_existing_checkpoint(tmpdir): +def test_try_resume_from_non_existing_checkpoint(tmpdir: Path): """ Test that trying to resume from non-existing `resume_from_checkpoint` fail without error.""" model = BoringModel() checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - logger=False, - callbacks=[checkpoint_cb], - limit_train_batches=0.1, - limit_val_batches=0.1, - ) + def gen_trainer(name_ckpt: Optional[str]) -> Trainer: + path_ckpt = None if name_ckpt is None else str(tmpdir / name_ckpt) + return Trainer( + default_root_dir=tmpdir, + resume_from_checkpoint=path_ckpt, + max_epochs=1, + logger=False, + callbacks=[checkpoint_cb], + limit_train_batches=0.1, + limit_val_batches=0.1, + ) # Generate checkpoint `last.ckpt` with BoringModel - trainer.fit(model) + gen_trainer(None).fit(model) # `True` if resume/restore successfully else `False` - assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu) - assert not trainer.checkpoint_connector.restore(str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu) - + assert gen_trainer("last.ckpt").checkpoint_connector.attempt_to_restore() + assert not gen_trainer("last_non_existing.ckpt").checkpoint_connector.attempt_to_restore() + class CaptureCallbacksBeforeTraining(Callback): callbacks = [] From 5c85b2175c6d9985d632b9ef107b606edb4e3ee2 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 5 Jan 2021 18:38:26 +0900 Subject: [PATCH 09/30] Refactor function name --- .../trainer/connectors/checkpoint_connector.py | 13 +++++++------ pytorch_lightning/tuner/batch_size_scaling.py | 2 +- pytorch_lightning/tuner/lr_finder.py | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 8a2dd4264985f..fd084d6c6a3ce 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -15,7 +15,7 @@ import os import re from pathlib import Path -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch @@ -71,7 +71,7 @@ def attempt_to_restore(self) -> bool: elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing: adress_checkpoint: str = self.trainer.resume_from_checkpoint if get_filesystem(adress_checkpoint).exists(adress_checkpoint): - self.restore_from_checkpoint(adress_checkpoint, self.trainer.on_gpu) + self.restore_states(adress_checkpoint, self.trainer.on_gpu) restored = True rank_zero_info(f"States restored from the checkpoint file at {adress_checkpoint}") else: @@ -81,7 +81,6 @@ def attempt_to_restore(self) -> bool: else: rank_zero_info("Start from scratch.") - # wait for all to catch up self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') @@ -96,11 +95,11 @@ def hpc_load(self, checkpoint_path: str): Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. All restored states are listed in return value description of `dump_checkpoint`. """ - self.restore_from_checkpoint(checkpoint_path, self.trainer.root_gpu) + self.restore_states(checkpoint_path, self.trainer.root_gpu) # call hpc specific hook self.trainer.get_model().on_hpc_load(checkpoint) - def restore_from_checkpoint(self, checkpoint_path: str, with_gpu: Union[bool, Optional[int]]) -> None: + def restore_states(self, checkpoint_path: str, with_gpu: Union[bool, Optional[int]]) -> Dict[str, Any]: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. @@ -110,7 +109,7 @@ def restore_from_checkpoint(self, checkpoint_path: str, with_gpu: Union[bool, Op with_gpu: bool for `on_gpu`, Optional[int] for `trainer.root_gpu`. """ # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + checkpoint: Dict[str, Any] = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # acquire the model model = self.trainer.get_model() @@ -132,6 +131,8 @@ def restore_from_checkpoint(self, checkpoint_path: str, with_gpu: Union[bool, Op # restore training state self.restore_training_state(checkpoint) + return checkpoint + def restore_training_state(self, checkpoint): """ Restore trainer state. diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index fd7a0fb62ae56..5364fd7ef5402 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -115,7 +115,7 @@ def scale_batch_size(trainer, # Restore initial state of model from temporary checkpoint, which is deleted after restore. if trainer.is_global_zero: - trainer.checkpoint_connector.restore_from_checkpoint(str(save_path), trainer.on_gpu) + trainer.checkpoint_connector.restore_states(str(save_path), trainer.on_gpu) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 14f2909cc9bbd..e540759c155ac 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -192,7 +192,7 @@ def lr_find( # Restore initial state of model from temporary checkpoint, which is deleted after restore. if trainer.is_global_zero: - trainer.checkpoint_connector.restore_from_checkpoint(str(save_path), trainer.on_gpu) + trainer.checkpoint_connector.restore_states(str(save_path), trainer.on_gpu) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) From 9a9c2177cc7de008842dc6293eb5109dfd19de4e Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 5 Jan 2021 18:40:41 +0900 Subject: [PATCH 10/30] Fix missing argument --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index fd084d6c6a3ce..3f6dcedb59101 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -95,7 +95,7 @@ def hpc_load(self, checkpoint_path: str): Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. All restored states are listed in return value description of `dump_checkpoint`. """ - self.restore_states(checkpoint_path, self.trainer.root_gpu) + checkpoint = self.restore_states(checkpoint_path, self.trainer.root_gpu) # call hpc specific hook self.trainer.get_model().on_hpc_load(checkpoint) From 9adf99948d0563f070fe5bd2315ed3d18a253961 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 5 Jan 2021 18:51:22 +0900 Subject: [PATCH 11/30] Refactor hpc load with commons --- .../trainer/connectors/checkpoint_connector.py | 12 ++---------- tests/base/develop_pipelines.py | 3 ++- tests/models/data/horovod/train_default_model.py | 3 ++- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3f6dcedb59101..b08b8d8c9bf90 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -63,7 +63,8 @@ def attempt_to_restore(self) -> bool: max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") if max_suffix is not None: checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' - self.hpc_load(checkpoint_path) + checkpoint = self.restore_states(checkpoint_path, self.trainer.root_gpu) + self.trainer.get_model().on_hpc_load(checkpoint) restored = True rank_zero_info(f'restored hpc model from: {checkpoint_path}') @@ -90,15 +91,6 @@ def attempt_to_restore(self) -> bool: return restored - def hpc_load(self, checkpoint_path: str): - """ - Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. - All restored states are listed in return value description of `dump_checkpoint`. - """ - checkpoint = self.restore_states(checkpoint_path, self.trainer.root_gpu) - # call hpc specific hook - self.trainer.get_model().on_hpc_load(checkpoint) - def restore_states(self, checkpoint_path: str, with_gpu: Union[bool, Optional[int]]) -> Dict[str, Any]: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index 4090bf47e0957..029bdd12aac38 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -90,7 +90,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi trainer.checkpoint_connector.hpc_save(save_dir, logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir) - trainer.checkpoint_connector.hpc_load(checkpoint_path) + checkpoint = trainer.checkpoint_connector.restore_states(checkpoint_path, trainer.root_gpu) + trainer.get_model().on_hpc_load(checkpoint) def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50): diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 2a600a22ede7b..b6ebd8fa37bac 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -78,7 +78,8 @@ def run_test_from_config(trainer_options): trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path) - trainer.checkpoint_connector.hpc_load(checkpoint_path) + checkpoint = trainer.checkpoint_connector.restore_states(checkpoint_path, trainer.root_gpu) + trainer.get_model().on_hpc_load(checkpoint) if args.on_gpu: trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1) From 7b6272e903c02396db936f6c4b235ea27b5950db Mon Sep 17 00:00:00 2001 From: Tarepan Date: Wed, 6 Jan 2021 05:43:44 +0900 Subject: [PATCH 12/30] Fix pep8 --- .../trainer/connectors/checkpoint_connector.py | 6 +++--- tests/models/test_restore.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b08b8d8c9bf90..83325a1e06a26 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -77,7 +77,7 @@ def attempt_to_restore(self) -> bool: rank_zero_info(f"States restored from the checkpoint file at {adress_checkpoint}") else: rank_zero_warn(f"checkpoint file at {adress_checkpoint} does not exist.") - + # 3. Do not restore, start from scratch. else: rank_zero_info("Start from scratch.") @@ -88,9 +88,9 @@ def attempt_to_restore(self) -> bool: # clear cache after restore if self.trainer.on_gpu: torch.cuda.empty_cache() - + return restored - + def restore_states(self, checkpoint_path: str, with_gpu: Union[bool, Optional[int]]) -> Dict[str, Any]: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index ad7339413cc13..5ae5a29af1f64 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -77,6 +77,7 @@ def test_try_resume_from_non_existing_checkpoint(tmpdir: Path): """ Test that trying to resume from non-existing `resume_from_checkpoint` fail without error.""" model = BoringModel() checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) + def gen_trainer(name_ckpt: Optional[str]) -> Trainer: path_ckpt = None if name_ckpt is None else str(tmpdir / name_ckpt) return Trainer( @@ -88,12 +89,13 @@ def gen_trainer(name_ckpt: Optional[str]) -> Trainer: limit_train_batches=0.1, limit_val_batches=0.1, ) + # Generate checkpoint `last.ckpt` with BoringModel gen_trainer(None).fit(model) # `True` if resume/restore successfully else `False` assert gen_trainer("last.ckpt").checkpoint_connector.attempt_to_restore() assert not gen_trainer("last_non_existing.ckpt").checkpoint_connector.attempt_to_restore() - + class CaptureCallbacksBeforeTraining(Callback): callbacks = [] From 0db7f62f15760baa9ac25888e1f43e98a588b36d Mon Sep 17 00:00:00 2001 From: Tarepan Date: Wed, 6 Jan 2021 06:22:08 +0900 Subject: [PATCH 13/30] Refactor checkpoint test --- tests/models/test_restore.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 5ae5a29af1f64..11ea4f0af9771 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -76,13 +76,13 @@ def test_model_properties_resume_from_checkpoint(enable_pl_optimizer, tmpdir): def test_try_resume_from_non_existing_checkpoint(tmpdir: Path): """ Test that trying to resume from non-existing `resume_from_checkpoint` fail without error.""" model = BoringModel() - checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True) def gen_trainer(name_ckpt: Optional[str]) -> Trainer: - path_ckpt = None if name_ckpt is None else str(tmpdir / name_ckpt) + path_dir_saved = tmpdir + path_file_loaded = None if name_ckpt is None else str(tmpdir / name_ckpt) + checkpoint_cb = ModelCheckpoint(dirpath=path_dir_saved, monitor="early_stop_on", save_last=True) return Trainer( - default_root_dir=tmpdir, - resume_from_checkpoint=path_ckpt, + resume_from_checkpoint=path_file_loaded, max_epochs=1, logger=False, callbacks=[checkpoint_cb], From 4bfb232e1fc1ac929b23570b7216f5db4143289d Mon Sep 17 00:00:00 2001 From: Tarepan Date: Wed, 6 Jan 2021 07:00:41 +0900 Subject: [PATCH 14/30] Refactor for easy test --- .../connectors/checkpoint_connector.py | 51 +++++++++++++------ pytorch_lightning/trainer/training_loop.py | 4 +- pytorch_lightning/tuner/batch_size_scaling.py | 2 +- pytorch_lightning/tuner/lr_finder.py | 2 +- tests/base/develop_pipelines.py | 5 +- .../data/horovod/train_default_model.py | 2 +- tests/models/test_restore.py | 4 +- 7 files changed, 45 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 83325a1e06a26..3dc76f1d22ef2 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -42,16 +42,21 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False - def attempt_to_restore(self) -> bool: + def attempt_to_restore(self, model: LightningModule) -> bool: """ Attempt to restore model/training states in this priority: 1. from HPC weights 2. from `resume_from_checkpoint` file 3. don't restore + Args: + model: Model to which states from checkpoint are applied. Returns: True if restored else False """ + # Design Note: + # `model` can be acquired with `self.trainer.get_model()`, but it make testing hard. + restored: bool = False # clear cache before restore @@ -63,8 +68,8 @@ def attempt_to_restore(self) -> bool: max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") if max_suffix is not None: checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' - checkpoint = self.restore_states(checkpoint_path, self.trainer.root_gpu) - self.trainer.get_model().on_hpc_load(checkpoint) + checkpoint = self.restore_states(model, checkpoint_path, self.trainer.root_gpu) + model.on_hpc_load(checkpoint) restored = True rank_zero_info(f'restored hpc model from: {checkpoint_path}') @@ -72,7 +77,7 @@ def attempt_to_restore(self) -> bool: elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing: adress_checkpoint: str = self.trainer.resume_from_checkpoint if get_filesystem(adress_checkpoint).exists(adress_checkpoint): - self.restore_states(adress_checkpoint, self.trainer.on_gpu) + self.restore_states(model, adress_checkpoint, self.trainer.on_gpu) restored = True rank_zero_info(f"States restored from the checkpoint file at {adress_checkpoint}") else: @@ -91,7 +96,12 @@ def attempt_to_restore(self) -> bool: return restored - def restore_states(self, checkpoint_path: str, with_gpu: Union[bool, Optional[int]]) -> Dict[str, Any]: + def restore_states( + self, + model: LightningModule, + checkpoint_path: str, + with_gpu: Union[bool, Optional[int]], + ) -> Dict[str, Any]: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. @@ -103,12 +113,26 @@ def restore_states(self, checkpoint_path: str, with_gpu: Union[bool, Optional[in # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` checkpoint: Dict[str, Any] = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - # acquire the model - model = self.trainer.get_model() - - # restore datamodule states + # restore states if self.trainer.datamodule is not None: self.trainer.datamodule.on_load_checkpoint(checkpoint) + self.restore_model_state(model, checkpoint, with_gpu, self.trainer.root_gpu) + self.restore_training_state(checkpoint) + + return checkpoint + + def restore_model_state( + self, + model: LightningModule, + checkpoint: Dict[str, Any], + with_gpu: Union[bool, Optional[int]], + root_gpu: Optional[int], + ) -> None: + """ + Restore model state. + """ + # Design Note: + # model can be acquired with `self.trainer.get_model()`, but it make upstream testing hard. # hook: give user access to checkpoint if needed. model.on_load_checkpoint(checkpoint) @@ -118,14 +142,9 @@ def restore_states(self, checkpoint_path: str, with_gpu: Union[bool, Optional[in # moves the model to the GPU if (with_gpu is True) or ((not isinstance(with_gpu, bool)) and (with_gpu is not None)): - model.cuda(self.trainer.root_gpu) - - # restore training state - self.restore_training_state(checkpoint) - - return checkpoint + model.cuda(root_gpu) - def restore_training_state(self, checkpoint): + def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: """ Restore trainer state. Model will get its change to update diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 34f5d0e78e3ff..35aae07688842 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -184,8 +184,8 @@ def setup_training(self, model: LightningModule): # if cluster resets state, the model will update with the saved weights self.trainer.model = model - # restore training and model before hpc is called - self.trainer.checkpoint_connector.attempt_to_restore() + # restore model/training states before hpc is called + self.trainer.checkpoint_connector.attempt_to_restore(self.trainer.model) # on pretrain routine end self.trainer.on_pretrain_routine_end(ref_model) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 5364fd7ef5402..e6c24b127a3d5 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -115,7 +115,7 @@ def scale_batch_size(trainer, # Restore initial state of model from temporary checkpoint, which is deleted after restore. if trainer.is_global_zero: - trainer.checkpoint_connector.restore_states(str(save_path), trainer.on_gpu) + trainer.checkpoint_connector.restore_states(model, str(save_path), trainer.on_gpu) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index e540759c155ac..e51f7532a06ce 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -192,7 +192,7 @@ def lr_find( # Restore initial state of model from temporary checkpoint, which is deleted after restore. if trainer.is_global_zero: - trainer.checkpoint_connector.restore_states(str(save_path), trainer.on_gpu) + trainer.checkpoint_connector.restore_states(model, str(save_path), trainer.on_gpu) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index 029bdd12aac38..582f38fb6126f 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.core.lightning import LightningModule import torch from pytorch_lightning import Trainer @@ -47,7 +48,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50 trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() -def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, with_hpc: bool = True): +def run_model_test(trainer_options, model: LightningModule, on_gpu: bool = True, version=None, with_hpc: bool = True): reset_seed() save_dir = trainer_options['default_root_dir'] @@ -90,7 +91,7 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi trainer.checkpoint_connector.hpc_save(save_dir, logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir) - checkpoint = trainer.checkpoint_connector.restore_states(checkpoint_path, trainer.root_gpu) + checkpoint = trainer.checkpoint_connector.restore_states(model, checkpoint_path, trainer.root_gpu) trainer.get_model().on_hpc_load(checkpoint) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index b6ebd8fa37bac..6e37af8142c44 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -78,7 +78,7 @@ def run_test_from_config(trainer_options): trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path) - checkpoint = trainer.checkpoint_connector.restore_states(checkpoint_path, trainer.root_gpu) + checkpoint = trainer.checkpoint_connector.restore_states(model, checkpoint_path, trainer.root_gpu) trainer.get_model().on_hpc_load(checkpoint) if args.on_gpu: diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 11ea4f0af9771..0451a1e2a7265 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -93,8 +93,8 @@ def gen_trainer(name_ckpt: Optional[str]) -> Trainer: # Generate checkpoint `last.ckpt` with BoringModel gen_trainer(None).fit(model) # `True` if resume/restore successfully else `False` - assert gen_trainer("last.ckpt").checkpoint_connector.attempt_to_restore() - assert not gen_trainer("last_non_existing.ckpt").checkpoint_connector.attempt_to_restore() + assert gen_trainer("last.ckpt").checkpoint_connector.attempt_to_restore(model) + assert not gen_trainer("last_non_existing.ckpt").checkpoint_connector.attempt_to_restore(model) class CaptureCallbacksBeforeTraining(Callback): From 6d52468f4ffb3d6f156a338d5b74ca8a5d765956 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Wed, 6 Jan 2021 07:04:53 +0900 Subject: [PATCH 15/30] Fix pep8 --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3dc76f1d22ef2..55b7a76e31a38 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -55,7 +55,7 @@ def attempt_to_restore(self, model: LightningModule) -> bool: True if restored else False """ # Design Note: - # `model` can be acquired with `self.trainer.get_model()`, but it make testing hard. + # `model` can be acquired with `self.trainer.get_model()`, but it make testing hard. restored: bool = False @@ -132,7 +132,7 @@ def restore_model_state( Restore model state. """ # Design Note: - # model can be acquired with `self.trainer.get_model()`, but it make upstream testing hard. + # model can be acquired with `self.trainer.get_model()`, but it make upstream testing hard. # hook: give user access to checkpoint if needed. model.on_load_checkpoint(checkpoint) From 2e6b277178a0b6999ea6a982f5e993ec4172a93d Mon Sep 17 00:00:00 2001 From: Tarepan Date: Wed, 6 Jan 2021 07:32:37 +0900 Subject: [PATCH 16/30] Fix trainer setup outside the fit --- .../trainer/connectors/checkpoint_connector.py | 16 +++++++--------- pytorch_lightning/trainer/training_loop.py | 2 +- tests/models/test_restore.py | 8 +++++--- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 55b7a76e31a38..7d0b42b70989f 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -42,23 +42,24 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False - def attempt_to_restore(self, model: LightningModule) -> bool: + def attempt_to_restore(self) -> bool: """ Attempt to restore model/training states in this priority: 1. from HPC weights 2. from `resume_from_checkpoint` file 3. don't restore - Args: - model: Model to which states from checkpoint are applied. Returns: True if restored else False """ - # Design Note: - # `model` can be acquired with `self.trainer.get_model()`, but it make testing hard. + # Development Note: + # prerequisite: + # `trainer.train_loop.setup_training(model)` is needed before this method call. + # It is because Trainer.__init__ do not prepare `model` and `accelerator_backend`. restored: bool = False - + model: LightningModule = self.trainer.get_model() + # clear cache before restore if self.trainer.on_gpu: torch.cuda.empty_cache() @@ -131,9 +132,6 @@ def restore_model_state( """ Restore model state. """ - # Design Note: - # model can be acquired with `self.trainer.get_model()`, but it make upstream testing hard. - # hook: give user access to checkpoint if needed. model.on_load_checkpoint(checkpoint) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 35aae07688842..17d53284335b2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -185,7 +185,7 @@ def setup_training(self, model: LightningModule): self.trainer.model = model # restore model/training states before hpc is called - self.trainer.checkpoint_connector.attempt_to_restore(self.trainer.model) + self.trainer.checkpoint_connector.attempt_to_restore() # on pretrain routine end self.trainer.on_pretrain_routine_end(ref_model) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 0451a1e2a7265..d9fd4a81b8afe 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -81,7 +81,7 @@ def gen_trainer(name_ckpt: Optional[str]) -> Trainer: path_dir_saved = tmpdir path_file_loaded = None if name_ckpt is None else str(tmpdir / name_ckpt) checkpoint_cb = ModelCheckpoint(dirpath=path_dir_saved, monitor="early_stop_on", save_last=True) - return Trainer( + trainer = Trainer( resume_from_checkpoint=path_file_loaded, max_epochs=1, logger=False, @@ -89,12 +89,14 @@ def gen_trainer(name_ckpt: Optional[str]) -> Trainer: limit_train_batches=0.1, limit_val_batches=0.1, ) + trainer.train_loop.setup_training(model) + return trainer # Generate checkpoint `last.ckpt` with BoringModel gen_trainer(None).fit(model) # `True` if resume/restore successfully else `False` - assert gen_trainer("last.ckpt").checkpoint_connector.attempt_to_restore(model) - assert not gen_trainer("last_non_existing.ckpt").checkpoint_connector.attempt_to_restore(model) + assert gen_trainer("last.ckpt").checkpoint_connector.attempt_to_restore() + assert not gen_trainer("last_non_existing.ckpt").checkpoint_connector.attempt_to_restore() class CaptureCallbacksBeforeTraining(Callback): From fa03af44d2b9cd002ce1a8208cb883d692c465ba Mon Sep 17 00:00:00 2001 From: Tarepan Date: Wed, 6 Jan 2021 08:15:40 +0900 Subject: [PATCH 17/30] Refactor big method with responsibility --- .../connectors/checkpoint_connector.py | 53 +++++++++++-------- tests/models/test_restore.py | 8 ++- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 7d0b42b70989f..fcecc96d3fd15 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -42,29 +42,43 @@ def __init__(self, trainer): # used to validate checkpointing logic self.has_trained = False - def attempt_to_restore(self) -> bool: + def attempt_to_restore(self) -> None: """ - Attempt to restore model/training states in this priority: + Attempt to restore model/training states. + """ + + # clear cache before restore + if self.trainer.on_gpu: + torch.cuda.empty_cache() + + # attempt to restore states + model: LightningModule = self.trainer.get_model() + self.attempt_to_apply_checkpoint(model) + + # wait for all to catch up + self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') + + # clear cache after restore + if self.trainer.on_gpu: + torch.cuda.empty_cache() + + def attempt_to_apply_checkpoint(self, model: LightningModule) -> bool: + """ + Attempt to apply checkpoint states to model/training in this priority: 1. from HPC weights 2. from `resume_from_checkpoint` file - 3. don't restore + 3. don't apply Returns: - True if restored else False + True if applied else False """ - # Development Note: - # prerequisite: - # `trainer.train_loop.setup_training(model)` is needed before this method call. - # It is because Trainer.__init__ do not prepare `model` and `accelerator_backend`. + # Design Note: + # `attempt_to_restore` has responsibility to whole state restoration flow (e.g. OOM, parallel processing). + # This method has responsibility to applying/assigning state value from nullable checkpoint. restored: bool = False - model: LightningModule = self.trainer.get_model() - # clear cache before restore - if self.trainer.on_gpu: - torch.cuda.empty_cache() - - # 1. Attempt to restore states from HPC checkpoint. + # 1. Attempt to apply HPC checkpoint. dir_path_hpc = str(self.trainer.weights_save_path) max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") if max_suffix is not None: @@ -74,7 +88,7 @@ def attempt_to_restore(self) -> bool: restored = True rank_zero_info(f'restored hpc model from: {checkpoint_path}') - # 2. Attempt to restore states from `resume_from_checkpoint` file. + # 2. Attempt to apply `resume_from_checkpoint` file. elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing: adress_checkpoint: str = self.trainer.resume_from_checkpoint if get_filesystem(adress_checkpoint).exists(adress_checkpoint): @@ -84,17 +98,10 @@ def attempt_to_restore(self) -> bool: else: rank_zero_warn(f"checkpoint file at {adress_checkpoint} does not exist.") - # 3. Do not restore, start from scratch. + # 3. Do not apply, start from scratch. else: rank_zero_info("Start from scratch.") - # wait for all to catch up - self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') - - # clear cache after restore - if self.trainer.on_gpu: - torch.cuda.empty_cache() - return restored def restore_states( diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index d9fd4a81b8afe..c2202dfad1625 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -81,7 +81,7 @@ def gen_trainer(name_ckpt: Optional[str]) -> Trainer: path_dir_saved = tmpdir path_file_loaded = None if name_ckpt is None else str(tmpdir / name_ckpt) checkpoint_cb = ModelCheckpoint(dirpath=path_dir_saved, monitor="early_stop_on", save_last=True) - trainer = Trainer( + return Trainer( resume_from_checkpoint=path_file_loaded, max_epochs=1, logger=False, @@ -89,14 +89,12 @@ def gen_trainer(name_ckpt: Optional[str]) -> Trainer: limit_train_batches=0.1, limit_val_batches=0.1, ) - trainer.train_loop.setup_training(model) - return trainer # Generate checkpoint `last.ckpt` with BoringModel gen_trainer(None).fit(model) # `True` if resume/restore successfully else `False` - assert gen_trainer("last.ckpt").checkpoint_connector.attempt_to_restore() - assert not gen_trainer("last_non_existing.ckpt").checkpoint_connector.attempt_to_restore() + assert gen_trainer("last.ckpt").checkpoint_connector.attempt_to_apply_checkpoint(model) + assert not gen_trainer("last_non_existing.ckpt").checkpoint_connector.attempt_to_apply_checkpoint(model) class CaptureCallbacksBeforeTraining(Callback): From c1668767b6af96da4221155d8aa44071f9799fb3 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Wed, 6 Jan 2021 08:20:45 +0900 Subject: [PATCH 18/30] Fix pip8 --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index fcecc96d3fd15..83255ef33d53e 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -77,7 +77,7 @@ def attempt_to_apply_checkpoint(self, model: LightningModule) -> bool: # This method has responsibility to applying/assigning state value from nullable checkpoint. restored: bool = False - + # 1. Attempt to apply HPC checkpoint. dir_path_hpc = str(self.trainer.weights_save_path) max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") From 430b5f8a04c576a912a180920ffca50c96140164 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Wed, 6 Jan 2021 08:35:01 +0900 Subject: [PATCH 19/30] Refactor for diff alignment --- .../trainer/connectors/checkpoint_connector.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 83255ef33d53e..62af4d689213e 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -46,7 +46,6 @@ def attempt_to_restore(self) -> None: """ Attempt to restore model/training states. """ - # clear cache before restore if self.trainer.on_gpu: torch.cuda.empty_cache() @@ -124,18 +123,13 @@ def restore_states( # restore states if self.trainer.datamodule is not None: self.trainer.datamodule.on_load_checkpoint(checkpoint) - self.restore_model_state(model, checkpoint, with_gpu, self.trainer.root_gpu) + self.restore_model_state(checkpoint, model, with_gpu) self.restore_training_state(checkpoint) return checkpoint - def restore_model_state( - self, - model: LightningModule, - checkpoint: Dict[str, Any], - with_gpu: Union[bool, Optional[int]], - root_gpu: Optional[int], - ) -> None: + def restore_model_state(self, checkpoint: Dict[str, Any], model: LightningModule, + with_gpu: Union[bool, Optional[int]]) -> None: """ Restore model state. """ @@ -147,7 +141,7 @@ def restore_model_state( # moves the model to the GPU if (with_gpu is True) or ((not isinstance(with_gpu, bool)) and (with_gpu is not None)): - model.cuda(root_gpu) + model.cuda(self.trainer.root_gpu) def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: """ From d9063be1a09eac45b898369b61ddb06636ae5012 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Wed, 6 Jan 2021 10:15:59 +0900 Subject: [PATCH 20/30] Link upstream issue: #5370 From 5387267029329a224f3809b47e1de9f85c810c27 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Wed, 6 Jan 2021 10:25:55 +0900 Subject: [PATCH 21/30] Fix pep8 --- .../trainer/connectors/checkpoint_connector.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 62af4d689213e..4de1734e3137c 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -128,8 +128,12 @@ def restore_states( return checkpoint - def restore_model_state(self, checkpoint: Dict[str, Any], model: LightningModule, - with_gpu: Union[bool, Optional[int]]) -> None: + def restore_model_state( + self, + checkpoint: Dict[str, Any], + model: LightningModule, + with_gpu: Union[bool, Optional[int]] + ) -> None: """ Restore model state. """ From e894f98b9dee3c64e401c6946226d02e351a77e5 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Thu, 7 Jan 2021 15:27:42 +0900 Subject: [PATCH 22/30] Fix type description without functional change --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 4de1734e3137c..e5e754e7f55e7 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -107,7 +107,7 @@ def restore_states( self, model: LightningModule, checkpoint_path: str, - with_gpu: Union[bool, Optional[int]], + with_gpu: Optional[Union[int, bool]], ) -> Dict[str, Any]: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. From ae13e0c62ca8259c6e9faa80d45f737b200f3643 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 19 Jan 2021 16:00:49 +0900 Subject: [PATCH 23/30] Refactor with_gpu type with simple typing Normal and HPC load now use common GPU type check (#5300). Now that there is no needs of accepting both bool and int. --- .../trainer/connectors/checkpoint_connector.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 2eb6127683401..a0d824dfe515d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -114,15 +114,14 @@ def restore_states( self, model: LightningModule, checkpoint_path: str, - with_gpu: Optional[Union[int, bool]], + on_gpu: bool, ) -> Dict[str, Any]: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. - `with_gpu=trainer.on_gpu` works as normal restore, `with_gpu=trainer.root_gpu` works as hpc restore. Args: - with_gpu: bool for `on_gpu`, Optional[int] for `trainer.root_gpu`. + on_gpu: Whether trainer is on GPU or not. """ # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` checkpoint: Dict[str, Any] = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) @@ -130,7 +129,7 @@ def restore_states( # restore states if self.trainer.datamodule is not None: self.trainer.datamodule.on_load_checkpoint(checkpoint) - self.restore_model_state(checkpoint, model, with_gpu) + self.restore_model_state(checkpoint, model, on_gpu) self.restore_training_state(checkpoint) return checkpoint @@ -139,7 +138,7 @@ def restore_model_state( self, checkpoint: Dict[str, Any], model: LightningModule, - with_gpu: Union[bool, Optional[int]] + on_gpu: bool, ) -> None: """ Restore model state. @@ -151,7 +150,7 @@ def restore_model_state( model.load_state_dict(checkpoint['state_dict']) # moves the model to the GPU - if (with_gpu is True) or ((not isinstance(with_gpu, bool)) and (with_gpu is not None)): + if (on_gpu is True) or ((not isinstance(on_gpu, bool)) and (on_gpu is not None)): model.cuda(self.trainer.root_gpu) def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: From 1f1d8176d68402d957f9c161e1c132322a052164 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 19 Jan 2021 16:08:49 +0900 Subject: [PATCH 24/30] Refactor comment format --- .../connectors/checkpoint_connector.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index a0d824dfe515d..c518e14c40ebc 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -50,8 +50,7 @@ def __init__(self, trainer): self.has_trained = False def attempt_to_restore(self) -> None: - """ - Attempt to restore model/training states. + """Attempt to restore model/training states. """ # clear cache before restore if self.trainer._device_type == DeviceType.GPU: @@ -69,11 +68,12 @@ def attempt_to_restore(self) -> None: torch.cuda.empty_cache() def attempt_to_apply_checkpoint(self, model: LightningModule) -> bool: - """ - Attempt to apply checkpoint states to model/training in this priority: - 1. from HPC weights - 2. from `resume_from_checkpoint` file - 3. don't apply + """Attempt to apply checkpoint states to model/training with priority. + + Priority: + 1. from HPC weights + 2. from `resume_from_checkpoint` file + 3. don't apply Returns: True if applied else False @@ -116,7 +116,8 @@ def restore_states( checkpoint_path: str, on_gpu: bool, ) -> Dict[str, Any]: - """ + """Restore all states from checkpoint in the specified path. + Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. @@ -140,8 +141,7 @@ def restore_model_state( model: LightningModule, on_gpu: bool, ) -> None: - """ - Restore model state. + """Restore model state. """ # hook: give user access to checkpoint if needed. model.on_load_checkpoint(checkpoint) @@ -154,8 +154,8 @@ def restore_model_state( model.cuda(self.trainer.root_gpu) def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: - """ - Restore trainer state. + """Restore trainer state. + Model will get its change to update :param checkpoint: :return: From 86f9c784ea7798512527b8db76cbb749ffa03fc1 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 19 Jan 2021 16:22:36 +0900 Subject: [PATCH 25/30] Fix isort --- tests/base/develop_pipelines.py | 2 +- tests/models/test_restore.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index ed60132613b6d..f173859093417 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.core.lightning import LightningModule import torch from pytorch_lightning import Trainer +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import DistributedType from tests.base import BoringModel diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index cdf9ea7f27520..61dc3bf48e9c2 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -14,9 +14,9 @@ import glob import logging as log import os -from pathlib import Path import pickle from copy import deepcopy +from pathlib import Path from typing import Optional import cloudpickle From c020cb47b0076c51fb593f888479f9f8b62e3218 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 19 Jan 2021 16:25:03 +0900 Subject: [PATCH 26/30] Fix pep8 --- .../trainer/connectors/checkpoint_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c518e14c40ebc..8dbbcb3422cd2 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -69,7 +69,7 @@ def attempt_to_restore(self) -> None: def attempt_to_apply_checkpoint(self, model: LightningModule) -> bool: """Attempt to apply checkpoint states to model/training with priority. - + Priority: 1. from HPC weights 2. from `resume_from_checkpoint` file @@ -117,7 +117,7 @@ def restore_states( on_gpu: bool, ) -> Dict[str, Any]: """Restore all states from checkpoint in the specified path. - + Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. @@ -155,7 +155,7 @@ def restore_model_state( def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: """Restore trainer state. - + Model will get its change to update :param checkpoint: :return: From 2cf5488c37eb09d1ade29d3c960459b6d6d44178 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Tue, 19 Jan 2021 16:36:29 +0900 Subject: [PATCH 27/30] Refactor too much type guard --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 8dbbcb3422cd2..f5d8e441495df 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -150,7 +150,7 @@ def restore_model_state( model.load_state_dict(checkpoint['state_dict']) # moves the model to the GPU - if (on_gpu is True) or ((not isinstance(on_gpu, bool)) and (on_gpu is not None)): + if on_gpu: model.cuda(self.trainer.root_gpu) def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: From d08bfca2c665b36b1f3574c124d609fd864f37dc Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 29 Jan 2021 11:13:15 +0100 Subject: [PATCH 28/30] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f029aeadb6d62..664b1be3c8469 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516)) +- Refactored `hpc_load` and entangled logics in `CheckpointConnector` ([#5371](https://github.com/PyTorchLightning/pytorch-lightning/pull/5371)) + + ### Deprecated - `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) From 54296c5d52a62984e4314ded011c9eb6e6324c0d Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 1 Feb 2021 15:29:54 +0100 Subject: [PATCH 29/30] amp --- pytorch_lightning/overrides/data_parallel.py | 4 ++-- pytorch_lightning/plugins/base_plugin.py | 3 ++- pytorch_lightning/plugins/precision/native_amp.py | 10 ++++++++-- .../plugins/precision/precision_plugin.py | 3 ++- .../plugins/precision/sharded_native_amp.py | 1 + 5 files changed, 15 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 8d1710e471197..b027502f99e8a 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -30,8 +30,7 @@ class LightningDataParallel(DataParallel): def __init__(self, module: LightningModule, *args, **kwargs): warnings.warn( "The usage of `LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4." - " From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.", - DeprecationWarning + " From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.", DeprecationWarning ) super().__init__(LightningParallelModule(module), *args, **kwargs) @@ -67,6 +66,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): pl_module: the model to wrap """ + def __init__(self, pl_module: LightningModule): super().__init__(pl_module) diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py index 0160afa559496..bca155b750047 100644 --- a/pytorch_lightning/plugins/base_plugin.py +++ b/pytorch_lightning/plugins/base_plugin.py @@ -22,7 +22,8 @@ class Plugin(ABC): """Basic Plugin class to derive precision and training type plugins from.""" @abstractmethod - def connect(self, model: torch.nn.Module, *args: Sequence, **kwargs: Sequence) -> Optional[Tuple[torch.nn.Module, Sequence, Sequence]]: + def connect(self, model: torch.nn.Module, *args: Sequence, + **kwargs: Sequence) -> Optional[Tuple[torch.nn.Module, Sequence, Sequence]]: """Connects the plugin with the accelerator (and thereby with trainer and model). Will be called by the accelerator. """ diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index daba223169fc6..8cdaba833af85 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -18,11 +18,17 @@ from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin -from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException +if _NATIVE_AMP_AVAILABLE: + from torch.cuda.amp import autocast +else: + autocast = None + class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): + def __init__(self): self.backend = AMPType.NATIVE self.scaler = torch.cuda.amp.GradScaler() @@ -74,6 +80,6 @@ def backward( return closure_loss @contextmanager - def train_step_context(self) -> Generator[torch.cuda.amp.autocast, None, None]: + def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" yield torch.cuda.amp.autocast() diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 031b588737614..3e74442e92277 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -37,7 +37,8 @@ def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Ten for p in group["params"]: yield p - def connect(self, model: torch.nn.Module, optimizers: Sequence, lr_schedulers: Sequence) -> Tuple[torch.nn.Module, Sequence, Sequence]: + def connect(self, model: torch.nn.Module, optimizers: Sequence, + lr_schedulers: Sequence) -> Tuple[torch.nn.Module, Sequence, Sequence]: """Connects this plugin to the accelerator and the training process""" return model, optimizers, lr_schedulers diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index ef8e1b8a95efe..b3b01fc720d2b 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -26,6 +26,7 @@ class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): """Mixed Precision for Sharded Training """ + def __init__(self): super().__init__() self.scaler = ShardedGradScaler() From d10fba8a377dc44872d481528c7615481ddc4442 Mon Sep 17 00:00:00 2001 From: Tarepan Date: Fri, 5 Feb 2021 08:29:37 +0900 Subject: [PATCH 30/30] Fix yapf --- .../trainer/connectors/checkpoint_connector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 1c4ce2ce42c03..1144d3e342da2 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -111,10 +111,10 @@ def attempt_to_apply_checkpoint(self, model: LightningModule) -> bool: return restored def restore_states( - self, - model: LightningModule, - checkpoint_path: str, - on_gpu: bool, + self, + model: LightningModule, + checkpoint_path: str, + on_gpu: bool, ) -> Dict[str, Any]: """Restore all states from checkpoint in the specified path.