From d8b0bf57258a6c1e62020620c2dc8daaf09fa27a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 00:40:49 +0200 Subject: [PATCH 1/2] Code cleaning in preparation for 7258 --- .../trainer/configuration_validator.py | 30 +- .../trainer/connectors/data_connector.py | 18 +- .../trainer/connectors/model_connector.py | 5 - pytorch_lightning/trainer/predict_loop.py | 6 +- pytorch_lightning/trainer/trainer.py | 5 +- pytorch_lightning/tuner/auto_gpu_select.py | 8 +- pytorch_lightning/tuner/batch_size_scaling.py | 34 +- pytorch_lightning/tuner/lr_finder.py | 290 +++++++++--------- pytorch_lightning/tuner/tuning.py | 21 +- tests/trainer/test_trainer_tricks.py | 212 ------------- tests/tuner/test_lr_finder.py | 17 +- tests/tuner/test_scale_batch_size.py | 217 +++++++++++++ 12 files changed, 427 insertions(+), 436 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 55b4ea7fe7692..215fd1353e3f0 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -11,19 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -class ConfigValidator(object): +class ConfigValidator: - def __init__(self, trainer): + def __init__(self, trainer: 'pl.Trainer') -> None: self.trainer = trainer - def verify_loop_configurations(self, model: LightningModule) -> None: + def verify_loop_configurations(self, model: 'pl.LightningModule') -> None: r""" Checks that the model is configured correctly before the run is started. @@ -31,19 +31,18 @@ def verify_loop_configurations(self, model: LightningModule) -> None: model: The model to check the configuration. """ - if self.trainer.state == TrainerState.FITTING: + if self.trainer.state in (TrainerState.FITTING, TrainerState.TUNING): self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, 'val') - elif self.trainer.state == TrainerState.TUNING: - self.__verify_train_loop_configuration(model) elif self.trainer.state == TrainerState.VALIDATING: self.__verify_eval_loop_configuration(model, 'val') elif self.trainer.state == TrainerState.TESTING: self.__verify_eval_loop_configuration(model, 'test') elif self.trainer.state == TrainerState.PREDICTING: self.__verify_predict_loop_configuration(model) + self.__verify_dp_batch_transfer_support(model) - def __verify_train_loop_configuration(self, model): + def __verify_train_loop_configuration(self, model: 'pl.LightningModule') -> None: # ----------------------------------- # verify model has a training step # ----------------------------------- @@ -82,14 +81,14 @@ def __verify_train_loop_configuration(self, model): going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches() has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad - if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization: + if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization: raise MisconfigurationException( 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad,' ' `accumulate_grad_batches` in `Trainer` should be 1.' ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' ) - def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None: + def __verify_eval_loop_configuration(self, model: 'pl.LightningModule', stage: str) -> None: loader_name = f'{stage}_dataloader' step_name = 'validation_step' if stage == 'val' else 'test_step' @@ -101,8 +100,15 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) - if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') - def __verify_predict_loop_configuration(self, model: LightningModule) -> None: - + def __verify_predict_loop_configuration(self, model: 'pl.LightningModule') -> None: has_predict_dataloader = is_overridden('predict_dataloader', model) if not has_predict_dataloader: raise MisconfigurationException('Dataloader not found for `Trainer.predict`') + + def __verify_dp_batch_transfer_support(self, model: 'pl.LightningModule') -> None: + """Raise Misconfiguration exception since these hooks are not supported in DP mode""" + # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model): + raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.') diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 5d2f141dc64a8..fd6c9ea32891c 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -16,6 +16,7 @@ from torch.utils.data import DataLoader +import pytorch_lightning as pl from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -89,7 +90,6 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): # set up the passed in dataloaders (if needed) self.attach_dataloaders(model, train_dataloader, val_dataloaders) self.attach_datamodule(model, datamodule) - self._validate_data_hooks(model) def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): # If you supply a datamodule you can't supply train_dataloader or val_dataloaders @@ -98,22 +98,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' ) - def _validate_data_hooks(self, model): - # Raise Misconfiguration exception since these hooks are not supported in DP mode - # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. - batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') - for hook in batch_transfer_hooks: - if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model): - raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.') - def attach_dataloaders( self, - model, + model: 'pl.LightningModule', train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - ): + ) -> None: # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations if train_dataloader is not None: @@ -128,7 +120,9 @@ def attach_dataloaders( if predict_dataloaders is not None: model.predict_dataloader = _PatchDataLoader(predict_dataloaders) - def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None: + def attach_datamodule( + self, model: 'pl.LightningModule', datamodule: Optional['pl.LightningDataModule'] = None + ) -> None: # We use datamodule if it's been provided, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 23f8d36a7ba83..d4bdedd31e0f4 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -11,11 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Root module for all distributed operations in Lightning. -Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU. - -""" from weakref import proxy diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 4815987e26240..fb1ad3b054c9e 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -76,11 +76,7 @@ def on_predict_model_eval(self): model_ref = self.trainer.lightning_module model_ref.on_predict_model_eval() - def setup(self, model, max_batches, dataloaders): - - # copy properties for forward overrides - self.trainer.model_connector.copy_trainer_model_properties(model) - + def setup(self, max_batches, dataloaders): # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1177c5f4ace7e..a2a7da13985f8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -775,7 +775,7 @@ def run_predict(self) -> Optional[_PREDICT_OUTPUT]: return [] # set up the eval loop - self.predict_loop.setup(self.lightning_module, max_batches, dataloaders) + self.predict_loop.setup(max_batches, dataloaders) # call hook self.predict_loop.on_predict_start() @@ -1086,8 +1086,6 @@ def tune( Runs routines to tune hyperparameters before training. Args: - datamodule: A instance of :class:`LightningDataModule`. - model: Model to tune. train_dataloader: A Pytorch DataLoader with training samples. If the model has @@ -1096,6 +1094,7 @@ def tune( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped + datamodule: A instance of :class:`LightningDataModule`. """ Trainer._log_api_event("tune") self.state = TrainerState.TUNING diff --git a/pytorch_lightning/tuner/auto_gpu_select.py b/pytorch_lightning/tuner/auto_gpu_select.py index 3bd1ce52b52f4..8e0b5ad68b689 100644 --- a/pytorch_lightning/tuner/auto_gpu_select.py +++ b/pytorch_lightning/tuner/auto_gpu_select.py @@ -17,11 +17,11 @@ def pick_multiple_gpus(nb): - ''' + """ Raises: MisconfigurationException: If ``gpus`` is set to 0, when ``auto_select_gpus=True``. - ''' + """ if nb == 0: raise MisconfigurationException( r"auto_select_gpus=True, gpus=0 is not a valid configuration.\ @@ -38,11 +38,11 @@ def pick_multiple_gpus(nb): def pick_single_gpu(exclude_gpus: list): - ''' + """ Raises: RuntimeError: If you try to allocate a GPU, when no GPUs are available. - ''' + """ for i in range(torch.cuda.device_count()): if i in exclude_gpus: continue diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 7e9dc524099de..45b0ac426e803 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -15,7 +15,7 @@ import os from typing import Optional, Tuple -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -28,21 +28,22 @@ def scale_batch_size( - trainer, - model: LightningModule, + trainer: 'pl.Trainer', + model: 'pl.LightningModule', mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', **fit_kwargs -): +) -> Optional[int]: r""" Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. Args: trainer: The Trainer + model: Model to fit. mode: string setting the search mode. Either `power` or `binsearch`. @@ -53,7 +54,7 @@ def scale_batch_size( batch size that failed. steps_per_trial: number of steps to run with a given batch size. - Idealy 1 should be enough to test if a OOM error occurs, + Ideally 1 should be enough to test if a OOM error occurs, however in practise a few are needed init_val: initial batch size to start the search with @@ -113,7 +114,7 @@ def scale_batch_size( trainer.progress_bar_callback.disable() # Initially we just double in size until an OOM is encountered - new_size = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val if mode == 'power': new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) elif mode == 'binsearch': @@ -139,7 +140,7 @@ def scale_batch_size( return new_size -def __scale_batch_dump_params(trainer): +def __scale_batch_dump_params(trainer: 'pl.Trainer') -> None: # Prevent going into infinite loop trainer.__dumped_params = { 'auto_lr_find': trainer.auto_lr_find, @@ -155,7 +156,7 @@ def __scale_batch_dump_params(trainer): } -def __scale_batch_reset_params(trainer, model, steps_per_trial): +def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule', steps_per_trial: int) -> None: trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times trainer.current_epoch = 0 @@ -168,7 +169,7 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial): trainer.model = model # required for saving -def __scale_batch_restore_params(trainer): +def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None: trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] trainer.current_epoch = trainer.__dumped_params['current_epoch'] trainer.max_steps = trainer.__dumped_params['max_steps'] @@ -181,9 +182,11 @@ def __scale_batch_restore_params(trainer): del trainer.__dumped_params -def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): - """ Batch scaling mode where the size is doubled at each iteration until an - OOM error is encountered. """ +def _run_power_scaling( + trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int, + **fit_kwargs +) -> int: + """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): garbage_collection_cuda() trainer.global_step = 0 # reset after each try @@ -207,7 +210,10 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f return new_size -def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): +def _run_binsearch_scaling( + trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int, + **fit_kwargs +) -> int: """ Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search """ @@ -252,7 +258,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, def _adjust_batch_size( - trainer, + trainer: 'pl.Trainer', batch_arg_name: str = 'batch_size', factor: float = 1.0, value: Optional[int] = None, diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 14f21da856145..df51637dc9520 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -23,9 +23,8 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -42,7 +41,7 @@ log = logging.getLogger(__name__) -def _determine_lr_attr_name(trainer, model: LightningModule) -> str: +def _determine_lr_attr_name(trainer: 'pl.Trainer', model: 'pl.LightningModule') -> str: if isinstance(trainer.auto_lr_find, str): if not lightning_hasattr(model, trainer.auto_lr_find): raise MisconfigurationException( @@ -62,9 +61,143 @@ def _determine_lr_attr_name(trainer, model: LightningModule) -> str: ) +class _LRFinder(object): + """ LR finder object. This object stores the results of Trainer.lr_find(). + + Args: + mode: either `linear` or `exponential`, how to increase lr after each step + + lr_min: lr to start search from + + lr_max: lr to stop search + + num_training: number of steps to take between lr_min and lr_max + + Example:: + # Run lr finder + lr_finder = trainer.lr_find(model) + + # Results stored in + lr_finder.results + + # Plot using + lr_finder.plot() + + # Get suggestion + lr = lr_finder.suggestion() + """ + + def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): + assert mode in ('linear', 'exponential'), \ + 'mode should be either `linear` or `exponential`' + + self.mode = mode + self.lr_min = lr_min + self.lr_max = lr_max + self.num_training = num_training + + self.results = {} + self._total_batch_idx = 0 # for debug purpose + + def _exchange_scheduler(self, configure_optimizers: Callable): + """ Decorate configure_optimizers methods such that it returns the users + originally specified optimizer together with a new scheduler that + that takes care of the learning rate search. + """ + + @wraps(configure_optimizers) + def func(): + # Decide the structure of the output from configure_optimizers + # Same logic as method `init_optimizers` in trainer/optimizers.py + optim_conf = configure_optimizers() + if isinstance(optim_conf, Optimizer): + optimizers = [optim_conf] + elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \ + and isinstance(optim_conf[0], list): + optimizers, _ = optim_conf + elif isinstance(optim_conf, dict): + optimizers = [optim_conf["optimizer"]] + elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): + optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] + elif isinstance(optim_conf, (list, tuple)): + optimizers = [optim_conf] + + if len(optimizers) != 1: + raise MisconfigurationException( + f'`model.configure_optimizers()` returned {len(optimizers)}, but' + ' learning rate finder only works with single optimizer' + ) + + optimizer = optimizers[0] + + new_lrs = [self.lr_min] * len(optimizer.param_groups) + for param_group, new_lr in zip(optimizer.param_groups, new_lrs): + param_group["lr"] = new_lr + param_group["initial_lr"] = new_lr + + args = (optimizer, self.lr_max, self.num_training) + scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args) + + return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] + + return func + + def plot(self, suggest: bool = False, show: bool = False): + """ Plot results from lr_find run + Args: + suggest: if True, will mark suggested lr to use with a red point + + show: if True, will show figure + """ + import matplotlib.pyplot as plt + + lrs = self.results["lr"] + losses = self.results["loss"] + + fig, ax = plt.subplots() + + # Plot loss as a function of the learning rate + ax.plot(lrs, losses) + if self.mode == 'exponential': + ax.set_xscale("log") + ax.set_xlabel("Learning rate") + ax.set_ylabel("Loss") + + if suggest: + _ = self.suggestion() + if self._optimal_idx: + ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], markersize=10, marker='o', color='red') + + if show: + plt.show() + + return fig + + def suggestion(self, skip_begin: int = 10, skip_end: int = 1): + """ This will propose a suggestion for choice of initial learning rate + as the point with the steepest negative gradient. + + Returns: + lr: suggested initial learning rate to use + skip_begin: how many samples to skip in the beginning. Prevent too naive estimates + skip_end: how many samples to skip in the end. Prevent too optimistic estimates + + """ + try: + loss = np.array(self.results["loss"][skip_begin:-skip_end]) + loss = loss[np.isfinite(loss)] + min_grad = np.gradient(loss).argmin() + self._optimal_idx = min_grad + skip_begin + return self.results["lr"][self._optimal_idx] + # todo: specify the possible exception + except Exception: + log.exception('Failed to compute suggesting for `lr`. There might not be enough points.') + self._optimal_idx = None + + def lr_find( - trainer, - model: LightningModule, + trainer: 'pl.Trainer', + model: 'pl.LightningModule', train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, @@ -72,14 +205,16 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, - datamodule: Optional[LightningDataModule] = None, + datamodule: Optional['pl.LightningDataModule'] = None, update_attr: bool = False, -): +) -> Optional[_LRFinder]: r""" ``lr_find`` enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. Args: + trainer: The Trainer + model: Model to do range testing for train_dataloader: A PyTorch @@ -232,140 +367,6 @@ def __lr_finder_restore_params(trainer, model): del trainer.__dumped_params -class _LRFinder(object): - """ LR finder object. This object stores the results of Trainer.lr_find(). - - Args: - mode: either `linear` or `exponential`, how to increase lr after each step - - lr_min: lr to start search from - - lr_max: lr to stop search - - num_training: number of steps to take between lr_min and lr_max - - Example:: - # Run lr finder - lr_finder = trainer.lr_find(model) - - # Results stored in - lr_finder.results - - # Plot using - lr_finder.plot() - - # Get suggestion - lr = lr_finder.suggestion() - """ - - def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): - assert mode in ('linear', 'exponential'), \ - 'mode should be either `linear` or `exponential`' - - self.mode = mode - self.lr_min = lr_min - self.lr_max = lr_max - self.num_training = num_training - - self.results = {} - self._total_batch_idx = 0 # for debug purpose - - def _exchange_scheduler(self, configure_optimizers: Callable): - """ Decorate configure_optimizers methods such that it returns the users - originally specified optimizer together with a new scheduler that - that takes care of the learning rate search. - """ - - @wraps(configure_optimizers) - def func(): - # Decide the structure of the output from configure_optimizers - # Same logic as method `init_optimizers` in trainer/optimizers.py - optim_conf = configure_optimizers() - if isinstance(optim_conf, Optimizer): - optimizers = [optim_conf] - elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \ - and isinstance(optim_conf[0], list): - optimizers, _ = optim_conf - elif isinstance(optim_conf, dict): - optimizers = [optim_conf["optimizer"]] - elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): - optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] - elif isinstance(optim_conf, (list, tuple)): - optimizers = [optim_conf] - - if len(optimizers) != 1: - raise MisconfigurationException( - f'`model.configure_optimizers()` returned {len(optimizers)}, but' - ' learning rate finder only works with single optimizer' - ) - - optimizer = optimizers[0] - - new_lrs = [self.lr_min] * len(optimizer.param_groups) - for param_group, new_lr in zip(optimizer.param_groups, new_lrs): - param_group["lr"] = new_lr - param_group["initial_lr"] = new_lr - - args = (optimizer, self.lr_max, self.num_training) - scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args) - - return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] - - return func - - def plot(self, suggest: bool = False, show: bool = False): - """ Plot results from lr_find run - Args: - suggest: if True, will mark suggested lr to use with a red point - - show: if True, will show figure - """ - import matplotlib.pyplot as plt - - lrs = self.results["lr"] - losses = self.results["loss"] - - fig, ax = plt.subplots() - - # Plot loss as a function of the learning rate - ax.plot(lrs, losses) - if self.mode == 'exponential': - ax.set_xscale("log") - ax.set_xlabel("Learning rate") - ax.set_ylabel("Loss") - - if suggest: - _ = self.suggestion() - if self._optimal_idx: - ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], markersize=10, marker='o', color='red') - - if show: - plt.show() - - return fig - - def suggestion(self, skip_begin: int = 10, skip_end: int = 1): - """ This will propose a suggestion for choice of initial learning rate - as the point with the steepest negative gradient. - - Returns: - lr: suggested initial learning rate to use - skip_begin: how many samples to skip in the beginning. Prevent too naive estimates - skip_end: how many samples to skip in the end. Prevent too optimistic estimates - - """ - try: - loss = np.array(self.results["loss"][skip_begin:-skip_end]) - loss = loss[np.isfinite(loss)] - min_grad = np.gradient(loss).argmin() - self._optimal_idx = min_grad + skip_begin - return self.results["lr"][self._optimal_idx] - # todo: specify the possible exception - except Exception: - log.exception('Failed to compute suggesting for `lr`. There might not be enough points.') - self._optimal_idx = None - - class _LRCallback(Callback): """ Special callback used by the learning rate finder. This callbacks log the learning rate before each batch and log the corresponding loss after @@ -441,9 +442,10 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data class _LinearLR(_LRScheduler): - """Linearly increases the learning rate between two boundaries - over a number of iterations. - Arguments: + """ + Linearly increases the learning rate between two boundaries over a number of iterations. + + Args: optimizer: wrapped optimizer. diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 9d471e2c5cbca..9822008f07a4f 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -16,20 +16,20 @@ from torch.utils.data import DataLoader +import pytorch_lightning as pl from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size -from pytorch_lightning.tuner.lr_finder import lr_find +from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find class Tuner: - def __init__(self, trainer): + def __init__(self, trainer: 'pl.Trainer') -> None: self.trainer = trainer - def on_trainer_init(self, auto_lr_find, auto_scale_batch_size): + def on_trainer_init(self, auto_lr_find: Union[str, bool], auto_scale_batch_size: Union[str, bool]) -> None: self.trainer.auto_lr_find = auto_lr_find self.trainer.auto_scale_batch_size = auto_scale_batch_size @@ -80,14 +80,14 @@ def _launch(self, *args: Any, **kwargs: Any) -> None: def scale_batch_size( self, - model, + model: 'pl.LightningModule', mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', **fit_kwargs - ): + ) -> Optional[int]: r""" Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. @@ -138,7 +138,7 @@ def scale_batch_size( def lr_find( self, - model: LightningModule, + model: 'pl.LightningModule', train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, @@ -146,9 +146,9 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, - datamodule: Optional[LightningDataModule] = None, + datamodule: Optional['pl.LightningDataModule'] = None, update_attr: bool = False, - ): + ) -> Optional[_LRFinder]: self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) return lr_find( self.trainer, @@ -163,6 +163,3 @@ def lr_find( datamodule, update_attr, ) - - def pick_multiple_gpus(self, num_gpus: int): - return pick_multiple_gpus(num_gpus) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 7206d225ab5cd..85aa7aa937740 100644 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -11,21 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from copy import deepcopy - -import pytest import torch from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -import tests.helpers.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate -from tests.helpers import BoringModel -from tests.helpers.datamodules import MNISTDataModule -from tests.helpers.runif import RunIf def test_num_training_batches(tmpdir): @@ -166,205 +156,3 @@ def test_overfit_batch_limits(tmpdir): loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(model, split) assert loader_num_batches[0] == 10 - - -def test_model_reset_correctly(tmpdir): - """ Check that model weights are correctly reset after scaling batch size. """ - tutils.reset_seed() - - model = EvalModelTemplate() - - # logger file to get meta - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - ) - - before_state_dict = deepcopy(model.state_dict()) - - trainer.tuner.scale_batch_size(model, max_trials=5) - - after_state_dict = model.state_dict() - - for key in before_state_dict.keys(): - assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), \ - 'Model was not reset correctly after scaling batch size' - - -def test_trainer_reset_correctly(tmpdir): - """ Check that all trainer parameters are reset correctly after scaling batch size. """ - tutils.reset_seed() - - model = EvalModelTemplate() - - # logger file to get meta - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - ) - - changed_attributes = [ - 'max_steps', - 'weights_summary', - 'logger', - 'callbacks', - 'checkpoint_callback', - 'limit_train_batches', - 'current_epoch', - ] - - attributes_before = {} - for ca in changed_attributes: - attributes_before[ca] = getattr(trainer, ca) - - trainer.tuner.scale_batch_size(model, max_trials=5) - - attributes_after = {} - for ca in changed_attributes: - attributes_after[ca] = getattr(trainer, ca) - - for key in changed_attributes: - assert attributes_before[key] == attributes_after[key], \ - f'Attribute {key} was not reset correctly after learning rate finder' - - -@RunIf(min_gpus=1) -@pytest.mark.parametrize('scale_arg', ['power', 'binsearch', True]) -def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg): - """ Test possible values for 'batch size auto scaling' Trainer argument. """ - tutils.reset_seed() - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - before_batch_size = hparams.get('batch_size') - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - auto_scale_batch_size=scale_arg, - gpus=1, - ) - trainer.tune(model) - after_batch_size = model.batch_size - assert before_batch_size != after_batch_size, \ - 'Batch size was not altered after running auto scaling of batch size' - - assert not os.path.exists(tmpdir / 'scale_batch_size_temp_model.ckpt') - - -@RunIf(min_gpus=1) -@pytest.mark.parametrize('use_hparams', [True, False]) -def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams): - """ Test that new batch size gets written to the correct hyperparameter attribute. """ - tutils.reset_seed() - - hparams = EvalModelTemplate.get_default_hparams() - before_batch_size = hparams.get('batch_size') - - class HparamsEvalModelTemplate(EvalModelTemplate): - - def dataloader(self, *args, **kwargs): - # artificially set batch_size so we can get a dataloader - # remove it immediately after, because we want only self.hparams.batch_size - setattr(self, "batch_size", before_batch_size) - dataloader = super().dataloader(*args, **kwargs) - del self.batch_size - return dataloader - - datamodule_model = MNISTDataModule(data_dir=tmpdir, batch_size=111) # this datamodule should get ignored! - datamodule_fit = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) - - model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate - model = model_class(**hparams) - model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - auto_scale_batch_size=True, - gpus=1, - ) - trainer.tune(model, datamodule_fit) - after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size - assert trainer.datamodule == datamodule_fit - assert before_batch_size != after_batch_size - assert after_batch_size <= len(trainer.train_dataloader.dataset) - assert datamodule_fit.batch_size == after_batch_size - # should be left unchanged, since it was not passed to .tune() - assert datamodule_model.batch_size == 111 - - -def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir): - """ Test for a warning when model.batch_size and model.hparams.batch_size both present. """ - - class TestModel(BoringModel): - - def __init__(self, batch_size=1): - super().__init__() - # now we have model.batch_size and model.hparams.batch_size - self.batch_size = 1 - self.save_hyperparameters() - - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, max_steps=1, max_epochs=1000, auto_scale_batch_size=True) - expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!" - with pytest.warns(UserWarning, match=expected_message): - trainer.tune(model) - - -@pytest.mark.parametrize('scale_method', ['power', 'binsearch']) -def test_call_to_trainer_method(tmpdir, scale_method): - """ Test that calling the trainer method itself works. """ - tutils.reset_seed() - - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - - before_batch_size = hparams.get('batch_size') - # logger file to get meta - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - ) - - after_batch_size = trainer.tuner.scale_batch_size(model, mode=scale_method, max_trials=5) - model.batch_size = after_batch_size - trainer.fit(model) - - assert before_batch_size != after_batch_size, \ - 'Batch size was not altered after running auto scaling of batch size' - - -def test_error_on_dataloader_passed_to_fit(tmpdir): - """Verify that when the auto scale batch size feature raises an error - if a train dataloader is passed to fit """ - - # only train passed to fit - model = EvalModelTemplate() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - auto_scale_batch_size='power', - ) - fit_options = dict(train_dataloader=model.dataloader(train=True)) - - with pytest.raises(MisconfigurationException): - trainer.tune(model, **fit_options) - - -@RunIf(min_gpus=1, amp_native=True) -def test_auto_scale_batch_size_with_amp(tmpdir): - model = EvalModelTemplate() - batch_size_before = model.batch_size - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=1, - auto_scale_batch_size=True, - gpus=1, - precision=16, - ) - trainer.tune(model) - batch_size_after = model.batch_size - assert trainer.amp_backend == AMPType.NATIVE - assert trainer.scaler is not None - assert batch_size_after != batch_size_before diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index 9834c1c8ad09b..e6b530752407f 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -17,7 +17,7 @@ import pytest import torch -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate from tests.helpers import BoringModel @@ -79,20 +79,11 @@ def test_trainer_reset_correctly(tmpdir): changed_attributes = [ 'callbacks', 'logger', 'max_steps', 'auto_lr_find', 'accumulate_grad_batches', 'checkpoint_callback' ] - attributes_before = {} - for ca in changed_attributes: - attributes_before[ca] = getattr(trainer, ca) - + expected = {ca: getattr(trainer, ca) for ca in changed_attributes} _ = trainer.tuner.lr_find(model, num_training=5) + actual = {ca: getattr(trainer, ca) for ca in changed_attributes} - attributes_after = {} - for ca in changed_attributes: - attributes_after[ca] = getattr(trainer, ca) - - for key in changed_attributes: - assert attributes_before[key] == attributes_after[key], \ - f'Attribute {key} was not reset correctly after learning rate finder' - + assert actual == expected assert model.trainer == trainer diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index ad7fc57092f32..e61cafec568ef 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -11,12 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from copy import deepcopy + import pytest +import torch from torch.utils.data import DataLoader +import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.tuner.tuning import Tuner +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import EvalModelTemplate from tests.helpers import BoringDataModule, BoringModel +from tests.helpers.datamodules import MNISTDataModule +from tests.helpers.runif import RunIf class BatchSizeDataModule(BoringDataModule): @@ -63,3 +73,210 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod assert model.batch_size == 16 if datamodule is not None and hasattr(datamodule, "batch_size"): assert datamodule.batch_size == 16 + + +def test_model_reset_correctly(tmpdir): + """ Check that model weights are correctly reset after scaling batch size. """ + tutils.reset_seed() + + model = EvalModelTemplate() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + ) + + before_state_dict = deepcopy(model.state_dict()) + + trainer.tuner.scale_batch_size(model, max_trials=5) + + after_state_dict = model.state_dict() + + for key in before_state_dict.keys(): + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), \ + 'Model was not reset correctly after scaling batch size' + + +def test_trainer_reset_correctly(tmpdir): + """ Check that all trainer parameters are reset correctly after scaling batch size. """ + tutils.reset_seed() + + model = EvalModelTemplate() + + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + ) + + changed_attributes = [ + 'max_steps', + 'weights_summary', + 'logger', + 'callbacks', + 'checkpoint_callback', + 'limit_train_batches', + 'current_epoch', + ] + expected = {ca: getattr(trainer, ca) for ca in changed_attributes} + trainer.tuner.scale_batch_size(model, max_trials=5) + actual = {ca: getattr(trainer, ca) for ca in changed_attributes} + + assert actual == expected + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize('scale_arg', ['power', 'binsearch', True]) +def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg): + """ Test possible values for 'batch size auto scaling' Trainer argument. """ + tutils.reset_seed() + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + before_batch_size = hparams.get('batch_size') + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + auto_scale_batch_size=scale_arg, + gpus=1, + ) + trainer.tune(model) + after_batch_size = model.batch_size + assert before_batch_size != after_batch_size, \ + 'Batch size was not altered after running auto scaling of batch size' + + assert not os.path.exists(tmpdir / 'scale_batch_size_temp_model.ckpt') + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize('use_hparams', [True, False]) +def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams): + """ Test that new batch size gets written to the correct hyperparameter attribute. """ + tutils.reset_seed() + + hparams = EvalModelTemplate.get_default_hparams() + before_batch_size = hparams.get('batch_size') + + class HparamsEvalModelTemplate(EvalModelTemplate): + + def dataloader(self, *args, **kwargs): + # artificially set batch_size so we can get a dataloader + # remove it immediately after, because we want only self.hparams.batch_size + setattr(self, "batch_size", before_batch_size) + dataloader = super().dataloader(*args, **kwargs) + del self.batch_size + return dataloader + + datamodule_model = MNISTDataModule(data_dir=tmpdir, batch_size=111) # this datamodule should get ignored! + datamodule_fit = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) + + model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate + model = model_class(**hparams) + model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + auto_scale_batch_size=True, + gpus=1, + ) + trainer.tune(model, datamodule_fit) + after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size + assert trainer.datamodule == datamodule_fit + assert before_batch_size != after_batch_size + assert after_batch_size <= len(trainer.train_dataloader.dataset) + assert datamodule_fit.batch_size == after_batch_size + # should be left unchanged, since it was not passed to .tune() + assert datamodule_model.batch_size == 111 + + +def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir): + """ Test for a warning when model.batch_size and model.hparams.batch_size both present. """ + + class TestModel(BoringModel): + + def __init__(self, batch_size=1): + super().__init__() + # now we have model.batch_size and model.hparams.batch_size + self.batch_size = 1 + self.save_hyperparameters() + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_steps=1, max_epochs=1000, auto_scale_batch_size=True) + expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!" + with pytest.warns(UserWarning, match=expected_message): + trainer.tune(model) + + +@pytest.mark.parametrize('scale_method', ['power', 'binsearch']) +def test_call_to_trainer_method(tmpdir, scale_method): + """ Test that calling the trainer method itself works. """ + tutils.reset_seed() + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + before_batch_size = hparams.get('batch_size') + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + ) + + after_batch_size = trainer.tuner.scale_batch_size(model, mode=scale_method, max_trials=5) + model.batch_size = after_batch_size + trainer.fit(model) + + assert before_batch_size != after_batch_size, \ + 'Batch size was not altered after running auto scaling of batch size' + + +def test_error_on_dataloader_passed_to_fit(tmpdir): + """Verify that when the auto scale batch size feature raises an error + if a train dataloader is passed to fit """ + + # only train passed to fit + model = EvalModelTemplate() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + auto_scale_batch_size='power', + ) + fit_options = dict(train_dataloader=model.dataloader(train=True)) + + with pytest.raises(MisconfigurationException): + trainer.tune(model, **fit_options) + + +@RunIf(min_gpus=1, amp_native=True) +def test_auto_scale_batch_size_with_amp(tmpdir): + model = EvalModelTemplate() + batch_size_before = model.batch_size + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + auto_scale_batch_size=True, + gpus=1, + precision=16, + ) + trainer.tune(model) + batch_size_after = model.batch_size + assert trainer.amp_backend == AMPType.NATIVE + assert trainer.scaler is not None + assert batch_size_after != batch_size_before + + +def test_scale_batch_size_no_trials(tmpdir): + """Check the result is correct even when no trials are run""" + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=1, + auto_scale_batch_size='power', + ) + model = BatchSizeModel(batch_size=2) + result = trainer.tuner.scale_batch_size(model, max_trials=0) + assert result == 2 From 1eeedac22126fe8c131c90a442caa95100ca311e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 00:45:59 +0200 Subject: [PATCH 2/2] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee81e4f6514a3..b2777e49ada3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -380,6 +380,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed resetting device after `fitting/evaluating/predicting` ([#7188](https://github.com/PyTorchLightning/pytorch-lightning/pull/7188)) +- Fixed bug where `trainer.tuner.scale_batch_size(max_trials=0)` would not return the correct batch size result ([#7262](https://github.com/PyTorchLightning/pytorch-lightning/pull/7262)) + + - Fixed metrics not being properly logged with `precision=16` and `manual_optimization` ([#7228](https://github.com/PyTorchLightning/pytorch-lightning/pull/7228))