Skip to content

Commit 189ed25

Browse files
SkafteNickiSeanNaren
authored andcommitted
Skip tuner algorithms on fast dev (#3903)
* skip on fast dev * fix error * changelog * fix recursive issue * combine tests * pep8 * move logic to base funcs * fix mistake * Update pytorch_lightning/tuner/lr_finder.py Co-authored-by: Rohit Gupta <[email protected]> * pep Co-authored-by: William Falcon <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: chaton <[email protected]> (cherry picked from commit 4f3160b)
1 parent 7b8931f commit 189ed25

File tree

4 files changed

+43
-5
lines changed

4 files changed

+43
-5
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Added `fsspec` to tuner ([#4458](https://github.com/PyTorchLightning/pytorch-lightning/pull/4458))
3131

3232

33-
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
33+
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
3434

3535

3636
### Changed
3737

38-
38+
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
3939

4040
### Deprecated
4141

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def scale_batch_size(trainer,
6868
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
6969
or datamodule.
7070
"""
71+
if trainer.fast_dev_run:
72+
rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning)
73+
return
74+
7175
if not lightning_hasattr(model, batch_arg_name):
7276
raise MisconfigurationException(
7377
f'Field {batch_arg_name} not found in both `model` and `model.hparams`')

pytorch_lightning/tuner/lr_finder.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from pytorch_lightning.loggers.base import DummyLogger
3030
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3131
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
32+
from pytorch_lightning.utilities import rank_zero_warn
33+
from pytorch_lightning.utilities.cloud_io import get_filesystem
3234

3335
# check if ipywidgets is installed before importing tqdm.auto
3436
# to ensure it won't fail and a progress bar is displayed
@@ -41,6 +43,10 @@
4143
def _run_lr_finder_internally(trainer, model: LightningModule):
4244
""" Call lr finder internally during Trainer.fit() """
4345
lr_finder = lr_find(trainer, model)
46+
47+
if lr_finder is None:
48+
return
49+
4450
lr = lr_finder.suggestion()
4551

4652
# TODO: log lr.results to self.logger
@@ -130,7 +136,11 @@ def lr_find(
130136
trainer.fit(model)
131137
132138
"""
133-
save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp.ckpt')
139+
if trainer.fast_dev_run:
140+
rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning)
141+
return
142+
143+
save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt')
134144

135145
__lr_finder_dump_params(trainer, model)
136146

@@ -181,8 +191,11 @@ def lr_find(
181191
lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose
182192

183193
# Reset model state
184-
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
185-
os.remove(save_path)
194+
if trainer.is_global_zero:
195+
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
196+
fs = get_filesystem(str(save_path))
197+
if fs.exists(save_path):
198+
fs.rm(save_path)
186199

187200
# Finish by resetting variables so trainer is ready to fit model
188201
__lr_finder_restore_params(trainer, model)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
from pytorch_lightning import Trainer
3+
from tests.base import EvalModelTemplate
4+
5+
6+
@pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder'])
7+
def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg):
8+
""" Test that tuner algorithms are skipped if fast dev run is enabled """
9+
10+
hparams = EvalModelTemplate.get_default_hparams()
11+
model = EvalModelTemplate(**hparams)
12+
trainer = Trainer(
13+
default_root_dir=tmpdir,
14+
max_epochs=2,
15+
auto_scale_batch_size=True if tuner_alg == 'batch size scaler' else False,
16+
auto_lr_find=True if tuner_alg == 'learning rate finder' else False,
17+
fast_dev_run=True
18+
)
19+
expected_message = f'Skipping {tuner_alg} since `fast_dev_run=True`'
20+
with pytest.warns(UserWarning, match=expected_message):
21+
trainer.tune(model)

0 commit comments

Comments
 (0)