From 23f5615b01391075e25a9f555194ad7daf95c61a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Sep 2021 02:58:38 +0200 Subject: [PATCH 1/2] replace dev debugger in early stopping --- pytorch_lightning/callbacks/early_stopping.py | 4 ---- pytorch_lightning/utilities/debugging.py | 15 --------------- tests/callbacks/test_early_stopping.py | 6 ++++-- tests/trainer/flags/test_fast_dev_run.py | 5 +++-- 4 files changed, 7 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4623b6077dbb9..ecb46eab446d4 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -206,10 +206,6 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: return current = logs.get(self.monitor) - - # when in dev debugging - trainer.dev_debugger.track_early_stopping_history(self, current) - should_stop, reason = self._evaluate_stopping_criteria(current) # stop every ddp process if any world process decides to stop diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index f49463e09a76b..e8942b730d21a 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -43,7 +43,6 @@ class InternalDebugger: def __init__(self, trainer: "pl.Trainer") -> None: self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1" self.trainer = trainer - self.early_stopping_history: List[Dict[str, Any]] = [] self.checkpoint_callback_history: List[Dict[str, Any]] = [] self.events: List[Dict[str, Any]] = [] self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = [] @@ -126,20 +125,6 @@ def track_lr_schedulers_update( } self.saved_lr_scheduler_updates.append(loss_dict) - @enabled_only - def track_early_stopping_history( - self, callback: "pl.callbacks.early_stopping.EarlyStopping", current: torch.Tensor - ) -> None: - debug_dict = { - "epoch": self.trainer.current_epoch, - "global_step": self.trainer.global_step, - "rank": self.trainer.global_rank, - "current": current, - "best": callback.best_score, - "patience": callback.wait_count, - } - self.early_stopping_history.append(debug_dict) - @enabled_only def track_checkpointing_history(self, filepath: str) -> None: cb = self.trainer.checkpoint_callback diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index ccc2ca24bf669..e9961c96efac7 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -16,6 +16,7 @@ import pickle from typing import List, Optional from unittest import mock +from unittest.mock import Mock import cloudpickle import numpy as np @@ -98,12 +99,12 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): new_trainer.fit(model) -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_early_stopping_no_extraneous_invocations(tmpdir): """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" model = ClassificationModel() dm = ClassifDataModule() early_stop_callback = EarlyStopping(monitor="train_loss") + early_stop_callback._run_early_stopping_check = Mock() expected_count = 4 trainer = Trainer( default_root_dir=tmpdir, @@ -111,12 +112,13 @@ def test_early_stopping_no_extraneous_invocations(tmpdir): limit_train_batches=4, limit_val_batches=4, max_epochs=expected_count, + checkpoint_callback=False, ) trainer.fit(model, datamodule=dm) assert trainer.early_stopping_callback == early_stop_callback assert trainer.early_stopping_callbacks == [early_stop_callback] - assert len(trainer.dev_debugger.early_stopping_history) == expected_count + assert early_stop_callback._run_early_stopping_check.call_count == expected_count @pytest.mark.parametrize( diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index cff0c8a43727d..5cd752ef75379 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -1,5 +1,6 @@ import os from unittest import mock +from unittest.mock import Mock import pytest import torch @@ -29,7 +30,6 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): @pytest.mark.parametrize("fast_dev_run", [1, 4]) -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run): """ Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run @@ -68,6 +68,7 @@ def test_step(self, batch, batch_idx): checkpoint_callback = ModelCheckpoint() early_stopping_callback = EarlyStopping() + early_stopping_callback._evaluate_stopping_criteria = Mock() trainer_config = dict( default_root_dir=tmpdir, fast_dev_run=fast_dev_run, @@ -102,7 +103,7 @@ def _make_fast_dev_run_assertions(trainer, model): # early stopping should not have been called with fast_dev_run assert trainer.early_stopping_callback == early_stopping_callback - assert len(trainer.dev_debugger.early_stopping_history) == 0 + early_stopping_callback._evaluate_stopping_criteria.assert_not_called() train_val_step_model = FastDevRunModel() trainer = Trainer(**trainer_config) From b130961a4317b582a5542238feea1ff78c9bc8ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 5 Sep 2021 13:04:44 +0200 Subject: [PATCH 2/2] remove unused imports --- tests/callbacks/test_early_stopping.py | 1 - tests/trainer/flags/test_fast_dev_run.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index e9961c96efac7..049e409842564 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import os import pickle from typing import List, Optional from unittest import mock diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 5cd752ef75379..a6e54d2fd1738 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -1,5 +1,4 @@ import os -from unittest import mock from unittest.mock import Mock import pytest