Skip to content

Commit f6cae56

Browse files
Revert "Misleading exception raised during batch scaling (#1973)"
This reverts commit f8103f9.
1 parent f8103f9 commit f6cae56

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

pytorch_lightning/trainer/training_tricks.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,7 @@ def scale_batch_size(self,
136136
137137
"""
138138
if not hasattr(model, batch_arg_name):
139-
if not hasattr(model.hparams, batch_arg_name):
140-
raise MisconfigurationException(
141-
'Neither of `model.batch_size` and `model.hparams.batch_size` found.'
142-
)
139+
raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`')
143140

144141
if hasattr(model.train_dataloader, 'patch_loader_code'):
145142
raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders'
@@ -248,23 +245,17 @@ def _adjust_batch_size(trainer,
248245
249246
"""
250247
model = trainer.get_model()
251-
if hasattr(model, batch_arg_name):
252-
batch_size = getattr(model, batch_arg_name)
253-
else:
254-
batch_size = getattr(model.hparams, batch_arg_name)
248+
batch_size = getattr(model, batch_arg_name)
255249
if value:
256-
if hasattr(model, batch_arg_name):
257-
setattr(model, batch_arg_name, value)
258-
else:
259-
setattr(model.hparams, batch_arg_name, value)
250+
setattr(model, batch_arg_name, value)
260251
new_size = value
261252
if desc:
262253
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
263254
else:
264255
new_size = int(batch_size * factor)
265256
if desc:
266257
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
267-
setattr(model.hparams, batch_arg_name, new_size)
258+
setattr(model, batch_arg_name, new_size)
268259
return new_size
269260

270261

0 commit comments

Comments
 (0)