Skip to content

Commit f9132e8

Browse files
authored
remove early stopping tracking from internal debugger (#9327)
* replace dev debugger in early stopping * remove unused imports
1 parent dc3391b commit f9132e8

File tree

4 files changed

+7
-25
lines changed

4 files changed

+7
-25
lines changed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,6 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
206206
return
207207

208208
current = logs.get(self.monitor)
209-
210-
# when in dev debugging
211-
trainer.dev_debugger.track_early_stopping_history(self, current)
212-
213209
should_stop, reason = self._evaluate_stopping_criteria(current)
214210

215211
# stop every ddp process if any world process decides to stop

pytorch_lightning/utilities/debugging.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ class InternalDebugger:
4343
def __init__(self, trainer: "pl.Trainer") -> None:
4444
self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1"
4545
self.trainer = trainer
46-
self.early_stopping_history: List[Dict[str, Any]] = []
4746
self.checkpoint_callback_history: List[Dict[str, Any]] = []
4847
self.events: List[Dict[str, Any]] = []
4948
self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = []
@@ -126,20 +125,6 @@ def track_lr_schedulers_update(
126125
}
127126
self.saved_lr_scheduler_updates.append(loss_dict)
128127

129-
@enabled_only
130-
def track_early_stopping_history(
131-
self, callback: "pl.callbacks.early_stopping.EarlyStopping", current: torch.Tensor
132-
) -> None:
133-
debug_dict = {
134-
"epoch": self.trainer.current_epoch,
135-
"global_step": self.trainer.global_step,
136-
"rank": self.trainer.global_rank,
137-
"current": current,
138-
"best": callback.best_score,
139-
"patience": callback.wait_count,
140-
}
141-
self.early_stopping_history.append(debug_dict)
142-
143128
@enabled_only
144129
def track_checkpointing_history(self, filepath: str) -> None:
145130
cb = self.trainer.checkpoint_callback

tests/callbacks/test_early_stopping.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
import os
1615
import pickle
1716
from typing import List, Optional
1817
from unittest import mock
18+
from unittest.mock import Mock
1919

2020
import cloudpickle
2121
import numpy as np
@@ -98,25 +98,26 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
9898
new_trainer.fit(model)
9999

100100

101-
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
102101
def test_early_stopping_no_extraneous_invocations(tmpdir):
103102
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
104103
model = ClassificationModel()
105104
dm = ClassifDataModule()
106105
early_stop_callback = EarlyStopping(monitor="train_loss")
106+
early_stop_callback._run_early_stopping_check = Mock()
107107
expected_count = 4
108108
trainer = Trainer(
109109
default_root_dir=tmpdir,
110110
callbacks=[early_stop_callback],
111111
limit_train_batches=4,
112112
limit_val_batches=4,
113113
max_epochs=expected_count,
114+
checkpoint_callback=False,
114115
)
115116
trainer.fit(model, datamodule=dm)
116117

117118
assert trainer.early_stopping_callback == early_stop_callback
118119
assert trainer.early_stopping_callbacks == [early_stop_callback]
119-
assert len(trainer.dev_debugger.early_stopping_history) == expected_count
120+
assert early_stop_callback._run_early_stopping_check.call_count == expected_count
120121

121122

122123
@pytest.mark.parametrize(

tests/trainer/flags/test_fast_dev_run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from unittest import mock
2+
from unittest.mock import Mock
33

44
import pytest
55
import torch
@@ -29,7 +29,6 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
2929

3030

3131
@pytest.mark.parametrize("fast_dev_run", [1, 4])
32-
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
3332
def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
3433
"""
3534
Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run
@@ -68,6 +67,7 @@ def test_step(self, batch, batch_idx):
6867

6968
checkpoint_callback = ModelCheckpoint()
7069
early_stopping_callback = EarlyStopping()
70+
early_stopping_callback._evaluate_stopping_criteria = Mock()
7171
trainer_config = dict(
7272
default_root_dir=tmpdir,
7373
fast_dev_run=fast_dev_run,
@@ -102,7 +102,7 @@ def _make_fast_dev_run_assertions(trainer, model):
102102

103103
# early stopping should not have been called with fast_dev_run
104104
assert trainer.early_stopping_callback == early_stopping_callback
105-
assert len(trainer.dev_debugger.early_stopping_history) == 0
105+
early_stopping_callback._evaluate_stopping_criteria.assert_not_called()
106106

107107
train_val_step_model = FastDevRunModel()
108108
trainer = Trainer(**trainer_config)

0 commit comments

Comments
 (0)