-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Print a better message for lr_find when fast_dev_run=True.
Motivation
Currently,
model = LitClassifier()
dm = MNISTDataModule()
trainer = pl.Trainer(fast_dev_run=True)
lr_finder = trainer.lr_find(model, dm)
print(lr_finder.results)Gives the following message:
LR finder stopped early due to diverging loss.
This is misleading since the error is due to fast_dev_run=True limiting the number of train batches to 1. The same error message is printed whenever lr_finder doesn't run num_training (default=100) number of training batches. https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/lr_finder.py#L205
Pitch
When running lr_find, check whether fast_dev_run=True. If True, print a warning, skip the lr_find and return initially set lr instead or returning None.
This way, one can do a quick fast_dev_run, even with auto_lr_find=True.
Alternatives
Raise a MissConfigurationError within lr_find when fast_dev_run=True. Although I don't really like this since there are times where I just want to check whether a fast_dev_run works without changing anything else (e.g. Not changing auto_lr_find=True to False) before doing a longer experiment.