Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed resetting device after `fitting/evaluating/predicting` ([#7188](https://github.com/PyTorchLightning/pytorch-lightning/pull/7188))


- Fixed bug where `trainer.tuner.scale_batch_size(max_trials=0)` would not return the correct batch size result ([#7262](https://github.com/PyTorchLightning/pytorch-lightning/pull/7262))


- Fixed metrics not being properly logged with `precision=16` and `manual_optimization` ([#7228](https://github.com/PyTorchLightning/pytorch-lightning/pull/7228))


Expand Down
30 changes: 18 additions & 12 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden


class ConfigValidator(object):
class ConfigValidator:

def __init__(self, trainer):
def __init__(self, trainer: 'pl.Trainer') -> None:
self.trainer = trainer

def verify_loop_configurations(self, model: LightningModule) -> None:
def verify_loop_configurations(self, model: 'pl.LightningModule') -> None:
r"""
Checks that the model is configured correctly before the run is started.

Args:
model: The model to check the configuration.

"""
if self.trainer.state == TrainerState.FITTING:
if self.trainer.state in (TrainerState.FITTING, TrainerState.TUNING):
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state == TrainerState.TUNING:
self.__verify_train_loop_configuration(model)
elif self.trainer.state == TrainerState.VALIDATING:
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state == TrainerState.TESTING:
self.__verify_eval_loop_configuration(model, 'test')
elif self.trainer.state == TrainerState.PREDICTING:
self.__verify_predict_loop_configuration(model)
self.__verify_dp_batch_transfer_support(model)

def __verify_train_loop_configuration(self, model):
def __verify_train_loop_configuration(self, model: 'pl.LightningModule') -> None:
# -----------------------------------
# verify model has a training step
# -----------------------------------
Expand Down Expand Up @@ -82,14 +81,14 @@ def __verify_train_loop_configuration(self, model):
going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches()

has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization:
if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization:
raise MisconfigurationException(
'When overriding `LightningModule` optimizer_step or optimizer_zero_grad,'
' `accumulate_grad_batches` in `Trainer` should be 1.'
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
)

def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None:
def __verify_eval_loop_configuration(self, model: 'pl.LightningModule', stage: str) -> None:
loader_name = f'{stage}_dataloader'
step_name = 'validation_step' if stage == 'val' else 'test_step'

Expand All @@ -101,8 +100,15 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -
if has_step and not has_loader:
rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop')

def __verify_predict_loop_configuration(self, model: LightningModule) -> None:

def __verify_predict_loop_configuration(self, model: 'pl.LightningModule') -> None:
has_predict_dataloader = is_overridden('predict_dataloader', model)
if not has_predict_dataloader:
raise MisconfigurationException('Dataloader not found for `Trainer.predict`')

def __verify_dp_batch_transfer_support(self, model: 'pl.LightningModule') -> None:
"""Raise Misconfiguration exception since these hooks are not supported in DP mode"""
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')
18 changes: 6 additions & 12 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -89,7 +90,6 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
# set up the passed in dataloaders (if needed)
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
self.attach_datamodule(model, datamodule)
self._validate_data_hooks(model)

def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
Expand All @@ -98,22 +98,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)

def _validate_data_hooks(self, model):
# Raise Misconfiguration exception since these hooks are not supported in DP mode
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')

def attach_dataloaders(
self,
model,
model: 'pl.LightningModule',
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
):
) -> None:
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
Expand All @@ -128,7 +120,9 @@ def attach_dataloaders(
if predict_dataloaders is not None:
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)

def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None:
def attach_datamodule(
self, model: 'pl.LightningModule', datamodule: Optional['pl.LightningDataModule'] = None
) -> None:
# We use datamodule if it's been provided, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)

Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Root module for all distributed operations in Lightning.
Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.

"""
from weakref import proxy


Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ def on_predict_model_eval(self):
model_ref = self.trainer.lightning_module
model_ref.on_predict_model_eval()

def setup(self, model, max_batches, dataloaders):

# copy properties for forward overrides
self.trainer.model_connector.copy_trainer_model_properties(model)

def setup(self, max_batches, dataloaders):
# convert max_batches to list
if isinstance(max_batches, int):
max_batches = [max_batches] * len(dataloaders)
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def run_predict(self) -> Optional[_PREDICT_OUTPUT]:
return []

# set up the eval loop
self.predict_loop.setup(self.lightning_module, max_batches, dataloaders)
self.predict_loop.setup(max_batches, dataloaders)

# call hook
self.predict_loop.on_predict_start()
Expand Down Expand Up @@ -1086,8 +1086,6 @@ def tune(
Runs routines to tune hyperparameters before training.

Args:
datamodule: A instance of :class:`LightningDataModule`.

model: Model to tune.

train_dataloader: A Pytorch DataLoader with training samples. If the model has
Expand All @@ -1096,6 +1094,7 @@ def tune(
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped

datamodule: A instance of :class:`LightningDataModule`.
"""
Trainer._log_api_event("tune")
self.state = TrainerState.TUNING
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/tuner/auto_gpu_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@


def pick_multiple_gpus(nb):
'''
"""
Raises:
MisconfigurationException:
If ``gpus`` is set to 0, when ``auto_select_gpus=True``.
'''
"""
if nb == 0:
raise MisconfigurationException(
r"auto_select_gpus=True, gpus=0 is not a valid configuration.\
Expand All @@ -38,11 +38,11 @@ def pick_multiple_gpus(nb):


def pick_single_gpu(exclude_gpus: list):
'''
"""
Raises:
RuntimeError:
If you try to allocate a GPU, when no GPUs are available.
'''
"""
for i in range(torch.cuda.device_count()):
if i in exclude_gpus:
continue
Expand Down
34 changes: 20 additions & 14 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from typing import Optional, Tuple

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand All @@ -28,21 +28,22 @@


def scale_batch_size(
trainer,
model: LightningModule,
trainer: 'pl.Trainer',
model: 'pl.LightningModule',
mode: str = 'power',
steps_per_trial: int = 3,
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = 'batch_size',
**fit_kwargs
):
) -> Optional[int]:
Comment on lines 30 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should have some caveats that the tuner doesn't work with things like deepspeed or sharded ddp which have different behavior on multiple gpus right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with this. In general scale_batch_size is not really that well tested in multi-gpu settings. Even the most simple case where you are using multiple gpus of different types (so maybe one with 8 gb of vram and one with 16 gb of vram) it will not assign higher batch size to the second device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SkafteNicki since you are the most familiar with the tuner limitations, can you open a PR showing warnings or raising an error for these cases?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca will do. I basically think that anything else than single cpu/gpu batch scaling is not supported

r"""
Will iteratively try to find the largest batch size for a given model
that does not give an out of memory (OOM) error.

Args:
trainer: The Trainer

model: Model to fit.

mode: string setting the search mode. Either `power` or `binsearch`.
Expand All @@ -53,7 +54,7 @@ def scale_batch_size(
batch size that failed.

steps_per_trial: number of steps to run with a given batch size.
Idealy 1 should be enough to test if a OOM error occurs,
Ideally 1 should be enough to test if a OOM error occurs,
however in practise a few are needed

init_val: initial batch size to start the search with
Expand Down Expand Up @@ -113,7 +114,7 @@ def scale_batch_size(
trainer.progress_bar_callback.disable()

# Initially we just double in size until an OOM is encountered
new_size = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val
if mode == 'power':
new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs)
elif mode == 'binsearch':
Expand All @@ -139,7 +140,7 @@ def scale_batch_size(
return new_size


def __scale_batch_dump_params(trainer):
def __scale_batch_dump_params(trainer: 'pl.Trainer') -> None:
# Prevent going into infinite loop
trainer.__dumped_params = {
'auto_lr_find': trainer.auto_lr_find,
Expand All @@ -155,7 +156,7 @@ def __scale_batch_dump_params(trainer):
}


def __scale_batch_reset_params(trainer, model, steps_per_trial):
def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule', steps_per_trial: int) -> None:
trainer.auto_scale_batch_size = None # prevent recursion
trainer.auto_lr_find = False # avoid lr find being called multiple times
trainer.current_epoch = 0
Expand All @@ -168,7 +169,7 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial):
trainer.model = model # required for saving


def __scale_batch_restore_params(trainer):
def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None:
trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find']
trainer.current_epoch = trainer.__dumped_params['current_epoch']
trainer.max_steps = trainer.__dumped_params['max_steps']
Expand All @@ -181,9 +182,11 @@ def __scale_batch_restore_params(trainer):
del trainer.__dumped_params


def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs):
""" Batch scaling mode where the size is doubled at each iteration until an
OOM error is encountered. """
def _run_power_scaling(
trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int,
**fit_kwargs
) -> int:
""" Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """
for _ in range(max_trials):
garbage_collection_cuda()
trainer.global_step = 0 # reset after each try
Expand All @@ -207,7 +210,10 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f
return new_size


def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs):
def _run_binsearch_scaling(
trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int,
**fit_kwargs
) -> int:
""" Batch scaling mode where the size is initially is doubled at each iteration
until an OOM error is encountered. Hereafter, the batch size is further
refined using a binary search """
Expand Down Expand Up @@ -252,7 +258,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials,


def _adjust_batch_size(
trainer,
trainer: 'pl.Trainer',
batch_arg_name: str = 'batch_size',
factor: float = 1.0,
value: Optional[int] = None,
Expand Down
Loading