Skip to content

Commit e401a08

Browse files
carmoccas-rogrohitgr7
authored andcommitted
Prune deprecated EarlyStopping(mode='auto') (Lightning-AI#6167)
Co-authored-by: Roger Shieh <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 47f0192 commit e401a08

File tree

4 files changed

+13
-34
lines changed

4 files changed

+13
-34
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535
- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162))
3636

3737

38+
- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167))
39+
40+
3841
### Fixed
3942

4043
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424

2525
from pytorch_lightning.callbacks.base import Callback
26-
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
26+
from pytorch_lightning.utilities import rank_zero_warn
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828

2929

@@ -40,23 +40,18 @@ class EarlyStopping(Callback):
4040
patience: number of validation epochs with no improvement
4141
after which training will be stopped. Default: ``3``.
4242
verbose: verbosity mode. Default: ``False``.
43-
mode: one of {auto, min, max}. In `min` mode,
43+
mode: one of ``'min'``, ``'max'``. In ``'min'`` mode,
4444
training will stop when the quantity
45-
monitored has stopped decreasing; in `max`
45+
monitored has stopped decreasing and in ``'max'``
4646
mode it will stop when the quantity
47-
monitored has stopped increasing; in `auto`
48-
mode, the direction is automatically inferred
49-
from the name of the monitored quantity.
50-
51-
.. warning::
52-
Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3.
47+
monitored has stopped increasing.
5348
5449
strict: whether to crash the training if `monitor` is
5550
not found in the validation metrics. Default: ``True``.
5651
5752
Raises:
5853
MisconfigurationException:
59-
If ``mode`` is none of ``"min"``, ``"max"``, and ``"auto"``.
54+
If ``mode`` is none of ``"min"`` or ``"max"``.
6055
RuntimeError:
6156
If the metric ``monitor`` is not available.
6257
@@ -78,7 +73,7 @@ def __init__(
7873
min_delta: float = 0.0,
7974
patience: int = 3,
8075
verbose: bool = False,
81-
mode: str = 'auto',
76+
mode: str = 'min',
8277
strict: bool = True,
8378
):
8479
super().__init__()
@@ -92,31 +87,13 @@ def __init__(
9287
self.mode = mode
9388
self.warned_result_obj = False
9489

95-
self.__init_monitor_mode()
90+
if self.mode not in self.mode_dict:
91+
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
9692

9793
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
9894
torch_inf = torch.tensor(np.Inf)
9995
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
10096

101-
def __init_monitor_mode(self):
102-
if self.mode not in self.mode_dict and self.mode != 'auto':
103-
raise MisconfigurationException(f"`mode` can be auto, {', '.join(self.mode_dict.keys())}, got {self.mode}")
104-
105-
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
106-
if self.mode == 'auto':
107-
rank_zero_warn(
108-
"mode='auto' is deprecated in v1.1 and will be removed in v1.3."
109-
" Default value for mode with be 'min' in v1.3.", DeprecationWarning
110-
)
111-
112-
if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
113-
self.mode = 'max'
114-
else:
115-
self.mode = 'min'
116-
117-
if self.verbose > 0:
118-
rank_zero_info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.')
119-
12097
def _validate_condition_metric(self, logs):
12198
monitor_val = logs.get(self.monitor)
12299

tests/callbacks/test_early_stopping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,13 @@ def validation_epoch_end(self, outputs):
334334
# Compute min_epochs latest step
335335
by_min_epochs = min_epochs * limit_train_batches
336336

337-
# Make sure the trainer stops for the max of all minimun requirements
337+
# Make sure the trainer stops for the max of all minimum requirements
338338
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \
339339
(trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs)
340340

341341
_logger.disabled = False
342342

343343

344344
def test_early_stopping_mode_options():
345-
with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
345+
with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"):
346346
EarlyStopping(mode="unknown_option")

tests/deprecated_api/test_remove_1-3.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020

2121
def test_v1_3_0_deprecated_arguments(tmpdir):
22-
2322
with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"):
2423

2524
class DeprecatedHparamsModel(LightningModule):

0 commit comments

Comments
 (0)