Skip to content

Commit 8ab5bcd

Browse files
tejasviBorda
andauthored
Misleading exception raised during batch scaling (#2223)
* Misleading exception raised during batch scaling Use batch_size from `model.hparams.batch_size` instead of `model.batch_size` * Improvements considering #1896 * Apply suggestions from code review Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 9edda9a commit 8ab5bcd

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

pytorch_lightning/trainer/training_tricks.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ def scale_batch_size(self,
150150
151151
"""
152152
if not hasattr(model, batch_arg_name):
153-
raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`')
153+
if not hasattr(model.hparams, batch_arg_name):
154+
raise MisconfigurationException(
155+
'Neither of `model.batch_size` and `model.hparams.batch_size` found.'
156+
)
154157

155158
if hasattr(model.train_dataloader, 'patch_loader_code'):
156159
raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders'
@@ -256,17 +259,23 @@ def _adjust_batch_size(trainer,
256259
257260
"""
258261
model = trainer.get_model()
259-
batch_size = getattr(model, batch_arg_name)
262+
if hasattr(model, batch_arg_name):
263+
batch_size = getattr(model, batch_arg_name)
264+
else:
265+
batch_size = getattr(model.hparams, batch_arg_name)
260266
if value:
261-
setattr(model, batch_arg_name, value)
267+
if hasattr(model, batch_arg_name):
268+
setattr(model, batch_arg_name, value)
269+
else:
270+
setattr(model.hparams, batch_arg_name, value)
262271
new_size = value
263272
if desc:
264273
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
265274
else:
266275
new_size = int(batch_size * factor)
267276
if desc:
268277
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
269-
setattr(model, batch_arg_name, new_size)
278+
setattr(model.hparams, batch_arg_name, new_size)
270279
return new_size
271280

272281

0 commit comments

Comments
 (0)