File tree Expand file tree Collapse file tree 4 files changed +36
-2
lines changed Expand file tree Collapse file tree 4 files changed +36
-2
lines changed Original file line number Diff line number Diff line change @@ -30,12 +30,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030- Added ` fsspec ` to tuner ([ #4458 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/4458 ) )
3131
3232
33- - Added metrics aggregation in Horovod and fixed early stopping ([ #3775 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/3775 ) )
33+ - Added metrics aggregation in Horovod and fixed early stopping ([ #3775 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/3775 ) )
3434
3535
3636### Changed
3737
38-
38+ - Tuner algorithms will be skipped if ` fast_dev_run=True ` ( [ # 3903 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/3903 ) )
3939
4040### Deprecated
4141
Original file line number Diff line number Diff line change @@ -69,6 +69,10 @@ def scale_batch_size(trainer,
6969 **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
7070 or datamodule.
7171 """
72+ if trainer .fast_dev_run :
73+ rank_zero_warn ('Skipping batch size scaler since `fast_dev_run=True`' , UserWarning )
74+ return
75+
7276 if not lightning_hasattr (model , batch_arg_name ):
7377 raise MisconfigurationException (
7478 f'Field { batch_arg_name } not found in both `model` and `model.hparams`' )
Original file line number Diff line number Diff line change 2929from pytorch_lightning .loggers .base import DummyLogger
3030from pytorch_lightning .utilities .exceptions import MisconfigurationException
3131from pytorch_lightning .utilities .parsing import lightning_hasattr , lightning_setattr
32+ from pytorch_lightning .utilities import rank_zero_warn
3233from pytorch_lightning .utilities .cloud_io import get_filesystem
3334
3435# check if ipywidgets is installed before importing tqdm.auto
4243def _run_lr_finder_internally (trainer , model : LightningModule ):
4344 """ Call lr finder internally during Trainer.fit() """
4445 lr_finder = lr_find (trainer , model )
46+
47+ if lr_finder is None :
48+ return
49+
4550 lr = lr_finder .suggestion ()
4651
4752 # TODO: log lr.results to self.logger
@@ -131,6 +136,10 @@ def lr_find(
131136 trainer.fit(model)
132137
133138 """
139+ if trainer .fast_dev_run :
140+ rank_zero_warn ('Skipping learning rate finder since `fast_dev_run=True`' , UserWarning )
141+ return
142+
134143 save_path = os .path .join (trainer .default_root_dir , 'lr_find_temp_model.ckpt' )
135144
136145 __lr_finder_dump_params (trainer , model )
Original file line number Diff line number Diff line change 1+ import pytest
2+ from pytorch_lightning import Trainer
3+ from tests .base import EvalModelTemplate
4+
5+
6+ @pytest .mark .parametrize ('tuner_alg' , ['batch size scaler' , 'learning rate finder' ])
7+ def test_skip_on_fast_dev_run_batch_scaler (tmpdir , tuner_alg ):
8+ """ Test that tuner algorithms are skipped if fast dev run is enabled """
9+
10+ hparams = EvalModelTemplate .get_default_hparams ()
11+ model = EvalModelTemplate (** hparams )
12+ trainer = Trainer (
13+ default_root_dir = tmpdir ,
14+ max_epochs = 2 ,
15+ auto_scale_batch_size = True if tuner_alg == 'batch size scaler' else False ,
16+ auto_lr_find = True if tuner_alg == 'learning rate finder' else False ,
17+ fast_dev_run = True
18+ )
19+ expected_message = f'Skipping { tuner_alg } since `fast_dev_run=True`'
20+ with pytest .warns (UserWarning , match = expected_message ):
21+ trainer .tune (model )
You can’t perform that action at this time.
0 commit comments