2323import torch
2424
2525from pytorch_lightning .callbacks .base import Callback
26- from pytorch_lightning .utilities import rank_zero_info , rank_zero_warn
26+ from pytorch_lightning .utilities import rank_zero_warn
2727from pytorch_lightning .utilities .exceptions import MisconfigurationException
2828
2929
@@ -40,23 +40,18 @@ class EarlyStopping(Callback):
4040 patience: number of validation epochs with no improvement
4141 after which training will be stopped. Default: ``3``.
4242 verbose: verbosity mode. Default: ``False``.
43- mode: one of {auto, min, max} . In `min` mode,
43+ mode: one of ``' min'``, ``' max'`` . In ``' min'` ` mode,
4444 training will stop when the quantity
45- monitored has stopped decreasing; in `max`
45+ monitored has stopped decreasing and in ``' max'` `
4646 mode it will stop when the quantity
47- monitored has stopped increasing; in `auto`
48- mode, the direction is automatically inferred
49- from the name of the monitored quantity.
50-
51- .. warning::
52- Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3.
47+ monitored has stopped increasing.
5348
5449 strict: whether to crash the training if `monitor` is
5550 not found in the validation metrics. Default: ``True``.
5651
5752 Raises:
5853 MisconfigurationException:
59- If ``mode`` is none of ``"min"``, ``"max"``, and ``"auto "``.
54+ If ``mode`` is none of ``"min"`` or ``"max "``.
6055 RuntimeError:
6156 If the metric ``monitor`` is not available.
6257
@@ -78,7 +73,7 @@ def __init__(
7873 min_delta : float = 0.0 ,
7974 patience : int = 3 ,
8075 verbose : bool = False ,
81- mode : str = 'auto ' ,
76+ mode : str = 'min ' ,
8277 strict : bool = True ,
8378 ):
8479 super ().__init__ ()
@@ -92,31 +87,13 @@ def __init__(
9287 self .mode = mode
9388 self .warned_result_obj = False
9489
95- self .__init_monitor_mode ()
90+ if self .mode not in self .mode_dict :
91+ raise MisconfigurationException (f"`mode` can be { ', ' .join (self .mode_dict .keys ())} , got { self .mode } " )
9692
9793 self .min_delta *= 1 if self .monitor_op == torch .gt else - 1
9894 torch_inf = torch .tensor (np .Inf )
9995 self .best_score = torch_inf if self .monitor_op == torch .lt else - torch_inf
10096
101- def __init_monitor_mode (self ):
102- if self .mode not in self .mode_dict and self .mode != 'auto' :
103- raise MisconfigurationException (f"`mode` can be auto, { ', ' .join (self .mode_dict .keys ())} , got { self .mode } " )
104-
105- # TODO: Update with MisconfigurationException when auto mode is removed in v1.3
106- if self .mode == 'auto' :
107- rank_zero_warn (
108- "mode='auto' is deprecated in v1.1 and will be removed in v1.3."
109- " Default value for mode with be 'min' in v1.3." , DeprecationWarning
110- )
111-
112- if "acc" in self .monitor or self .monitor .startswith ("fmeasure" ):
113- self .mode = 'max'
114- else :
115- self .mode = 'min'
116-
117- if self .verbose > 0 :
118- rank_zero_info (f'EarlyStopping mode set to { self .mode } for monitoring { self .monitor } .' )
119-
12097 def _validate_condition_metric (self , logs ):
12198 monitor_val = logs .get (self .monitor )
12299
0 commit comments