-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Code cleaning in preparation for #7258 [3/n] #7262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| import os | ||
| from typing import Optional, Tuple | ||
|
|
||
| from pytorch_lightning.core.lightning import LightningModule | ||
| import pytorch_lightning as pl | ||
| from pytorch_lightning.loggers.base import DummyLogger | ||
| from pytorch_lightning.utilities import DeviceType, rank_zero_warn | ||
| from pytorch_lightning.utilities.cloud_io import get_filesystem | ||
|
|
@@ -28,21 +28,22 @@ | |
|
|
||
|
|
||
| def scale_batch_size( | ||
| trainer, | ||
| model: LightningModule, | ||
| trainer: 'pl.Trainer', | ||
| model: 'pl.LightningModule', | ||
| mode: str = 'power', | ||
| steps_per_trial: int = 3, | ||
| init_val: int = 2, | ||
| max_trials: int = 25, | ||
| batch_arg_name: str = 'batch_size', | ||
| **fit_kwargs | ||
| ): | ||
| ) -> Optional[int]: | ||
|
Comment on lines
30
to
+39
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should have some caveats that the tuner doesn't work with things like deepspeed or sharded ddp which have different behavior on multiple gpus right?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with this. In general
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SkafteNicki since you are the most familiar with the tuner limitations, can you open a PR showing warnings or raising an error for these cases?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @carmocca will do. I basically think that anything else than single cpu/gpu batch scaling is not supported
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| r""" | ||
| Will iteratively try to find the largest batch size for a given model | ||
| that does not give an out of memory (OOM) error. | ||
|
|
||
| Args: | ||
| trainer: The Trainer | ||
|
|
||
| model: Model to fit. | ||
|
|
||
| mode: string setting the search mode. Either `power` or `binsearch`. | ||
|
|
@@ -53,7 +54,7 @@ def scale_batch_size( | |
| batch size that failed. | ||
|
|
||
| steps_per_trial: number of steps to run with a given batch size. | ||
| Idealy 1 should be enough to test if a OOM error occurs, | ||
| Ideally 1 should be enough to test if a OOM error occurs, | ||
| however in practise a few are needed | ||
|
|
||
| init_val: initial batch size to start the search with | ||
|
|
@@ -113,7 +114,7 @@ def scale_batch_size( | |
| trainer.progress_bar_callback.disable() | ||
|
|
||
| # Initially we just double in size until an OOM is encountered | ||
| new_size = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val | ||
| new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val | ||
| if mode == 'power': | ||
| new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) | ||
| elif mode == 'binsearch': | ||
|
|
@@ -139,7 +140,7 @@ def scale_batch_size( | |
| return new_size | ||
|
|
||
|
|
||
| def __scale_batch_dump_params(trainer): | ||
| def __scale_batch_dump_params(trainer: 'pl.Trainer') -> None: | ||
| # Prevent going into infinite loop | ||
| trainer.__dumped_params = { | ||
| 'auto_lr_find': trainer.auto_lr_find, | ||
|
|
@@ -155,7 +156,7 @@ def __scale_batch_dump_params(trainer): | |
| } | ||
|
|
||
|
|
||
| def __scale_batch_reset_params(trainer, model, steps_per_trial): | ||
| def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule', steps_per_trial: int) -> None: | ||
| trainer.auto_scale_batch_size = None # prevent recursion | ||
| trainer.auto_lr_find = False # avoid lr find being called multiple times | ||
| trainer.current_epoch = 0 | ||
|
|
@@ -168,7 +169,7 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial): | |
| trainer.model = model # required for saving | ||
|
|
||
|
|
||
| def __scale_batch_restore_params(trainer): | ||
| def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None: | ||
| trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] | ||
| trainer.current_epoch = trainer.__dumped_params['current_epoch'] | ||
| trainer.max_steps = trainer.__dumped_params['max_steps'] | ||
|
|
@@ -181,9 +182,11 @@ def __scale_batch_restore_params(trainer): | |
| del trainer.__dumped_params | ||
|
|
||
|
|
||
| def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): | ||
| """ Batch scaling mode where the size is doubled at each iteration until an | ||
| OOM error is encountered. """ | ||
| def _run_power_scaling( | ||
| trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int, | ||
| **fit_kwargs | ||
| ) -> int: | ||
| """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ | ||
| for _ in range(max_trials): | ||
| garbage_collection_cuda() | ||
| trainer.global_step = 0 # reset after each try | ||
|
|
@@ -207,7 +210,10 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f | |
| return new_size | ||
|
|
||
|
|
||
| def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): | ||
| def _run_binsearch_scaling( | ||
| trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int, | ||
| **fit_kwargs | ||
| ) -> int: | ||
| """ Batch scaling mode where the size is initially is doubled at each iteration | ||
| until an OOM error is encountered. Hereafter, the batch size is further | ||
| refined using a binary search """ | ||
|
|
@@ -252,7 +258,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, | |
|
|
||
|
|
||
| def _adjust_batch_size( | ||
| trainer, | ||
| trainer: 'pl.Trainer', | ||
| batch_arg_name: str = 'batch_size', | ||
| factor: float = 1.0, | ||
| value: Optional[int] = None, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.