Skip to content
Merged
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ module = [
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.tuner.tuning",
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
"pytorch_lightning.utilities.distributed",
Expand Down
5 changes: 2 additions & 3 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.tuner.lr_finder import _LRFinder
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.tuner.tuning import _TunerResult, Tuner
from pytorch_lightning.utilities import (
_HPU_AVAILABLE,
_IPU_AVAILABLE,
Expand Down Expand Up @@ -1015,7 +1014,7 @@ def tune(
datamodule: Optional[LightningDataModule] = None,
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
lr_find_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Optional[Union[int, _LRFinder]]]:
) -> _TunerResult:
r"""
Runs routines to tune hyperparameters before training.

Expand Down
11 changes: 9 additions & 2 deletions src/pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
from typing import Any, Dict, Optional, Union

from typing_extensions import NotRequired, TypedDict

import pytorch_lightning as pl
from pytorch_lightning.trainer.states import TrainerStatus
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
Expand All @@ -21,6 +23,11 @@
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS


class _TunerResult(TypedDict):
lr_find: NotRequired[Optional[_LRFinder]]
scale_batch_size: NotRequired[Optional[int]]


class Tuner:
"""Tuner class to tune your model."""

Expand All @@ -36,11 +43,11 @@ def _tune(
model: "pl.LightningModule",
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
lr_find_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Optional[Union[int, _LRFinder]]]:
) -> _TunerResult:
scale_batch_size_kwargs = scale_batch_size_kwargs or {}
lr_find_kwargs = lr_find_kwargs or {}
# return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
result = {}
result = _TunerResult()

self.trainer.strategy.connect(model)

Expand Down