From f6cae56b3381d5bf0fefce5a18b03732898d1849 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 17 Jun 2020 08:01:41 -0400 Subject: [PATCH] Revert "Misleading exception raised during batch scaling (#1973)" This reverts commit f8103f9c7dfc35b4198e951a1789cae534c8b1db. --- pytorch_lightning/trainer/training_tricks.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index f69b6230aa116..817215202992f 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -136,10 +136,7 @@ def scale_batch_size(self, """ if not hasattr(model, batch_arg_name): - if not hasattr(model.hparams, batch_arg_name): - raise MisconfigurationException( - 'Neither of `model.batch_size` and `model.hparams.batch_size` found.' - ) + raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`') if hasattr(model.train_dataloader, 'patch_loader_code'): raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders' @@ -248,15 +245,9 @@ def _adjust_batch_size(trainer, """ model = trainer.get_model() - if hasattr(model, batch_arg_name): - batch_size = getattr(model, batch_arg_name) - else: - batch_size = getattr(model.hparams, batch_arg_name) + batch_size = getattr(model, batch_arg_name) if value: - if hasattr(model, batch_arg_name): - setattr(model, batch_arg_name, value) - else: - setattr(model.hparams, batch_arg_name, value) + setattr(model, batch_arg_name, value) new_size = value if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') @@ -264,7 +255,7 @@ def _adjust_batch_size(trainer, new_size = int(batch_size * factor) if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') - setattr(model.hparams, batch_arg_name, new_size) + setattr(model, batch_arg_name, new_size) return new_size