Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,6 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
return

current = logs.get(self.monitor)

# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)

should_stop, reason = self._evaluate_stopping_criteria(current)

# stop every ddp process if any world process decides to stop
Expand Down
15 changes: 0 additions & 15 deletions pytorch_lightning/utilities/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class InternalDebugger:
def __init__(self, trainer: "pl.Trainer") -> None:
self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1"
self.trainer = trainer
self.early_stopping_history: List[Dict[str, Any]] = []
self.checkpoint_callback_history: List[Dict[str, Any]] = []
self.events: List[Dict[str, Any]] = []
self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = []
Expand Down Expand Up @@ -126,20 +125,6 @@ def track_lr_schedulers_update(
}
self.saved_lr_scheduler_updates.append(loss_dict)

@enabled_only
def track_early_stopping_history(
self, callback: "pl.callbacks.early_stopping.EarlyStopping", current: torch.Tensor
) -> None:
debug_dict = {
"epoch": self.trainer.current_epoch,
"global_step": self.trainer.global_step,
"rank": self.trainer.global_rank,
"current": current,
"best": callback.best_score,
"patience": callback.wait_count,
}
self.early_stopping_history.append(debug_dict)

@enabled_only
def track_checkpointing_history(self, filepath: str) -> None:
cb = self.trainer.checkpoint_callback
Expand Down
7 changes: 4 additions & 3 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import pickle
from typing import List, Optional
from unittest import mock
from unittest.mock import Mock

import cloudpickle
import numpy as np
Expand Down Expand Up @@ -98,25 +98,26 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
new_trainer.fit(model)


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_early_stopping_no_extraneous_invocations(tmpdir):
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
model = ClassificationModel()
dm = ClassifDataModule()
early_stop_callback = EarlyStopping(monitor="train_loss")
early_stop_callback._run_early_stopping_check = Mock()
expected_count = 4
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
limit_train_batches=4,
limit_val_batches=4,
max_epochs=expected_count,
checkpoint_callback=False,
)
trainer.fit(model, datamodule=dm)

assert trainer.early_stopping_callback == early_stop_callback
assert trainer.early_stopping_callbacks == [early_stop_callback]
assert len(trainer.dev_debugger.early_stopping_history) == expected_count
assert early_stop_callback._run_early_stopping_check.call_count == expected_count


@pytest.mark.parametrize(
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
Expand Down Expand Up @@ -29,7 +29,6 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):


@pytest.mark.parametrize("fast_dev_run", [1, 4])
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
"""
Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run
Expand Down Expand Up @@ -68,6 +67,7 @@ def test_step(self, batch, batch_idx):

checkpoint_callback = ModelCheckpoint()
early_stopping_callback = EarlyStopping()
early_stopping_callback._evaluate_stopping_criteria = Mock()
trainer_config = dict(
default_root_dir=tmpdir,
fast_dev_run=fast_dev_run,
Expand Down Expand Up @@ -102,7 +102,7 @@ def _make_fast_dev_run_assertions(trainer, model):

# early stopping should not have been called with fast_dev_run
assert trainer.early_stopping_callback == early_stopping_callback
assert len(trainer.dev_debugger.early_stopping_history) == 0
early_stopping_callback._evaluate_stopping_criteria.assert_not_called()

train_val_step_model = FastDevRunModel()
trainer = Trainer(**trainer_config)
Expand Down