@@ -136,10 +136,7 @@ def scale_batch_size(self,
136136
137137 """
138138 if not hasattr (model , batch_arg_name ):
139- if not hasattr (model .hparams , batch_arg_name ):
140- raise MisconfigurationException (
141- 'Neither of `model.batch_size` and `model.hparams.batch_size` found.'
142- )
139+ raise MisconfigurationException (f'Field { batch_arg_name } not found in `model.hparams`' )
143140
144141 if hasattr (model .train_dataloader , 'patch_loader_code' ):
145142 raise MisconfigurationException ('The batch scaling feature cannot be used with dataloaders'
@@ -248,23 +245,17 @@ def _adjust_batch_size(trainer,
248245
249246 """
250247 model = trainer .get_model ()
251- if hasattr (model , batch_arg_name ):
252- batch_size = getattr (model , batch_arg_name )
253- else :
254- batch_size = getattr (model .hparams , batch_arg_name )
248+ batch_size = getattr (model , batch_arg_name )
255249 if value :
256- if hasattr (model , batch_arg_name ):
257- setattr (model , batch_arg_name , value )
258- else :
259- setattr (model .hparams , batch_arg_name , value )
250+ setattr (model , batch_arg_name , value )
260251 new_size = value
261252 if desc :
262253 log .info (f'Batch size { batch_size } { desc } , trying batch size { new_size } ' )
263254 else :
264255 new_size = int (batch_size * factor )
265256 if desc :
266257 log .info (f'Batch size { batch_size } { desc } , trying batch size { new_size } ' )
267- setattr (model . hparams , batch_arg_name , new_size )
258+ setattr (model , batch_arg_name , new_size )
268259 return new_size
269260
270261
0 commit comments