diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b024ab46cf6d..1eaf950adf070 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -133,6 +133,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BasePredictionWriter` callback to implement prediction saving ([#7127](https://github.com/PyTorchLightning/pytorch-lightning/pull/7127)) +- Added `trainer.tune(scale_batch_size_kwargs, lr_find_kwargs)` arguments to configure the tuning algorithms ([#7258](https://github.com/PyTorchLightning/pytorch-lightning/pull/7258)) + + - Added `tpu_distributed` check for TPU Spawn barrier ([#7241](https://github.com/PyTorchLightning/pytorch-lightning/pull/7241)) @@ -178,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed default setting for communication of multi-node training using `DDPShardedPlugin` ([#6937](https://github.com/PyTorchLightning/pytorch-lightning/pull/6937)) +- `trainer.tune()` now returns the tuning result ([#7258](https://github.com/PyTorchLightning/pytorch-lightning/pull/7258)) + + - `LightningModule.from_datasets()` now accepts `IterableDataset` instances as training datasets. ([#7503](https://github.com/PyTorchLightning/pytorch-lightning/pull/7503)) @@ -325,6 +331,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) +- Fixed `trainer.tuner.{lr_find,scale_batch_size}` not setting the `Trainer` state properly ([#7258](https://github.com/PyTorchLightning/pytorch-lightning/pull/7258)) + + - Fixed bug where `BaseFinetuning.flatten_modules()` was duplicating leaf node parameters ([#6879](https://github.com/PyTorchLightning/pytorch-lightning/pull/6879)) diff --git a/docs/source/advanced/lr_finder.rst b/docs/source/advanced/lr_finder.rst index 9a0749b36ad4a..fe2c82c661872 100644 --- a/docs/source/advanced/lr_finder.rst +++ b/docs/source/advanced/lr_finder.rst @@ -73,9 +73,9 @@ If your model is using an arbitrary value instead of ``self.lr`` or ``self.learn trainer.tune(model) -If you want to inspect the results of the learning rate finder or just play around -with the parameters of the algorithm, this can be done by invoking the ``lr_find`` -method of the trainer. A typical example of this would look like +You can also inspect the results of the learning rate finder or just play around +with the parameters of the algorithm. This can be done by invoking the +:meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find` method. A typical example of this would look like: .. code-block:: python diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index c3b232b41c13c..dd16f7c914107 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -112,7 +112,7 @@ search for batch sizes larger than the size of the training dataset. to `.fit()`. The scaling algorithm has a number of parameters that the user can control by -invoking the trainer method `.scale_batch_size` themself (see description below). +invoking the :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size` method: .. code-block:: python @@ -123,7 +123,7 @@ invoking the trainer method `.scale_batch_size` themself (see description below) # Invoke method new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here) - # Override old batch size + # Override old batch size (this is done automatically) model.hparams.batch_size = new_batch_size # Fit as normal @@ -142,10 +142,6 @@ The algorithm in short works by: 3. The found batch size is saved to either `model.batch_size` or `model.hparams.batch_size` 4. Restore the initial state of model and trainer -.. autoclass:: pytorch_lightning.tuner.tuning.Tuner - :noindex: - :members: scale_batch_size - .. warning:: Batch size finder is not supported for DDP yet, it is coming soon. diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index c954db735c282..642b11b5bdad4 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -153,14 +153,14 @@ Trainer API Tuner API --------- -.. currentmodule:: pytorch_lightning.tuner +.. currentmodule:: pytorch_lightning.tuner.tuning .. autosummary:: :toctree: api :nosignatures: + :template: classtemplate.rst - batch_size_scaling - lr_finder + Tuner Utilities API ------------- diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index a1f1c02ef498d..9550ceae4a9cc 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -319,7 +319,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): Args: args: The parser or namespace to take arguments from. Only known arguments will be - parsed and passed to the :class:`LightningDataModule`. + parsed and passed to the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. **kwargs: Additional keyword arguments that may override ones in the parser or namespace. These must be valid DataModule arguments. diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 476bd3fc14da7..9fb531f8eb67c 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -17,7 +17,6 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.trainer.supporters import prefetch_iterator from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -68,24 +67,26 @@ def can_prepare_data(self): else: return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data - def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): - # if a datamodule comes in as the second arg, then fix it for the user - if isinstance(train_dataloader, LightningDataModule): - datamodule = train_dataloader - train_dataloader = None - - self.__enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule) - + def attach_data( + self, + model: 'pl.LightningModule', + train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional['pl.LightningDataModule'] = None + ) -> None: # set up the passed in dataloaders (if needed) - self.attach_dataloaders(model, train_dataloader, val_dataloaders) - self.attach_datamodule(model, datamodule) - - def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders - if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: - raise MisconfigurationException( - 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' - ) + self.attach_dataloaders( + model, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + test_dataloaders=test_dataloaders, + predict_dataloaders=predict_dataloaders, + ) + self.attach_datamodule(model, datamodule=datamodule) + # set local properties on the model + self.trainer.model_connector.copy_trainer_model_properties(model) def attach_dataloaders( self, diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 113cbab12fa14..b98c1c0c551c2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -56,8 +56,9 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin +from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_warn from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach @@ -409,21 +410,15 @@ def __init__( # Callback system self.on_init_end() - def _run( - self, - model: LightningModule, - train_dataloader: Any = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - datamodule: Optional[LightningDataModule] = None, - ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: - # set local properties on the model - self.model_connector.copy_trainer_model_properties(model) + def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + # clean hparams + if hasattr(model, "hparams"): + parsing.clean_namespace(model.hparams) - # ---------------------------- - # LINK DATA - # ---------------------------- - # setup data, etc... - self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) + self.config_validator.verify_loop_configurations(model) + + # attach model log function to callback + self.callback_connector.attach_model_logging_functions(model) # hook self.data_connector.prepare_data(model) @@ -848,14 +843,29 @@ def fit( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped - datamodule: A instance of :class:`LightningDataModule`. + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ Trainer._log_api_event("fit") self.state = TrainerState.FITTING self.training = True - self._run(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + raise MisconfigurationException( + 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`' + ) + + # links data to the trainer + self.data_connector.attach_data( + model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + ) + + self._run(model) assert self.state.stopped self.training = False @@ -883,7 +893,7 @@ def validate( verbose: If True, prints the validation results. - datamodule: A instance of :class:`LightningDataModule`. + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. Returns: The dictionary with final validation results returned by validation_epoch_end. @@ -908,10 +918,8 @@ def validate( model_provided = model is not None model = model or self.lightning_module - # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule) - # Attach dataloaders (if given) - self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders) + # links data to the trainer + self.data_connector.attach_data(model, val_dataloaders=val_dataloaders, datamodule=datamodule) if not model_provided: self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) @@ -948,7 +956,7 @@ def test( verbose: If True, prints the test results. - datamodule: A instance of :class:`LightningDataModule`. + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. Returns: Returns a list of dictionaries, one for each test dataloader containing their respective metrics. @@ -969,10 +977,8 @@ def test( model_provided = model is not None model = model or self.lightning_module - # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule) - # Attach dataloaders (if given) - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + # links data to the trainer + self.data_connector.attach_data(model, test_dataloaders=test_dataloaders, datamodule=datamodule) if not model_provided: self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) @@ -1063,10 +1069,8 @@ def predict( if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') - # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model, datamodule) - # Attach dataloaders (if given) - self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) + # links data to the trainer + self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) results = self._run(model) @@ -1081,7 +1085,9 @@ def tune( train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, - ) -> None: + scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, + lr_find_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Optional[Union[int, _LRFinder]]]: r""" Runs routines to tune hyperparameters before training. @@ -1094,17 +1100,38 @@ def tune( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped - datamodule: A instance of :class:`LightningDataModule`. + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + + scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size` + + lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find` """ Trainer._log_api_event("tune") self.state = TrainerState.TUNING self.tuning = True - self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule) + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + raise MisconfigurationException( + 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`' + ) + + # links data to the trainer + self.data_connector.attach_data( + model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + ) + + result = self.tuner._tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs) assert self.state.stopped self.tuning = False + return result + def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" state = self._setup_state diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 397f5cc5cfca9..489083f796181 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -24,7 +24,7 @@ from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing +from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters @@ -91,20 +91,6 @@ def on_train_start(self): # hook self.trainer.call_hook("on_train_start") - def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): - # clean hparams - if hasattr(model, "hparams"): - parsing.clean_namespace(model.hparams) - - # links data to the trainer - self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) - - # check that model is configured correctly - self.trainer.config_validator.verify_loop_configurations(model) - - # attach model log function to callback - self.trainer.callback_connector.attach_model_logging_functions(model) - def on_train_end(self): if self._teardown_already_run: return diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 79e1bde9099ca..da681b0d6db80 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -35,52 +35,8 @@ def scale_batch_size( init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', - **fit_kwargs ) -> Optional[int]: - r""" - Will iteratively try to find the largest batch size for a given model - that does not give an out of memory (OOM) error. - - Args: - trainer: The Trainer - - model: Model to fit. - - mode: string setting the search mode. Either `power` or `binsearch`. - If mode is `power` we keep multiplying the batch size by 2, until - we get an OOM error. If mode is 'binsearch', we will initially - also keep multiplying by 2 and after encountering an OOM error - do a binary search between the last successful batch size and the - batch size that failed. - - steps_per_trial: number of steps to run with a given batch size. - Ideally 1 should be enough to test if a OOM error occurs, - however in practise a few are needed - - init_val: initial batch size to start the search with - - max_trials: max number of increase in batch size done before - algorithm is terminated - - batch_arg_name: name of the attribute that stores the batch size. - It is expected that the user has provided a model or datamodule that has a hyperparameter - with that name. We will look for this attribute name in the following places - - - ``model`` - - ``model.hparams`` - - ``model.datamodule`` - - ``trainer.datamodule`` (the datamodule passed to the tune method) - - **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader - or datamodule. - - Raises: - MisconfigurationException: - If field ``batch_arg_name`` is not found in ``model`` and ``model.hparams``, or - if batch scaling feature is used with dataloaders passed directly to ``.fit()``. - ValueError: - If mode in method ``scale_batch_size`` is neither ``power`` nor ``binsearch``. - """ + """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`""" if trainer.fast_dev_run: rank_zero_warn('Skipping batch size scaler since fast_dev_run is enabled.', UserWarning) return @@ -116,9 +72,9 @@ def scale_batch_size( # Initially we just double in size until an OOM is encountered new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val if mode == 'power': - new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) + new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials) elif mode == 'binsearch': - new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) + new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials) else: raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch') @@ -183,8 +139,7 @@ def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None: def _run_power_scaling( - trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int, - **fit_kwargs + trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int ) -> int: """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): @@ -192,7 +147,7 @@ def _run_power_scaling( trainer.global_step = 0 # reset after each try try: # Try fit - trainer.tuner._run(model, **fit_kwargs) + trainer.tuner._run(model) # Double in size new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: @@ -211,8 +166,7 @@ def _run_power_scaling( def _run_binsearch_scaling( - trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int, - **fit_kwargs + trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int ) -> int: """ Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further @@ -224,7 +178,7 @@ def _run_binsearch_scaling( trainer.global_step = 0 # reset after each try try: # Try fit - trainer.tuner._run(model, **fit_kwargs) + trainer.tuner._run(model) count += 1 if count > max_trials: break diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index aceacf26e85e1..01f48c66ad201 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -15,13 +15,12 @@ import logging import os from functools import wraps -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, Optional, Sequence import numpy as np import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback @@ -62,7 +61,7 @@ def _determine_lr_attr_name(trainer: 'pl.Trainer', model: 'pl.LightningModule') class _LRFinder(object): - """ LR finder object. This object stores the results of Trainer.lr_find(). + """ LR finder object. This object stores the results of lr_find(). Args: mode: either `linear` or `exponential`, how to increase lr after each step @@ -198,77 +197,14 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1): def lr_find( trainer: 'pl.Trainer', model: 'pl.LightningModule', - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, - datamodule: Optional['pl.LightningDataModule'] = None, update_attr: bool = False, ) -> Optional[_LRFinder]: - r""" - ``lr_find`` enables the user to do a range test of good initial learning rates, - to reduce the amount of guesswork in picking a good starting learning rate. - - Args: - trainer: The Trainer - - model: Model to do range testing for - - train_dataloader: A PyTorch - ``DataLoader`` with training samples. If the model has - a predefined train_dataloader method, this will be skipped. - - min_lr: minimum learning rate to investigate - - max_lr: maximum learning rate to investigate - - num_training: number of learning rates to test - - mode: Search strategy to update learning rate after each batch: - - - ``'exponential'`` (default): Will increase the learning rate exponentially. - - ``'linear'``: Will increase the learning rate linearly. - - early_stop_threshold: threshold for stopping the search. If the - loss at any point is larger than early_stop_threshold*best_loss - then the search is stopped. To disable, set to None. - - datamodule: An optional ``LightningDataModule`` which holds the training - and validation dataloader(s). Note that the ``train_dataloader`` and - ``val_dataloaders`` parameters cannot be used at the same time as - this parameter, or a ``MisconfigurationException`` will be raised. - - update_attr: Whether to update the learning rate attribute or not. - - Raises: - MisconfigurationException: - If learning rate/lr in ``model`` or ``model.hparams`` isn't overriden when ``auto_lr_find=True``, or - if you are using `more than one optimizer` with learning rate finder. - - Example:: - - # Setup model and trainer - model = MyModelClass(hparams) - trainer = pl.Trainer() - - # Run lr finder - lr_finder = trainer.tuner.lr_find(model, ...) - - # Inspect results - fig = lr_finder.plot(); fig.show() - suggested_lr = lr_finder.suggestion() - - # Overwrite lr and create new model - hparams.lr = suggested_lr - model = MyModelClass(hparams) - - # Ready to train with new learning rate - trainer.fit(model) - - """ + """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" if trainer.fast_dev_run: rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) return @@ -311,7 +247,7 @@ def lr_find( model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) # Fit, lr & loss logged in callback - trainer.tuner._run(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) + trainer.tuner._run(model) # Prompt if we stopped early if trainer.global_step != num_training: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 1b64cd4fccb66..8e3862b195cd6 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -11,20 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find class Tuner: + """Tuner class to tune your model""" def __init__(self, trainer: 'pl.Trainer') -> None: self.trainer = trainer @@ -33,44 +31,30 @@ def on_trainer_init(self, auto_lr_find: Union[str, bool], auto_scale_batch_size: self.trainer.auto_lr_find = auto_lr_find self.trainer.auto_scale_batch_size = auto_scale_batch_size - def setup_trainer( + def _tune( self, - model: LightningModule, - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - datamodule: LightningDataModule = None, - ): - self.trainer.model_connector.copy_trainer_model_properties(model) - # setup data, etc... - self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) - # hook - self.trainer.data_connector.prepare_data(model) - - def tune(self, model, train_dataloader, val_dataloaders, datamodule): + model: 'pl.LightningModule', + scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, + lr_find_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Optional[Union[int, _LRFinder]]]: + scale_batch_size_kwargs = scale_batch_size_kwargs or {} + lr_find_kwargs = lr_find_kwargs or {} + # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added + result = {} + # Run auto batch size scaling if self.trainer.auto_scale_batch_size: - if isinstance(self.trainer.auto_scale_batch_size, bool): - self.trainer.auto_scale_batch_size = 'power' - self.scale_batch_size( - model, - mode=self.trainer.auto_scale_batch_size, - train_dataloader=train_dataloader, - val_dataloaders=val_dataloaders, - datamodule=datamodule, - ) + result['scale_batch_size'] = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs) # Run learning rate finder: if self.trainer.auto_lr_find: - self.lr_find( - model, - update_attr=True, - train_dataloader=train_dataloader, - val_dataloaders=val_dataloaders, - datamodule=datamodule, - ) + lr_find_kwargs.setdefault('update_attr', True) + result['lr_find'] = lr_find(self.trainer, model, **lr_find_kwargs) self.trainer.state = TrainerState.FINISHED + return result + def _run(self, *args: Any, **kwargs: Any) -> None: """`_run` wrapper to set the proper state during tuning, as this can be called multiple times""" self.trainer.state = TrainerState.TUNING # last `_run` call might have set it to `FINISHED` @@ -81,29 +65,38 @@ def _run(self, *args: Any, **kwargs: Any) -> None: def scale_batch_size( self, model: 'pl.LightningModule', + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional['pl.LightningDataModule'] = None, mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', - **fit_kwargs ) -> Optional[int]: - r""" - Will iteratively try to find the largest batch size for a given model + """ + Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. Args: - model: Model to fit. + model: Model to tune. + + train_dataloader: A Pytorch DataLoader with training samples. If the model has + a predefined train_dataloader method this will be skipped. + + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloaders method this will be skipped + + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. - mode: string setting the search mode. Either `power` or `binsearch`. - If mode is `power` we keep multiplying the batch size by 2, until - we get an OOM error. If mode is 'binsearch', we will initially - also keep multiplying by 2 and after encountering an OOM error - do a binary search between the last successful batch size and the - batch size that failed. + mode: Search strategy to update the batch size: + + - ``'power'`` (default): Keep multiplying the batch size by 2, until we get an OOM error. + - ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error + do a binary search between the last successful batch size and the batch size that failed. steps_per_trial: number of steps to run with a given batch size. - Idealy 1 should be enough to test if a OOM error occurs, + Ideally 1 should be enough to test if a OOM error occurs, however in practise a few are needed init_val: initial batch size to start the search with @@ -119,47 +112,88 @@ def scale_batch_size( - ``model.hparams`` - ``model.datamodule`` - ``trainer.datamodule`` (the datamodule passed to the tune method) - - **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader - or datamodule. - """ - self.setup_trainer(model, **fit_kwargs) - return scale_batch_size( - self.trainer, + self.trainer.auto_scale_batch_size = True + result = self.trainer.tune( model, - mode, - steps_per_trial, - init_val, - max_trials, - batch_arg_name, - **fit_kwargs, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + scale_batch_size_kwargs={ + 'mode': mode, + 'steps_per_trial': steps_per_trial, + 'init_val': init_val, + 'max_trials': max_trials, + 'batch_arg_name': batch_arg_name, + } ) + self.trainer.auto_scale_batch_size = False + return result['scale_batch_size'] def lr_find( self, model: 'pl.LightningModule', train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional['pl.LightningDataModule'] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, - datamodule: Optional['pl.LightningDataModule'] = None, update_attr: bool = False, ) -> Optional[_LRFinder]: - self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) - return lr_find( - self.trainer, + """ + Enables the user to do a range test of good initial learning rates, + to reduce the amount of guesswork in picking a good starting learning rate. + + Args: + model: Model to tune. + + train_dataloader: A Pytorch DataLoader with training samples. If the model has + a predefined train_dataloader method this will be skipped. + + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloaders method this will be skipped + + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + + min_lr: minimum learning rate to investigate + + max_lr: maximum learning rate to investigate + + num_training: number of learning rates to test + + mode: Search strategy to update learning rate after each batch: + + - ``'exponential'`` (default): Will increase the learning rate exponentially. + - ``'linear'``: Will increase the learning rate linearly. + + early_stop_threshold: threshold for stopping the search. If the + loss at any point is larger than early_stop_threshold*best_loss + then the search is stopped. To disable, set to None. + + update_attr: Whether to update the learning rate attribute or not. + + Raises: + MisconfigurationException: + If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``, + or if you are using more than one optimizer. + """ + self.trainer.auto_lr_find = True + result = self.trainer.tune( model, - train_dataloader, - val_dataloaders, - min_lr, - max_lr, - num_training, - mode, - early_stop_threshold, - datamodule, - update_attr, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + lr_find_kwargs={ + 'min_lr': min_lr, + 'max_lr': max_lr, + 'num_training': num_training, + 'mode': mode, + 'early_stop_threshold': early_stop_threshold, + 'update_attr': update_attr + } ) + self.trainer.auto_lr_find = False + return result['lr_find'] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7997d1efdf0a0..35ce1a4d97034 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1977,6 +1977,8 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + model = BoringModel() + trainer.fit(model) with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"): trainer.validate() diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index e6b530752407f..641196eda466f 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -54,7 +54,7 @@ def test_model_reset_correctly(tmpdir): before_state_dict = deepcopy(model.state_dict()) - _ = trainer.tuner.lr_find(model, num_training=5) + trainer.tuner.lr_find(model, num_training=5) after_state_dict = model.state_dict() @@ -80,7 +80,7 @@ def test_trainer_reset_correctly(tmpdir): 'callbacks', 'logger', 'max_steps', 'auto_lr_find', 'accumulate_grad_batches', 'checkpoint_callback' ] expected = {ca: getattr(trainer, ca) for ca in changed_attributes} - _ = trainer.tuner.lr_find(model, num_training=5) + trainer.tuner.lr_find(model, num_training=5) actual = {ca: getattr(trainer, ca) for ca in changed_attributes} assert actual == expected @@ -278,12 +278,10 @@ def __init__(self, learning_rate=0.1, batch_size=2): before_lr = model.hparams.learning_rate # logger file to get meta - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=3, - ) - bs = trainer.tuner.scale_batch_size(model) - lr = trainer.tuner.lr_find(model).suggestion() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=3, auto_lr_find=True, auto_scale_batch_size=True) + result = trainer.tune(model) + bs = result['scale_batch_size'] + lr = result['lr_find'].suggestion() assert lr != before_lr assert isinstance(bs, int) @@ -329,7 +327,7 @@ def training_step_end(self, outputs): model = TestModel() trainer = Trainer(default_root_dir=tmpdir) num_training = 3 - _ = trainer.tuner.lr_find( + trainer.tuner.lr_find( model=model, num_training=num_training, )