2424from pytorch_lightning .utilities import AMPType
2525from pytorch_lightning .utilities .exceptions import MisconfigurationException
2626from tests .base import EvalModelTemplate
27- from tests .helpers import BoringDataModule , BoringModel
27+ from tests .helpers import BoringDataModule , BoringModel , RandomDataset
2828from tests .helpers .datamodules import MNISTDataModule
2929from tests .helpers .runif import RunIf
3030
3131
3232class BatchSizeDataModule (BoringDataModule ):
3333
34- def __init__ (self , batch_size = None ):
34+ def __init__ (self , batch_size ):
3535 super ().__init__ ()
3636 if batch_size is not None :
3737 self .batch_size = batch_size
@@ -42,21 +42,23 @@ def train_dataloader(self):
4242
4343class BatchSizeModel (BoringModel ):
4444
45- def __init__ (self , batch_size = None ):
45+ def __init__ (self , batch_size ):
4646 super ().__init__ ()
4747 if batch_size is not None :
4848 self .batch_size = batch_size
4949
50+ def train_dataloader (self ):
51+ return DataLoader (RandomDataset (32 , 64 ), batch_size = getattr (self , "batch_size" , 1 ))
5052
51- @ pytest . mark . parametrize (
52- "model,datamodule" , [
53- ( BatchSizeModel ( 2 ), None ),
54- ( BatchSizeModel ( 2 ), BatchSizeDataModule ( 2 ) ),
55- ( BatchSizeModel ( 2 ), BatchSizeDataModule ( None ) ),
56- ( BatchSizeModel ( None ), BatchSizeDataModule ( 2 ) ),
57- ]
58- )
59- def test_scale_batch_size_method_with_model_or_datamodule (tmpdir , model , datamodule ):
53+
54+ @ pytest . mark . parametrize ([ "model_bs" , "dm_bs" ] , [
55+ ( 2 , - 1 ),
56+ ( 2 , 2 ),
57+ ( 2 , None ),
58+ ( None , 2 ),
59+ ( 16 , 16 ),
60+ ] )
61+ def test_scale_batch_size_method_with_model_or_datamodule (tmpdir , model_bs , dm_bs ):
6062 """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
6163 trainer = Trainer (
6264 default_root_dir = tmpdir ,
@@ -65,14 +67,21 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamod
6567 max_epochs = 1 ,
6668 )
6769 tuner = Tuner (trainer )
68- new_batch_size = tuner .scale_batch_size (
69- model = model , mode = "binsearch" , init_val = 4 , max_trials = 2 , datamodule = datamodule
70- )
70+
71+ model = BatchSizeModel (model_bs )
72+ datamodule = BatchSizeDataModule (dm_bs ) if dm_bs != - 1 else None
73+
74+ new_batch_size = tuner .scale_batch_size (model , mode = "binsearch" , init_val = 4 , max_trials = 2 , datamodule = datamodule )
7175 assert new_batch_size == 16
72- if hasattr (model , "batch_size" ):
73- assert model .batch_size == 16
74- if datamodule is not None and hasattr (datamodule , "batch_size" ):
75- assert datamodule .batch_size == 16
76+
77+ if model_bs is not None :
78+ assert model .batch_size == new_batch_size
79+ if dm_bs == - 1 :
80+ # datamodule batch size takes precedence
81+ assert trainer .train_dataloader .loaders .batch_size == new_batch_size
82+ if dm_bs not in (- 1 , None ):
83+ assert datamodule .batch_size == new_batch_size
84+ assert trainer .train_dataloader .loaders .batch_size == new_batch_size
7685
7786
7887def test_model_reset_correctly (tmpdir ):
0 commit comments