From 26cf61dfd3a8b06e5a53f8fb524c426d7e9b24fa Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 6 Oct 2020 16:37:50 +0200 Subject: [PATCH 01/10] skip on fast dev --- pytorch_lightning/tuner/lr_finder.py | 3 +++ pytorch_lightning/tuner/tuning.py | 9 +++++++++ tests/trainer/test_lr_finder.py | 16 ++++++++++++++++ tests/trainer/test_trainer_tricks.py | 16 ++++++++++++++++ 4 files changed, 44 insertions(+) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index a3ba2550186a7..3fc62a89351cb 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -27,6 +27,7 @@ from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr +from pytorch_lightning.utilities import rank_zero_warn # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -38,6 +39,8 @@ def _run_lr_finder_internally(trainer, model: LightningModule): """ Call lr finder internally during Trainer.fit() """ + if trainer.fast_dev_run: + rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) lr_finder = lr_find(trainer, model) lr = lr_finder.suggestion() diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 8c55ffac92c6a..4ea44562b18af 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -16,6 +16,7 @@ from pytorch_lightning.tuner.lr_finder import _run_lr_finder_internally, lr_find from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.utilities import rank_zero_warn from typing import Optional, List, Union from torch.utils.data import DataLoader @@ -37,6 +38,9 @@ def scale_batch_size(self, max_trials: int = 25, batch_arg_name: str = 'batch_size', **fit_kwargs): + if self.trainer.fast_dev_run: + rank_zero_warn('Skipping batch size scaler `fast_dev_run=True`', UserWarning) + return return scale_batch_size( self.trainer, model, mode, steps_per_trial, init_val, max_trials, batch_arg_name, **fit_kwargs ) @@ -53,6 +57,11 @@ def lr_find( early_stop_threshold: float = 4.0, datamodule: Optional[LightningDataModule] = None ): + import pdb + pdb.set_trace() + if self.trainer.fast_dev_run: + rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) + return return lr_find( self.trainer, model, diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 67c673df1318d..daef7c245eeb7 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -241,3 +241,19 @@ def test_suggestion_with_non_finite_values(tmpdir): assert before_lr == after_lr, \ 'Learning rate was altered because of non-finite loss values' + + +def test_skip_on_fast_dev_run_lr_find(tmpdir): + """ Test that learning rate finder is skipped if fast dev run is enabled """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + auto_lr_find=True, + fast_dev_run=True + ) + expected_message = 'Skipping learning rate finder since `fast_dev_run=True`' + with pytest.warns(UserWarning, match=expected_message): + trainer.tune(model) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index a9297576c6f14..1837519e98261 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -328,3 +328,19 @@ def test_auto_scale_batch_size_with_amp(tmpdir): assert trainer.amp_backend == AMPType.NATIVE assert trainer.scaler is not None assert batch_size_after != batch_size_before + + +def test_skip_on_fast_dev_run_batch_scaler(tmpdir): + """ Test that batch scaler is skipped if fast dev run is enabled """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + auto_scale_batch_size=True, + fast_dev_run=True + ) + expected_message = 'Skipping batch size scaler `fast_dev_run=True`' + with pytest.warns(UserWarning, match=expected_message): + trainer.tune(model) From 51b978adc907145c0876733f6c449b5d5eca4ad5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 6 Oct 2020 16:45:06 +0200 Subject: [PATCH 02/10] fix error --- pytorch_lightning/tuner/tuning.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 72c766ce7e5d1..4c685380b644a 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -93,8 +93,6 @@ def lr_find( early_stop_threshold: float = 4.0, datamodule: Optional[LightningDataModule] = None ): - import pdb - pdb.set_trace() if self.trainer.fast_dev_run: rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) return From 3bdc99b48eb6f6cb10e3d1846f164e3a1dfc049b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 6 Oct 2020 16:47:18 +0200 Subject: [PATCH 03/10] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 064bd756a2891..96bf5f22d411c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Swap `torch.load` for `fsspec` load in cloud_io loading ([#3692](https://github.com/PyTorchLightning/pytorch-lightning/pull/3692)) +- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) + ### Deprecated - Rename Trainer arguments `row_log_interval` >> `log_every_n_steps` and `log_save_interval` >> `flush_logs_every_n_steps` ([#3748](https://github.com/PyTorchLightning/pytorch-lightning/pull/3748)) From 8d7718498114f816879c881e5f83bfa377cde91c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 6 Oct 2020 15:25:23 -0400 Subject: [PATCH 04/10] fix recursive issue --- tests/trainer/flags/test_fast_dev_run.py | 19 +++++++++++++++++++ tests/trainer/test_trainer_tricks.py | 16 ---------------- 2 files changed, 19 insertions(+), 16 deletions(-) create mode 100644 tests/trainer/flags/test_fast_dev_run.py diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py new file mode 100644 index 0000000000000..dd19491e8bde7 --- /dev/null +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -0,0 +1,19 @@ +import pytest +from pytorch_lightning import Trainer +from tests.base import EvalModelTemplate + + +def test_skip_on_fast_dev_run_batch_scaler(tmpdir): + """ Test that batch scaler is skipped if fast dev run is enabled """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + auto_scale_batch_size=True, + fast_dev_run=True + ) + expected_message = 'Skipping batch size scaler `fast_dev_run=True`' + with pytest.warns(UserWarning, match=expected_message): + trainer.tune(model) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 1837519e98261..a9297576c6f14 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -328,19 +328,3 @@ def test_auto_scale_batch_size_with_amp(tmpdir): assert trainer.amp_backend == AMPType.NATIVE assert trainer.scaler is not None assert batch_size_after != batch_size_before - - -def test_skip_on_fast_dev_run_batch_scaler(tmpdir): - """ Test that batch scaler is skipped if fast dev run is enabled """ - - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=2, - auto_scale_batch_size=True, - fast_dev_run=True - ) - expected_message = 'Skipping batch size scaler `fast_dev_run=True`' - with pytest.warns(UserWarning, match=expected_message): - trainer.tune(model) From 09c89451ded35bbff85d9f3dd958c7f5bf3f8b2d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 7 Oct 2020 07:53:15 +0200 Subject: [PATCH 05/10] combine tests --- pytorch_lightning/tuner/tuning.py | 2 +- tests/trainer/flags/test_fast_dev_run.py | 10 ++++++---- tests/trainer/test_lr_finder.py | 16 ---------------- 3 files changed, 7 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 4c685380b644a..f6a1a6afa1bdf 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -75,7 +75,7 @@ def scale_batch_size(self, """ if self.trainer.fast_dev_run: - rank_zero_warn('Skipping batch size scaler `fast_dev_run=True`', UserWarning) + rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning) return return scale_batch_size( self.trainer, model, mode, steps_per_trial, init_val, max_trials, batch_arg_name, **fit_kwargs diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index dd19491e8bde7..727614ecf1ef6 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -2,8 +2,8 @@ from pytorch_lightning import Trainer from tests.base import EvalModelTemplate - -def test_skip_on_fast_dev_run_batch_scaler(tmpdir): +@pytest.mark.parametrize('tuner_alg', ['scale_batch_size', 'lr_find']) +def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg): """ Test that batch scaler is skipped if fast dev run is enabled """ hparams = EvalModelTemplate.get_default_hparams() @@ -11,9 +11,11 @@ def test_skip_on_fast_dev_run_batch_scaler(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - auto_scale_batch_size=True, + auto_scale_batch_size=True if tuner_alg=='scale_batch_size' else False, + auto_lr_find=True if tuner_alg=='lr_find' else False, fast_dev_run=True ) - expected_message = 'Skipping batch size scaler `fast_dev_run=True`' + alg = 'batch size scaler' if tuner_alg=='scale_batch_size' else 'learning rate finder' + expected_message = f'Skipping {alg} since `fast_dev_run=True`' with pytest.warns(UserWarning, match=expected_message): trainer.tune(model) diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 272c722d355bf..67eb480a71c61 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -244,19 +244,3 @@ def test_suggestion_with_non_finite_values(tmpdir): assert before_lr == after_lr, \ 'Learning rate was altered because of non-finite loss values' - - -def test_skip_on_fast_dev_run_lr_find(tmpdir): - """ Test that learning rate finder is skipped if fast dev run is enabled """ - - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=2, - auto_lr_find=True, - fast_dev_run=True - ) - expected_message = 'Skipping learning rate finder since `fast_dev_run=True`' - with pytest.warns(UserWarning, match=expected_message): - trainer.tune(model) From 07d2f8e6b74f3dc22394137000155d6f5f8a9ad5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 7 Oct 2020 07:55:55 +0200 Subject: [PATCH 06/10] pep8 --- tests/trainer/flags/test_fast_dev_run.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 727614ecf1ef6..cbe4d4012227a 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -2,20 +2,20 @@ from pytorch_lightning import Trainer from tests.base import EvalModelTemplate -@pytest.mark.parametrize('tuner_alg', ['scale_batch_size', 'lr_find']) + +@pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder']) def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg): - """ Test that batch scaler is skipped if fast dev run is enabled """ + """ Test that tuner algorithms are skipped if fast dev run is enabled """ hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - auto_scale_batch_size=True if tuner_alg=='scale_batch_size' else False, - auto_lr_find=True if tuner_alg=='lr_find' else False, + auto_scale_batch_size=True if tuner_alg == 'batch size scaler' else False, + auto_lr_find=True if tuner_alg == 'learning rate finder' else False, fast_dev_run=True ) - alg = 'batch size scaler' if tuner_alg=='scale_batch_size' else 'learning rate finder' - expected_message = f'Skipping {alg} since `fast_dev_run=True`' + expected_message = f'Skipping {tuner_alg} since `fast_dev_run=True`' with pytest.warns(UserWarning, match=expected_message): trainer.tune(model) From d0bfab253ef6d69042f146ec1836d1acd7ba3758 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 9 Oct 2020 10:52:21 +0200 Subject: [PATCH 07/10] move logic to base funcs --- pytorch_lightning/tuner/batch_size_scaling.py | 4 ++++ pytorch_lightning/tuner/lr_finder.py | 6 ++++-- pytorch_lightning/tuner/tuning.py | 6 ------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 8b2e05c66b753..046a25dd5427d 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -67,6 +67,10 @@ def scale_batch_size(trainer, **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader or datamodule. """ + if trainer.fast_dev_run: + rank_zero_warn('Skipping batch size scaler `fast_dev_run=True`', UserWarning) + return + if not lightning_hasattr(model, batch_arg_name): raise MisconfigurationException( f'Field {batch_arg_name} not found in both `model` and `model.hparams`') diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 5554272d1c223..2a47cdd329806 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -41,8 +41,6 @@ def _run_lr_finder_internally(trainer, model: LightningModule): """ Call lr finder internally during Trainer.fit() """ - if trainer.fast_dev_run: - rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) lr_finder = lr_find(trainer, model) lr = lr_finder.suggestion() @@ -133,6 +131,10 @@ def lr_find( trainer.fit(model) """ + if trainer.fast_dev_run: + rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) + return + save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp.ckpt') __lr_finder_dump_params(trainer, model) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 4c685380b644a..71d8b24e44afc 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -74,9 +74,6 @@ def scale_batch_size(self, or datamodule. """ - if self.trainer.fast_dev_run: - rank_zero_warn('Skipping batch size scaler `fast_dev_run=True`', UserWarning) - return return scale_batch_size( self.trainer, model, mode, steps_per_trial, init_val, max_trials, batch_arg_name, **fit_kwargs ) @@ -93,9 +90,6 @@ def lr_find( early_stop_threshold: float = 4.0, datamodule: Optional[LightningDataModule] = None ): - if self.trainer.fast_dev_run: - rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) - return return lr_find( self.trainer, model, From 4e1574541b5c048362e7de7a597974f28422512a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 9 Oct 2020 10:58:12 +0200 Subject: [PATCH 08/10] fix mistake --- pytorch_lightning/tuner/batch_size_scaling.py | 2 +- pytorch_lightning/tuner/lr_finder.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 41618fa393ec1..51be258ff267f 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -68,7 +68,7 @@ def scale_batch_size(trainer, or datamodule. """ if trainer.fast_dev_run: - rank_zero_warn('Skipping batch size scaler `fast_dev_run=True`', UserWarning) + rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning) return if not lightning_hasattr(model, batch_arg_name): diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 2a47cdd329806..3c8a6532d03af 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -41,6 +41,9 @@ def _run_lr_finder_internally(trainer, model: LightningModule): """ Call lr finder internally during Trainer.fit() """ + if trainer.fast_dev_run: + rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) + return lr_finder = lr_find(trainer, model) lr = lr_finder.suggestion() From dd94000b00732ac7a4255407d9b26035250668c9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 9 Oct 2020 13:02:40 +0200 Subject: [PATCH 09/10] Update pytorch_lightning/tuner/lr_finder.py Co-authored-by: Rohit Gupta --- pytorch_lightning/tuner/lr_finder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 3c8a6532d03af..d36c5aad5bfd8 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -41,10 +41,11 @@ def _run_lr_finder_internally(trainer, model: LightningModule): """ Call lr finder internally during Trainer.fit() """ - if trainer.fast_dev_run: - rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) - return lr_finder = lr_find(trainer, model) + + if lr_finder is None: + return + lr = lr_finder.suggestion() # TODO: log lr.results to self.logger From 8deadadfa6f65110e70cdee722b2575db3aa8837 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 3 Nov 2020 22:49:09 +0530 Subject: [PATCH 10/10] pep --- pytorch_lightning/tuner/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index e32e9188d675a..b6d8c8178093b 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -43,7 +43,7 @@ def _run_lr_finder_internally(trainer, model: LightningModule): """ Call lr finder internally during Trainer.fit() """ lr_finder = lr_find(trainer, model) - + if lr_finder is None: return