From 3b3cd7eec2dfda7941d45746057a5e9e630bd9b0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 00:24:45 +0100 Subject: [PATCH 01/11] Reset the tuner state with the checkpoint connector --- .../connectors/checkpoint_connector.py | 4 +- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/tuner/batch_size_scaling.py | 75 ++++++----------- pytorch_lightning/tuner/lr_finder.py | 83 ++++++++----------- pytorch_lightning/tuner/tuning.py | 2 + tests/tuner/test_lr_finder.py | 1 + 6 files changed, 65 insertions(+), 102 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 125548471b529..50bf730dd180b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -222,8 +222,8 @@ def restore_loops(self) -> None: assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") - if state_dict is not None and self.trainer.state.fn != TrainerFn.TUNING: - if self.trainer.state.fn == TrainerFn.FITTING: + if state_dict is not None: + if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) elif self.trainer.state.fn == TrainerFn.VALIDATING: self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 39cadb7f9e7ef..4830b3a33c3c1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2075,7 +2075,7 @@ def model(self, model: torch.nn.Module) -> None: model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending on the backend. """ - self.strategy.model = model + self.strategy.connect(model) """ General properties diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 84467310568f7..aed098fecb683 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License import logging -import os -import uuid -from typing import Optional, Tuple +from copy import deepcopy +from typing import Any, Dict, Optional, Tuple from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error @@ -59,15 +56,15 @@ def scale_batch_size( " Please disable the feature or incorporate the dataloader into the model." ) - # Arguments we adjust during the batch size finder, save for restoring - __scale_batch_dump_params(trainer) + trainer.fit_loop.current_epoch -= 1 + trainer.fit_loop.global_step -= 1 + state_dict = deepcopy(trainer.checkpoint_connector.dump_checkpoint()) + trainer.fit_loop.current_epoch += 1 + trainer.fit_loop.global_step += 1 + params = __scale_batch_dump_params(trainer) # Set to values that are required by the algorithm - __scale_batch_reset_params(trainer, model, steps_per_trial) - - # Save initial model, that is loaded after batch size is found - save_path = os.path.join(trainer.default_root_dir, f"scale_batch_size_temp_model_{uuid.uuid4()}.ckpt") - trainer.save_checkpoint(str(save_path)) + __scale_batch_reset_params(trainer, steps_per_trial) if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() @@ -85,59 +82,35 @@ def scale_batch_size( log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}") # Restore initial state of model - if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path)) - fs = get_filesystem(str(save_path)) - if fs.exists(save_path): - fs.rm(save_path) - - # Finish by resetting variables so trainer is ready to fit model - __scale_batch_restore_params(trainer) - if trainer.progress_bar_callback: - trainer.progress_bar_callback.enable() + trainer.checkpoint_connector._loaded_checkpoint = state_dict + trainer.checkpoint_connector.restore(None) + __scale_batch_restore_params(trainer, params) return new_size -def __scale_batch_dump_params(trainer: "pl.Trainer") -> None: - # Prevent going into infinite loop - trainer.__dumped_params = { - "auto_lr_find": trainer.auto_lr_find, - "current_epoch": trainer.current_epoch, - "global_step": trainer.global_step, - "max_steps": trainer.max_steps, - "logger": trainer.logger, - "callbacks": trainer.callbacks, - "checkpoint_callback": trainer.checkpoint_callback, +def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: + return { "auto_scale_batch_size": trainer.auto_scale_batch_size, + "auto_lr_find": trainer.auto_lr_find, + "max_steps": trainer.fit_loop.max_steps, "limit_train_batches": trainer.limit_train_batches, - "model": trainer.model, } -def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule", steps_per_trial: int) -> None: +def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None: trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times trainer.fit_loop.current_epoch = 0 trainer.fit_loop.max_steps = steps_per_trial # take few steps - trainer.logger = DummyLogger() if trainer.logger is not None else None - trainer.callbacks = [] # not needed before full run trainer.limit_train_batches = 1.0 - trainer.optimizers, trainer.lr_schedulers = [], [] # required for saving - trainer.model = model # required for saving - - -def __scale_batch_restore_params(trainer: "pl.Trainer") -> None: - trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"] - trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"] - trainer.fit_loop.global_step = trainer.__dumped_params["global_step"] - trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"] - trainer.logger = trainer.__dumped_params["logger"] - trainer.callbacks = trainer.__dumped_params["callbacks"] - trainer.auto_scale_batch_size = trainer.__dumped_params["auto_scale_batch_size"] - trainer.limit_train_batches = trainer.__dumped_params["limit_train_batches"] - trainer.model = trainer.__dumped_params["model"] - del trainer.__dumped_params + + +def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: + trainer.auto_scale_batch_size = params["auto_scale_batch_size"] + trainer.auto_lr_find = params["auto_lr_find"] + trainer.fit_loop.max_steps = params["max_steps"] + trainer.limit_train_batches = params["limit_train_batches"] def _run_power_scaling( diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 7bf1bcf34ed96..265afb825e60c 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -13,10 +13,9 @@ # limitations under the License. import importlib import logging -import os -import uuid +from copy import deepcopy from functools import wraps -from typing import Optional, Sequence +from typing import Any, Dict, Optional, Sequence import numpy as np import torch @@ -31,7 +30,6 @@ ) from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr @@ -208,36 +206,23 @@ def lr_find( if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) - save_path = os.path.join(trainer.default_root_dir, f"lr_find_temp_model_{uuid.uuid4()}.ckpt") + trainer.fit_loop.current_epoch -= 1 + trainer.fit_loop.global_step -= 1 + state_dict = deepcopy(trainer.checkpoint_connector.dump_checkpoint()) + trainer.fit_loop.current_epoch += 1 + trainer.fit_loop.global_step += 1 + params = __lr_finder_dump_params(trainer) - __lr_finder_dump_params(trainer, model) - - # Prevent going into infinite loop - trainer.auto_lr_find = False + # Set to values that are required by the algorithm + __lr_finder_reset_params(trainer, num_training, early_stop_threshold) # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) - # Use special lr logger callback - trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] - - # No logging - trainer.logger = DummyLogger() if trainer.logger is not None else None - - # Max step set to number of iterations - trainer.fit_loop.max_steps = num_training - # Disable standard progress bar for fit if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() - # Required for saving the model - trainer.optimizers, trainer.lr_schedulers = [], [] - trainer.model = model - - # Dump model checkpoint - trainer.save_checkpoint(str(save_path)) - # Configure optimizer and scheduler trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) @@ -252,15 +237,11 @@ def lr_find( lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses}) lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose - # Reset model state - if trainer.is_global_zero: - trainer.checkpoint_connector.restore(str(save_path)) - fs = get_filesystem(str(save_path)) - if fs.exists(save_path): - fs.rm(save_path) + # Restore initial state of model + trainer.checkpoint_connector._loaded_checkpoint = state_dict + trainer.checkpoint_connector.restore(None) + __lr_finder_restore_params(trainer, params) - # Finish by resetting variables so trainer is ready to fit model - __lr_finder_restore_params(trainer, model) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() @@ -275,27 +256,33 @@ def lr_find( return lr_finder -def __lr_finder_dump_params(trainer, model): - # Prevent going into infinite loop - trainer.__dumped_params = { +def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: + return { "auto_lr_find": trainer.auto_lr_find, "callbacks": trainer.callbacks, "logger": trainer.logger, - "global_step": trainer.global_step, - "max_steps": trainer.max_steps, - "checkpoint_callback": trainer.checkpoint_callback, - "current_epoch": trainer.current_epoch, + "max_steps": trainer.fit_loop.max_steps, } -def __lr_finder_restore_params(trainer, model): - trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"] - trainer.logger = trainer.__dumped_params["logger"] - trainer.callbacks = trainer.__dumped_params["callbacks"] - trainer.fit_loop.global_step = trainer.__dumped_params["global_step"] - trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"] - trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"] - del trainer.__dumped_params +def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None: + # avoid lr find being called multiple times + trainer.auto_lr_find = False + # Use special lr logger callback + trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] + # No logging + trainer.logger = DummyLogger() if trainer.logger is not None else None + # Max step set to number of iterations + trainer.fit_loop.max_steps = num_training + # Required for saving the model + trainer.optimizers, trainer.lr_schedulers = [], [] + + +def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: + trainer.auto_lr_find = params["auto_lr_find"] + trainer.callbacks = params["callbacks"] + trainer.logger = params["logger"] + trainer.fit_loop.max_steps = params["max_steps"] class _LRCallback(Callback): diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 6a3a6b66d0644..f64183a92bc1c 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -41,6 +41,8 @@ def _tune( # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added result = {} + self.trainer.strategy.connect(model) + # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, str): diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index cad4bb2a786d2..b866f877d1c19 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -343,6 +343,7 @@ def training_step_end(self, outputs): trainer.tuner.lr_find(model=model, num_training=num_training) +@pytest.mark.xfail def test_multiple_lr_find_calls_gives_same_results(tmpdir): """Tests that lr_finder gives same results if called multiple times.""" seed_everything(1) From 28515251a7265146e5fc96d07b3ddbfbc5f1cdcc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 00:29:48 +0100 Subject: [PATCH 02/11] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 63ce5cf8a6574..943ec1e765996 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -213,6 +213,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270)) +- The tuner now uses the checkpoint connector to copy and restore its state ([#11518](https://github.com/PyTorchLightning/pytorch-lightning/pull/11518)) + + - Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360)) From 33a472d3a97cb41271a3d98a70bbf49b3318d153 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 17:27:22 +0100 Subject: [PATCH 03/11] Save to file --- pytorch_lightning/tuner/batch_size_scaling.py | 9 +++++---- pytorch_lightning/tuner/lr_finder.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index aed098fecb683..fcc987ca02d78 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License import logging -from copy import deepcopy +import os +import uuid from typing import Any, Dict, Optional, Tuple from torch.utils.data import DataLoader @@ -56,9 +57,10 @@ def scale_batch_size( " Please disable the feature or incorporate the dataloader into the model." ) + ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") trainer.fit_loop.current_epoch -= 1 trainer.fit_loop.global_step -= 1 - state_dict = deepcopy(trainer.checkpoint_connector.dump_checkpoint()) + trainer.save_checkpoint(ckpt_path) trainer.fit_loop.current_epoch += 1 trainer.fit_loop.global_step += 1 params = __scale_batch_dump_params(trainer) @@ -82,8 +84,7 @@ def scale_batch_size( log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}") # Restore initial state of model - trainer.checkpoint_connector._loaded_checkpoint = state_dict - trainer.checkpoint_connector.restore(None) + trainer.checkpoint_connector.restore(ckpt_path) __scale_batch_restore_params(trainer, params) return new_size diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 265afb825e60c..11476b07b8314 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -13,7 +13,8 @@ # limitations under the License. import importlib import logging -from copy import deepcopy +import os +import uuid from functools import wraps from typing import Any, Dict, Optional, Sequence @@ -206,9 +207,10 @@ def lr_find( if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) + ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt") trainer.fit_loop.current_epoch -= 1 trainer.fit_loop.global_step -= 1 - state_dict = deepcopy(trainer.checkpoint_connector.dump_checkpoint()) + trainer.save_checkpoint(ckpt_path) trainer.fit_loop.current_epoch += 1 trainer.fit_loop.global_step += 1 params = __lr_finder_dump_params(trainer) @@ -238,8 +240,7 @@ def lr_find( lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose # Restore initial state of model - trainer.checkpoint_connector._loaded_checkpoint = state_dict - trainer.checkpoint_connector.restore(None) + trainer.checkpoint_connector.restore(ckpt_path) __lr_finder_restore_params(trainer, params) if trainer.progress_bar_callback: From f84ed775b9ca6de51498429e8930399a8583be60 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 17:30:03 +0100 Subject: [PATCH 04/11] Self review --- pytorch_lightning/tuner/batch_size_scaling.py | 4 ++++ pytorch_lightning/tuner/lr_finder.py | 1 + 2 files changed, 5 insertions(+) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index fcc987ca02d78..f431f6e50d203 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -57,6 +57,7 @@ def scale_batch_size( " Please disable the feature or incorporate the dataloader into the model." ) + # Save initial model, that is loaded after batch size is found ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") trainer.fit_loop.current_epoch -= 1 trainer.fit_loop.global_step -= 1 @@ -87,6 +88,9 @@ def scale_batch_size( trainer.checkpoint_connector.restore(ckpt_path) __scale_batch_restore_params(trainer, params) + if trainer.progress_bar_callback: + trainer.progress_bar_callback.enable() + return new_size diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 11476b07b8314..f9847a0a70bc4 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -207,6 +207,7 @@ def lr_find( if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) + # Save initial model, that is loaded after learning rate is found ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt") trainer.fit_loop.current_epoch -= 1 trainer.fit_loop.global_step -= 1 From e9d95a64a6c320bf82e11b653c381dde7ca591b7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 17:35:11 +0100 Subject: [PATCH 05/11] Forgot to remove --- pytorch_lightning/tuner/batch_size_scaling.py | 1 + pytorch_lightning/tuner/lr_finder.py | 1 + tests/tuner/test_lr_finder.py | 2 +- tests/tuner/test_scale_batch_size.py | 2 +- 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index f431f6e50d203..e1da86705c118 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -86,6 +86,7 @@ def scale_batch_size( # Restore initial state of model trainer.checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) __scale_batch_restore_params(trainer, params) if trainer.progress_bar_callback: diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index f9847a0a70bc4..36e89d21cba57 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -242,6 +242,7 @@ def lr_find( # Restore initial state of model trainer.checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) __lr_finder_restore_params(trainer, params) if trainer.progress_bar_callback: diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index b866f877d1c19..146af4d66172d 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -65,7 +65,7 @@ def test_model_reset_correctly(tmpdir): torch.eq(before_state_dict[key], after_state_dict[key]) ), "Model was not reset correctly after learning rate finder" - assert not any(f for f in os.listdir(tmpdir) if f.startswith("lr_find_temp_model")) + assert not any(f for f in os.listdir(tmpdir) if f.startswith(".lr_find")) def test_trainer_reset_correctly(tmpdir): diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index c1d1de052d9c6..31d3dd3dd3c19 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -92,7 +92,7 @@ def test_model_reset_correctly(tmpdir): torch.eq(before_state_dict[key], after_state_dict[key]) ), "Model was not reset correctly after scaling batch size" - assert not any(f for f in os.listdir(tmpdir) if f.startswith("scale_batch_size_temp_model")) + assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size")) def test_trainer_reset_correctly(tmpdir): From 057edb0d398b16c265247644b65a696beb63a6d5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 17:38:30 +0100 Subject: [PATCH 06/11] Self review --- pytorch_lightning/tuner/batch_size_scaling.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index e1da86705c118..c15a010ddcaa8 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -19,6 +19,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl +from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -97,6 +98,8 @@ def scale_batch_size( def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: return { + "logger": trainer.logger, + "callbacks": trainer.callbacks, "auto_scale_batch_size": trainer.auto_scale_batch_size, "auto_lr_find": trainer.auto_lr_find, "max_steps": trainer.fit_loop.max_steps, @@ -108,6 +111,8 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times trainer.fit_loop.current_epoch = 0 + trainer.logger = DummyLogger() if trainer.logger is not None else None + trainer.callbacks = [] # not needed before full run trainer.fit_loop.max_steps = steps_per_trial # take few steps trainer.limit_train_batches = 1.0 @@ -116,6 +121,8 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) trainer.auto_scale_batch_size = params["auto_scale_batch_size"] trainer.auto_lr_find = params["auto_lr_find"] trainer.fit_loop.max_steps = params["max_steps"] + trainer.logger = params["logger"] + trainer.callbacks = params["callbacks"] trainer.limit_train_batches = params["limit_train_batches"] From a63cdff5cf9f2fa76682f374919d85ebb772075e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 17:40:21 +0100 Subject: [PATCH 07/11] Self review --- pytorch_lightning/tuner/batch_size_scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index c15a010ddcaa8..54fb227f06014 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -100,9 +100,9 @@ def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: return { "logger": trainer.logger, "callbacks": trainer.callbacks, + "max_steps": trainer.fit_loop.max_steps, "auto_scale_batch_size": trainer.auto_scale_batch_size, "auto_lr_find": trainer.auto_lr_find, - "max_steps": trainer.fit_loop.max_steps, "limit_train_batches": trainer.limit_train_batches, } @@ -111,9 +111,9 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times trainer.fit_loop.current_epoch = 0 + trainer.fit_loop.max_steps = steps_per_trial # take few steps trainer.logger = DummyLogger() if trainer.logger is not None else None trainer.callbacks = [] # not needed before full run - trainer.fit_loop.max_steps = steps_per_trial # take few steps trainer.limit_train_batches = 1.0 From ca830ac628cc85a6d89809260579467e25fc2d15 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 20:26:39 +0100 Subject: [PATCH 08/11] Keep order --- pytorch_lightning/tuner/batch_size_scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 54fb227f06014..788395f676027 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -98,9 +98,9 @@ def scale_batch_size( def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: return { + "max_steps": trainer.fit_loop.max_steps, "logger": trainer.logger, "callbacks": trainer.callbacks, - "max_steps": trainer.fit_loop.max_steps, "auto_scale_batch_size": trainer.auto_scale_batch_size, "auto_lr_find": trainer.auto_lr_find, "limit_train_batches": trainer.limit_train_batches, From 985c9bdb9a505024bf34ac45f39de41b382a7ed9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 18 Jan 2022 20:33:22 +0100 Subject: [PATCH 09/11] Fix xfail --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- pytorch_lightning/tuner/lr_finder.py | 2 +- tests/tuner/test_lr_finder.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b551e1cf1a014..5c437bfd889b2 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -223,7 +223,7 @@ def restore_loops(self) -> None: assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") if state_dict is not None: - if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if self.trainer.state.fn == TrainerFn.FITTING: self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) elif self.trainer.state.fn == TrainerFn.VALIDATING: self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index b540e2142e314..d87926add3cb0 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -273,7 +273,7 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto # Max step set to number of iterations trainer.fit_loop.max_steps = num_training # Required for saving the model - trainer.optimizers, trainer.lr_schedulers = [], [] + trainer.optimizers, trainer.strategy.lr_schedulers = [], [] def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index 146af4d66172d..62d729d3d414d 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -343,7 +343,6 @@ def training_step_end(self, outputs): trainer.tuner.lr_find(model=model, num_training=num_training) -@pytest.mark.xfail def test_multiple_lr_find_calls_gives_same_results(tmpdir): """Tests that lr_finder gives same results if called multiple times.""" seed_everything(1) From a7e5a88697e85ca12636273c4dcbdb2441be2a59 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 19 Jan 2022 17:33:17 +0100 Subject: [PATCH 10/11] Remove deletion of optimizers and lr schedulers --- pytorch_lightning/tuner/lr_finder.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index d87926add3cb0..0be49535e0513 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -272,8 +272,6 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto trainer.logger = DummyLogger() if trainer.logger is not None else None # Max step set to number of iterations trainer.fit_loop.max_steps = num_training - # Required for saving the model - trainer.optimizers, trainer.strategy.lr_schedulers = [], [] def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: From b2e6d13792b196582c665f07994f621f0433dcec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 20 Jan 2022 19:19:32 +0100 Subject: [PATCH 11/11] Update pytorch_lightning/trainer/trainer.py --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1687504d17967..4e19f16b29c6d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2074,7 +2074,7 @@ def model(self, model: torch.nn.Module) -> None: model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending on the backend. """ - self.strategy.connect(model) + self.strategy.model = model """ General properties