1515import os
1616from typing import Optional , Tuple
1717
18- from pytorch_lightning . core . lightning import LightningModule
18+ import pytorch_lightning as pl
1919from pytorch_lightning .loggers .base import DummyLogger
2020from pytorch_lightning .utilities import DeviceType , rank_zero_warn
2121from pytorch_lightning .utilities .cloud_io import get_filesystem
2828
2929
3030def scale_batch_size (
31- trainer ,
32- model : LightningModule ,
31+ trainer : 'pl.Trainer' ,
32+ model : 'pl. LightningModule' ,
3333 mode : str = 'power' ,
3434 steps_per_trial : int = 3 ,
3535 init_val : int = 2 ,
3636 max_trials : int = 25 ,
3737 batch_arg_name : str = 'batch_size' ,
3838 ** fit_kwargs
39- ):
39+ ) -> Optional [ int ] :
4040 r"""
4141 Will iteratively try to find the largest batch size for a given model
4242 that does not give an out of memory (OOM) error.
4343
4444 Args:
4545 trainer: The Trainer
46+
4647 model: Model to fit.
4748
4849 mode: string setting the search mode. Either `power` or `binsearch`.
@@ -53,7 +54,7 @@ def scale_batch_size(
5354 batch size that failed.
5455
5556 steps_per_trial: number of steps to run with a given batch size.
56- Idealy 1 should be enough to test if a OOM error occurs,
57+ Ideally 1 should be enough to test if a OOM error occurs,
5758 however in practise a few are needed
5859
5960 init_val: initial batch size to start the search with
@@ -113,7 +114,7 @@ def scale_batch_size(
113114 trainer .progress_bar_callback .disable ()
114115
115116 # Initially we just double in size until an OOM is encountered
116- new_size = _adjust_batch_size (trainer , batch_arg_name , value = init_val ) # initially set to init_val
117+ new_size , _ = _adjust_batch_size (trainer , batch_arg_name , value = init_val ) # initially set to init_val
117118 if mode == 'power' :
118119 new_size = _run_power_scaling (trainer , model , new_size , batch_arg_name , max_trials , ** fit_kwargs )
119120 elif mode == 'binsearch' :
@@ -139,7 +140,7 @@ def scale_batch_size(
139140 return new_size
140141
141142
142- def __scale_batch_dump_params (trainer ) :
143+ def __scale_batch_dump_params (trainer : 'pl.Trainer' ) -> None :
143144 # Prevent going into infinite loop
144145 trainer .__dumped_params = {
145146 'auto_lr_find' : trainer .auto_lr_find ,
@@ -155,7 +156,7 @@ def __scale_batch_dump_params(trainer):
155156 }
156157
157158
158- def __scale_batch_reset_params (trainer , model , steps_per_trial ) :
159+ def __scale_batch_reset_params (trainer : 'pl.Trainer' , model : 'pl.LightningModule' , steps_per_trial : int ) -> None :
159160 trainer .auto_scale_batch_size = None # prevent recursion
160161 trainer .auto_lr_find = False # avoid lr find being called multiple times
161162 trainer .current_epoch = 0
@@ -168,7 +169,7 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial):
168169 trainer .model = model # required for saving
169170
170171
171- def __scale_batch_restore_params (trainer ) :
172+ def __scale_batch_restore_params (trainer : 'pl.Trainer' ) -> None :
172173 trainer .auto_lr_find = trainer .__dumped_params ['auto_lr_find' ]
173174 trainer .current_epoch = trainer .__dumped_params ['current_epoch' ]
174175 trainer .max_steps = trainer .__dumped_params ['max_steps' ]
@@ -181,9 +182,11 @@ def __scale_batch_restore_params(trainer):
181182 del trainer .__dumped_params
182183
183184
184- def _run_power_scaling (trainer , model , new_size , batch_arg_name , max_trials , ** fit_kwargs ):
185- """ Batch scaling mode where the size is doubled at each iteration until an
186- OOM error is encountered. """
185+ def _run_power_scaling (
186+ trainer : 'pl.Trainer' , model : 'pl.LightningModule' , new_size : int , batch_arg_name : str , max_trials : int ,
187+ ** fit_kwargs
188+ ) -> int :
189+ """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """
187190 for _ in range (max_trials ):
188191 garbage_collection_cuda ()
189192 trainer .global_step = 0 # reset after each try
@@ -207,7 +210,10 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f
207210 return new_size
208211
209212
210- def _run_binsearch_scaling (trainer , model , new_size , batch_arg_name , max_trials , ** fit_kwargs ):
213+ def _run_binsearch_scaling (
214+ trainer : 'pl.Trainer' , model : 'pl.LightningModule' , new_size : int , batch_arg_name : str , max_trials : int ,
215+ ** fit_kwargs
216+ ) -> int :
211217 """ Batch scaling mode where the size is initially is doubled at each iteration
212218 until an OOM error is encountered. Hereafter, the batch size is further
213219 refined using a binary search """
@@ -252,7 +258,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials,
252258
253259
254260def _adjust_batch_size (
255- trainer ,
261+ trainer : 'pl.Trainer' ,
256262 batch_arg_name : str = 'batch_size' ,
257263 factor : float = 1.0 ,
258264 value : Optional [int ] = None ,
0 commit comments