1414import logging
1515import os
1616import uuid
17- from typing import Any , Dict , Optional , Tuple
17+ from typing import Any , Dict , Union , List , Optional , Tuple
1818
1919from torch .utils .data import DataLoader
2020
2121import pytorch_lightning as pl
22- from pytorch_lightning .loggers .logger import DummyLogger
22+ from pytorch_lightning .callbacks .callback import Callback
23+ from pytorch_lightning .loggers .logger import DummyLogger , Logger
24+
2325from pytorch_lightning .utilities .data import has_len_all_ranks
2426from pytorch_lightning .utilities .exceptions import MisconfigurationException
2527from pytorch_lightning .utilities .memory import garbage_collection_cuda , is_oom_error
@@ -41,7 +43,7 @@ def scale_batch_size(
4143 """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`"""
4244 if trainer .fast_dev_run :
4345 rank_zero_warn ("Skipping batch size scaler since fast_dev_run is enabled." )
44- return
46+ return None
4547
4648 if not lightning_hasattr (model , batch_arg_name ):
4749 raise MisconfigurationException (f"Field { batch_arg_name } not found in both `model` and `model.hparams`" )
@@ -234,18 +236,26 @@ def _adjust_batch_size(
234236 """
235237 model = trainer .lightning_module
236238 batch_size = lightning_getattr (model , batch_arg_name )
237- new_size = value if value is not None else int (batch_size * factor )
239+ if value is not None :
240+ new_size = value
241+ else :
242+ if not isinstance (batch_size , int ):
243+ raise ValueError (f"value is None and batch_size is not int value: { batch_size } " )
244+ new_size = int (batch_size * factor )
245+
238246 if desc :
239247 log .info (f"Batch size { batch_size } { desc } , trying batch size { new_size } " )
240248
241249 if not _is_valid_batch_size (new_size , trainer .train_dataloader , trainer ):
250+ if not isinstance (trainer .train_dataloader , DataLoader ):
251+ raise ValueError ("train_dataloader is not a DataLoader" )
242252 new_size = min (new_size , len (trainer .train_dataloader .dataset ))
243253
244254 changed = new_size != batch_size
245255 lightning_setattr (model , batch_arg_name , new_size )
246256 return new_size , changed
247257
248258
249- def _is_valid_batch_size (batch_size : int , dataloader : DataLoader , trainer : "pl.Trainer" ):
259+ def _is_valid_batch_size (batch_size : int , dataloader : DataLoader , trainer : "pl.Trainer" ) -> bool :
250260 module = trainer .lightning_module or trainer .datamodule
251261 return not has_len_all_ranks (dataloader , trainer .strategy , module ) or batch_size <= len (dataloader )
0 commit comments