Skip to content
23 changes: 23 additions & 0 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,26 @@ def test_lr_finder_fails_fast_on_bad_config(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, auto_lr_find=True)
with pytest.raises(MisconfigurationException, match='should have one of these fields'):
trainer.tune(BoringModel())


def test_lr_find_with_bs_scale(tmpdir):
""" Test that lr_find runs with batch_size_scaling """

class BoringModelTune(BoringModel):
def __init__(self, learning_rate=0.1, batch_size=2):
super().__init__()
self.save_hyperparameters()

model = BoringModelTune()
before_lr = model.hparams.learning_rate

# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
)
bs = trainer.tuner.scale_batch_size(model)
lr = trainer.tuner.lr_find(model).suggestion()

assert lr != before_lr
assert isinstance(bs, int)