Skip to content

Commit d649127

Browse files
SkafteNickiwilliamFalconNicki Skafterohitgr7tchaton
committed
Skip tuner algorithms on fast dev (#3903)
* skip on fast dev * fix error * changelog * fix recursive issue * combine tests * pep8 * move logic to base funcs * fix mistake * Update pytorch_lightning/tuner/lr_finder.py Co-authored-by: Rohit Gupta <[email protected]> * pep Co-authored-by: William Falcon <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: chaton <[email protected]>
1 parent 90d692f commit d649127

File tree

4 files changed

+36
-2
lines changed

4 files changed

+36
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Added `fsspec` to tuner ([#4458](https://github.com/PyTorchLightning/pytorch-lightning/pull/4458))
3131

3232

33-
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
33+
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
3434

3535

3636
### Changed
3737

38-
38+
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
3939

4040
### Deprecated
4141

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def scale_batch_size(trainer,
6969
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
7070
or datamodule.
7171
"""
72+
if trainer.fast_dev_run:
73+
rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning)
74+
return
75+
7276
if not lightning_hasattr(model, batch_arg_name):
7377
raise MisconfigurationException(
7478
f'Field {batch_arg_name} not found in both `model` and `model.hparams`')

pytorch_lightning/tuner/lr_finder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pytorch_lightning.loggers.base import DummyLogger
3030
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3131
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
32+
from pytorch_lightning.utilities import rank_zero_warn
3233
from pytorch_lightning.utilities.cloud_io import get_filesystem
3334

3435
# check if ipywidgets is installed before importing tqdm.auto
@@ -42,6 +43,10 @@
4243
def _run_lr_finder_internally(trainer, model: LightningModule):
4344
""" Call lr finder internally during Trainer.fit() """
4445
lr_finder = lr_find(trainer, model)
46+
47+
if lr_finder is None:
48+
return
49+
4550
lr = lr_finder.suggestion()
4651

4752
# TODO: log lr.results to self.logger
@@ -131,6 +136,10 @@ def lr_find(
131136
trainer.fit(model)
132137
133138
"""
139+
if trainer.fast_dev_run:
140+
rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning)
141+
return
142+
134143
save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt')
135144

136145
__lr_finder_dump_params(trainer, model)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
from pytorch_lightning import Trainer
3+
from tests.base import EvalModelTemplate
4+
5+
6+
@pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder'])
7+
def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg):
8+
""" Test that tuner algorithms are skipped if fast dev run is enabled """
9+
10+
hparams = EvalModelTemplate.get_default_hparams()
11+
model = EvalModelTemplate(**hparams)
12+
trainer = Trainer(
13+
default_root_dir=tmpdir,
14+
max_epochs=2,
15+
auto_scale_batch_size=True if tuner_alg == 'batch size scaler' else False,
16+
auto_lr_find=True if tuner_alg == 'learning rate finder' else False,
17+
fast_dev_run=True
18+
)
19+
expected_message = f'Skipping {tuner_alg} since `fast_dev_run=True`'
20+
with pytest.warns(UserWarning, match=expected_message):
21+
trainer.tune(model)

0 commit comments

Comments
 (0)