@@ -150,7 +150,10 @@ def scale_batch_size(self,
150150
151151 """
152152 if not hasattr (model , batch_arg_name ):
153- raise MisconfigurationException (f'Field { batch_arg_name } not found in `model.hparams`' )
153+ if not hasattr (model .hparams , batch_arg_name ):
154+ raise MisconfigurationException (
155+ 'Neither of `model.batch_size` and `model.hparams.batch_size` found.'
156+ )
154157
155158 if hasattr (model .train_dataloader , 'patch_loader_code' ):
156159 raise MisconfigurationException ('The batch scaling feature cannot be used with dataloaders'
@@ -256,17 +259,23 @@ def _adjust_batch_size(trainer,
256259
257260 """
258261 model = trainer .get_model ()
259- batch_size = getattr (model , batch_arg_name )
262+ if hasattr (model , batch_arg_name ):
263+ batch_size = getattr (model , batch_arg_name )
264+ else :
265+ batch_size = getattr (model .hparams , batch_arg_name )
260266 if value :
261- setattr (model , batch_arg_name , value )
267+ if hasattr (model , batch_arg_name ):
268+ setattr (model , batch_arg_name , value )
269+ else :
270+ setattr (model .hparams , batch_arg_name , value )
262271 new_size = value
263272 if desc :
264273 log .info (f'Batch size { batch_size } { desc } , trying batch size { new_size } ' )
265274 else :
266275 new_size = int (batch_size * factor )
267276 if desc :
268277 log .info (f'Batch size { batch_size } { desc } , trying batch size { new_size } ' )
269- setattr (model , batch_arg_name , new_size )
278+ setattr (model . hparams , batch_arg_name , new_size )
270279 return new_size
271280
272281
0 commit comments