Skip to content
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270))


- The tuner now uses the checkpoint connector to copy and restore its state ([#11518](https://github.com/PyTorchLightning/pytorch-lightning/pull/11518))


- Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def restore_loops(self) -> None:

assert self.trainer.state.fn is not None
state_dict = self._loaded_checkpoint.get("loops")
if state_dict is not None and self.trainer.state.fn != TrainerFn.TUNING:
if state_dict is not None:
if self.trainer.state.fn == TrainerFn.FITTING:
self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"])
elif self.trainer.state.fn == TrainerFn.VALIDATING:
Expand Down
70 changes: 28 additions & 42 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
import logging
import os
import uuid
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple

from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import has_len_all_ranks
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
Expand Down Expand Up @@ -59,15 +58,17 @@ def scale_batch_size(
" Please disable the feature or incorporate the dataloader into the model."
)

# Arguments we adjust during the batch size finder, save for restoring
__scale_batch_dump_params(trainer)
# Save initial model, that is loaded after batch size is found
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
trainer.fit_loop.current_epoch -= 1
trainer.fit_loop.global_step -= 1
trainer.save_checkpoint(ckpt_path)
trainer.fit_loop.current_epoch += 1
trainer.fit_loop.global_step += 1
params = __scale_batch_dump_params(trainer)

# Set to values that are required by the algorithm
__scale_batch_reset_params(trainer, model, steps_per_trial)

# Save initial model, that is loaded after batch size is found
save_path = os.path.join(trainer.default_root_dir, f"scale_batch_size_temp_model_{uuid.uuid4()}.ckpt")
trainer.save_checkpoint(str(save_path))
__scale_batch_reset_params(trainer, steps_per_trial)

if trainer.progress_bar_callback:
trainer.progress_bar_callback.disable()
Expand All @@ -85,59 +86,44 @@ def scale_batch_size(
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")

# Restore initial state of model
if trainer.is_global_zero:
trainer.checkpoint_connector.restore(str(save_path))
fs = get_filesystem(str(save_path))
if fs.exists(save_path):
fs.rm(save_path)

# Finish by resetting variables so trainer is ready to fit model
__scale_batch_restore_params(trainer)
trainer.checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
__scale_batch_restore_params(trainer, params)

if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()

return new_size


def __scale_batch_dump_params(trainer: "pl.Trainer") -> None:
# Prevent going into infinite loop
trainer.__dumped_params = {
"auto_lr_find": trainer.auto_lr_find,
"current_epoch": trainer.current_epoch,
"global_step": trainer.global_step,
"max_steps": trainer.max_steps,
def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
return {
"max_steps": trainer.fit_loop.max_steps,
"logger": trainer.logger,
"callbacks": trainer.callbacks,
"checkpoint_callback": trainer.checkpoint_callback,
"auto_scale_batch_size": trainer.auto_scale_batch_size,
"auto_lr_find": trainer.auto_lr_find,
"limit_train_batches": trainer.limit_train_batches,
"model": trainer.model,
}


def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule", steps_per_trial: int) -> None:
def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None:
trainer.auto_scale_batch_size = None # prevent recursion
trainer.auto_lr_find = False # avoid lr find being called multiple times
trainer.fit_loop.current_epoch = 0
trainer.fit_loop.max_steps = steps_per_trial # take few steps
trainer.logger = DummyLogger() if trainer.logger is not None else None
trainer.callbacks = [] # not needed before full run
trainer.limit_train_batches = 1.0
trainer.optimizers, trainer.strategy.lr_schedulers = [], [] # required for saving
trainer.model = model # required for saving


def __scale_batch_restore_params(trainer: "pl.Trainer") -> None:
trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"]
trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"]
trainer.fit_loop.global_step = trainer.__dumped_params["global_step"]
trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"]
trainer.logger = trainer.__dumped_params["logger"]
trainer.callbacks = trainer.__dumped_params["callbacks"]
trainer.auto_scale_batch_size = trainer.__dumped_params["auto_scale_batch_size"]
trainer.limit_train_batches = trainer.__dumped_params["limit_train_batches"]
trainer.model = trainer.__dumped_params["model"]
del trainer.__dumped_params


def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
trainer.auto_scale_batch_size = params["auto_scale_batch_size"]
trainer.auto_lr_find = params["auto_lr_find"]
trainer.fit_loop.max_steps = params["max_steps"]
trainer.logger = params["logger"]
trainer.callbacks = params["callbacks"]
trainer.limit_train_batches = params["limit_train_batches"]


def _run_power_scaling(
Expand Down
80 changes: 34 additions & 46 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import uuid
from functools import wraps
from typing import Optional, Sequence
from typing import Any, Dict, Optional, Sequence

import numpy as np
import torch
Expand All @@ -27,7 +27,6 @@
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities.types import LRSchedulerConfig
Expand Down Expand Up @@ -203,36 +202,25 @@ def lr_find(
if update_attr:
lr_attr_name = _determine_lr_attr_name(trainer, model)

save_path = os.path.join(trainer.default_root_dir, f"lr_find_temp_model_{uuid.uuid4()}.ckpt")
# Save initial model, that is loaded after learning rate is found
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
trainer.fit_loop.current_epoch -= 1
trainer.fit_loop.global_step -= 1
trainer.save_checkpoint(ckpt_path)
trainer.fit_loop.current_epoch += 1
trainer.fit_loop.global_step += 1
params = __lr_finder_dump_params(trainer)

__lr_finder_dump_params(trainer, model)

# Prevent going into infinite loop
trainer.auto_lr_find = False
# Set to values that are required by the algorithm
__lr_finder_reset_params(trainer, num_training, early_stop_threshold)

# Initialize lr finder object (stores results)
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

# Use special lr logger callback
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]

# No logging
trainer.logger = DummyLogger() if trainer.logger is not None else None

# Max step set to number of iterations
trainer.fit_loop.max_steps = num_training

# Disable standard progress bar for fit
if trainer.progress_bar_callback:
trainer.progress_bar_callback.disable()

# Required for saving the model
trainer.optimizers, trainer.strategy.lr_schedulers = [], []
trainer.model = model

# Dump model checkpoint
trainer.save_checkpoint(str(save_path))

# Configure optimizer and scheduler
trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model)

Expand All @@ -247,15 +235,11 @@ def lr_find(
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose

# Reset model state
if trainer.is_global_zero:
trainer.checkpoint_connector.restore(str(save_path))
fs = get_filesystem(str(save_path))
if fs.exists(save_path):
fs.rm(save_path)
# Restore initial state of model
trainer.checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
__lr_finder_restore_params(trainer, params)

# Finish by resetting variables so trainer is ready to fit model
__lr_finder_restore_params(trainer, model)
if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()

Expand All @@ -270,27 +254,31 @@ def lr_find(
return lr_finder


def __lr_finder_dump_params(trainer, model):
# Prevent going into infinite loop
trainer.__dumped_params = {
def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
return {
"auto_lr_find": trainer.auto_lr_find,
"callbacks": trainer.callbacks,
"logger": trainer.logger,
"global_step": trainer.global_step,
"max_steps": trainer.max_steps,
"checkpoint_callback": trainer.checkpoint_callback,
"current_epoch": trainer.current_epoch,
"max_steps": trainer.fit_loop.max_steps,
}


def __lr_finder_restore_params(trainer, model):
trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"]
trainer.logger = trainer.__dumped_params["logger"]
trainer.callbacks = trainer.__dumped_params["callbacks"]
trainer.fit_loop.global_step = trainer.__dumped_params["global_step"]
trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"]
trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"]
del trainer.__dumped_params
def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None:
# avoid lr find being called multiple times
trainer.auto_lr_find = False
# Use special lr logger callback
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
# No logging
trainer.logger = DummyLogger() if trainer.logger is not None else None
# Max step set to number of iterations
trainer.fit_loop.max_steps = num_training


def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
trainer.auto_lr_find = params["auto_lr_find"]
trainer.callbacks = params["callbacks"]
trainer.logger = params["logger"]
trainer.fit_loop.max_steps = params["max_steps"]


class _LRCallback(Callback):
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _tune(
# return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
result = {}

self.trainer.strategy.connect(model)

# Run auto batch size scaling
if self.trainer.auto_scale_batch_size:
if isinstance(self.trainer.auto_scale_batch_size, str):
Expand Down
2 changes: 1 addition & 1 deletion tests/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_model_reset_correctly(tmpdir):
torch.eq(before_state_dict[key], after_state_dict[key])
), "Model was not reset correctly after learning rate finder"

assert not any(f for f in os.listdir(tmpdir) if f.startswith("lr_find_temp_model"))
assert not any(f for f in os.listdir(tmpdir) if f.startswith(".lr_find"))


def test_trainer_reset_correctly(tmpdir):
Expand Down
2 changes: 1 addition & 1 deletion tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_model_reset_correctly(tmpdir):
torch.eq(before_state_dict[key], after_state_dict[key])
), "Model was not reset correctly after scaling batch size"

assert not any(f for f in os.listdir(tmpdir) if f.startswith("scale_batch_size_temp_model"))
assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size"))


def test_trainer_reset_correctly(tmpdir):
Expand Down