Skip to content

Commit f357417

Browse files
authored
Change trainer.should_stop to not stop in between an epoch and run until min_steps/min_epochs only (#13890)
1 parent 0e30e4a commit f357417

File tree

9 files changed

+176
-27
lines changed

9 files changed

+176
-27
lines changed

docs/source-pytorch/common/trainer.rst

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,3 +1745,63 @@ execution within that function, and the status of the Trainer.
17451745
trainer.state.status
17461746
# stage in ("train", "sanity_check", "validate", "test", "predict", "tune")
17471747
trainer.state.stage
1748+
1749+
should_stop
1750+
***********
1751+
1752+
If you want to terminate the training during ``.fit``, you can set ``trainer.should_stop=True`` to terminate the training
1753+
as soon as possible. Note that, it will respect the arguments ``min_steps`` and ``min_epochs`` to check whether to stop. If these
1754+
arguments are set and the ``current_epoch`` or ``global_step`` don't meet these minimum conditions, training will continue until
1755+
both conditions are met. If any of these arguments is not set, it won't be considered for the final decision.
1756+
1757+
1758+
.. code-block:: python
1759+
1760+
# setting `trainer.should_stop` at any point of training will terminate it
1761+
class LitModel(LightningModule):
1762+
def training_step(self, *args, **kwargs):
1763+
self.trainer.should_stop = True
1764+
1765+
1766+
trainer = Trainer()
1767+
model = LitModel()
1768+
trainer.fit(model)
1769+
1770+
.. code-block:: python
1771+
1772+
# setting `trainer.should_stop` will stop training only after at least 5 epochs have run
1773+
class LitModel(LightningModule):
1774+
def training_step(self, *args, **kwargs):
1775+
if self.current_epoch == 2:
1776+
self.trainer.should_stop = True
1777+
1778+
1779+
trainer = Trainer(min_epochs=5, max_epochs=100)
1780+
model = LitModel()
1781+
trainer.fit(model)
1782+
1783+
.. code-block:: python
1784+
1785+
# setting `trainer.should_stop` will stop training only after at least 5 steps have run
1786+
class LitModel(LightningModule):
1787+
def training_step(self, *args, **kwargs):
1788+
if self.global_step == 2:
1789+
self.trainer.should_stop = True
1790+
1791+
1792+
trainer = Trainer(min_steps=5, max_epochs=100)
1793+
model = LitModel()
1794+
trainer.fit(model)
1795+
1796+
.. code-block:: python
1797+
1798+
# setting `trainer.should_stop` at any until both min_steps and min_epochs are satisfied
1799+
class LitModel(LightningModule):
1800+
def training_step(self, *args, **kwargs):
1801+
if self.global_step == 7:
1802+
self.trainer.should_stop = True
1803+
1804+
1805+
trainer = Trainer(min_steps=5, min_epochs=5, max_epochs=100)
1806+
model = LitModel()
1807+
trainer.fit(model)

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4545
- Included `torch.cuda` rng state to the aggregate `_collect_rng_states()` and `_set_rng_states()` ([#14384](https://github.com/Lightning-AI/lightning/pull/14384))
4646

4747

48+
- Changed `trainer.should_stop` to not stop in between an epoch and run until `min_steps/min_epochs` only ([#13890](https://github.com/Lightning-AI/lightning/pull/13890))
49+
50+
4851
- When using multiple loggers, by default checkpoints and profiler output now get saved to the log dir of the first logger in the list ([14325](https://github.com/Lightning-AI/lightning/pull/14325))
4952

5053

src/pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,21 @@ def _is_validation_done(self) -> bool:
102102
@property
103103
def done(self) -> bool:
104104
"""Evaluates when to leave the loop."""
105-
return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop
105+
if self._is_training_done and self._is_validation_done:
106+
return True
107+
108+
if self.trainer.should_stop:
109+
# early stopping
110+
min_epochs = self.trainer.fit_loop.min_epochs
111+
should_stop_early = self.trainer.fit_loop._should_stop_early
112+
if not should_stop_early:
113+
self._warning_cache.info(
114+
f"Trainer was signaled to stop but the required `min_epochs={min_epochs!r}` or"
115+
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
116+
)
117+
return should_stop_early
118+
119+
return False
106120

107121
def connect( # type: ignore[override]
108122
self,

src/pytorch_lightning/loops/fit_loop.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ def _results(self) -> _ResultCollection:
146146
return self.epoch_loop.val_loop._results
147147
raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")
148148

149+
@property
150+
def _should_stop_early(self) -> bool:
151+
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
152+
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
153+
return met_min_epochs and met_min_steps
154+
149155
@property
150156
def done(self) -> bool:
151157
"""Evaluates when to leave the loop."""
@@ -169,20 +175,10 @@ def done(self) -> bool:
169175
rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
170176
return True
171177

172-
if self.trainer.should_stop:
173-
# early stopping
174-
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
175-
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
176-
if met_min_epochs and met_min_steps:
177-
self.trainer.should_stop = True
178-
rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
179-
return True
180-
else:
181-
rank_zero_info(
182-
f"Trainer was signaled to stop but the required `min_epochs={self.min_epochs!r}` or"
183-
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
184-
)
185-
self.trainer.should_stop = False
178+
if self.trainer.should_stop and self._should_stop_early:
179+
rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
180+
return True
181+
186182
return False
187183

188184
@property

src/pytorch_lightning/utilities/warnings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning as NewLightningDeprecationWarning
2020
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation as new_rank_zero_deprecation
21+
from pytorch_lightning.utilities.rank_zero import rank_zero_info as new_rank_zero_info
2122
from pytorch_lightning.utilities.rank_zero import rank_zero_warn as new_rank_zero_warn
2223

2324
# enable our warnings
@@ -39,6 +40,11 @@ def deprecation(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
3940
self.add(message)
4041
new_rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs)
4142

43+
def info(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
44+
if message not in self:
45+
self.add(message)
46+
new_rank_zero_info(message, stacklevel=stacklevel, **kwargs)
47+
4248

4349
def rank_zero_warn(*args: Any, **kwargs: Any) -> Any:
4450
new_rank_zero_deprecation(

tests/tests_pytorch/callbacks/test_early_stopping.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,25 +265,28 @@ def validation_epoch_end(self, outputs):
265265
assert early_stopping.stopped_epoch == expected_stop_epoch
266266

267267

268-
@pytest.mark.parametrize("limit_train_batches", (3, 5))
269268
@pytest.mark.parametrize(
270-
["min_epochs", "min_steps"],
269+
"limit_train_batches,min_epochs,min_steps,stop_step",
271270
[
272271
# IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being
273272
# triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop
274-
(0, 10),
273+
(3, 0, 10, 10),
274+
(5, 0, 10, 10),
275275
# IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is
276276
# being triggered, THEN the trainer should continue until reaching
277277
# `trainer.global_step` == `min_epochs * len(train_dataloader)`
278-
(2, 0),
278+
(3, 2, 0, 6),
279+
(5, 2, 0, 10),
279280
# IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when
280281
# `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and
281282
# `min_steps` would be reached
282-
(1, 10),
283-
(3, 10),
283+
(3, 1, 10, 10),
284+
(5, 1, 10, 10),
285+
(3, 3, 10, 10),
286+
(5, 3, 10, 15),
284287
],
285288
)
286-
def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps):
289+
def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps, stop_step):
287290
if min_steps:
288291
assert limit_train_batches < min_steps
289292

@@ -317,8 +320,7 @@ def training_step(self, batch, batch_idx):
317320
# epochs continue until min steps are reached
318321
assert trainer.current_epoch == expected_epochs
319322
# steps continue until min steps are reached AND the epoch is exhausted
320-
# stopping mid-epoch is not supported
321-
assert trainer.global_step == limit_train_batches * expected_epochs
323+
assert trainer.global_step == stop_step
322324

323325

324326
def test_early_stopping_mode_options():

tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
1415
from unittest.mock import patch
1516

1617
import pytest
@@ -184,3 +185,36 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir):
184185
) as advance_mocked:
185186
trainer.fit(model, ckpt_path=ckpt_path)
186187
assert advance_mocked.call_count == 1
188+
189+
190+
@pytest.mark.parametrize(
191+
"min_epochs, min_steps, current_epoch, global_step, early_stop, epoch_loop_done, raise_info_msg",
192+
[
193+
(None, None, 1, 4, True, True, False),
194+
(None, None, 1, 10, True, True, False),
195+
(4, None, 1, 4, False, False, True),
196+
(4, 2, 1, 4, False, False, True),
197+
(4, None, 1, 10, False, True, False),
198+
(4, 3, 1, 3, False, False, True),
199+
(4, 10, 1, 10, False, True, False),
200+
(None, 4, 1, 4, True, True, False),
201+
],
202+
)
203+
def test_should_stop_early_stopping_conditions_not_met(
204+
caplog, min_epochs, min_steps, current_epoch, global_step, early_stop, epoch_loop_done, raise_info_msg
205+
):
206+
"""Test that checks that info message is logged when users sets `should_stop` but min conditions are not
207+
met."""
208+
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0)
209+
trainer.num_training_batches = 10
210+
trainer.should_stop = True
211+
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step
212+
trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step
213+
trainer.fit_loop.epoch_progress.current.completed = current_epoch - 1
214+
215+
message = f"min_epochs={min_epochs}` or `min_steps={min_steps}` has not been met. Training will continue"
216+
with caplog.at_level(logging.INFO, logger="pytorch_lightning.loops"):
217+
assert trainer.fit_loop.epoch_loop.done is epoch_loop_done
218+
219+
assert (message in caplog.text) is raise_info_msg
220+
assert trainer.fit_loop._should_stop_early is early_stop

tests/tests_pytorch/loops/test_training_loop.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ def test_fit_loop_done_log_messages(caplog):
180180

181181
fit_loop.epoch_loop.min_steps = 100
182182
assert not fit_loop.done
183-
assert "was signaled to stop but" in caplog.text
184183

185184

186185
def test_warning_valid_train_step_end(tmpdir):
@@ -198,3 +197,35 @@ def training_step_end(self, outputs):
198197
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
199198

200199
trainer.fit(model)
200+
201+
202+
@pytest.mark.parametrize(
203+
"min_epochs, min_steps, current_epoch, early_stop, fit_loop_done, raise_debug_msg",
204+
[
205+
(4, None, 100, True, True, False),
206+
(4, None, 3, False, False, False),
207+
(4, 10, 3, False, False, False),
208+
(None, 10, 4, True, True, True),
209+
(4, None, 4, True, True, True),
210+
(4, 10, 4, True, True, True),
211+
],
212+
)
213+
def test_should_stop_early_stopping_conditions_met(
214+
caplog, min_epochs, min_steps, current_epoch, early_stop, fit_loop_done, raise_debug_msg
215+
):
216+
"""Test that checks that debug message is logged when users sets `should_stop` and min conditions are met."""
217+
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100)
218+
trainer.num_training_batches = 10
219+
trainer.should_stop = True
220+
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = (
221+
current_epoch * trainer.num_training_batches
222+
)
223+
trainer.fit_loop.epoch_loop.batch_progress.current.ready = 10
224+
trainer.fit_loop.epoch_progress.current.processed = current_epoch
225+
226+
message = "`Trainer.fit` stopped: `trainer.should_stop` was set."
227+
with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"):
228+
assert trainer.fit_loop.done is fit_loop_done
229+
230+
assert (message in caplog.text) is raise_debug_msg
231+
assert trainer.fit_loop._should_stop_early is early_stop

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,10 @@ def training_step(self, batch, batch_idx):
622622
output["loss"] = output["loss"] * 0.0 # force minimal loss to trigger early stopping
623623
self.log("loss", output["loss"])
624624
self.training_step_invoked += 1
625-
assert not self.trainer.should_stop
625+
if self.current_epoch < 2:
626+
assert not self.trainer.should_stop
627+
else:
628+
assert self.trainer.should_stop
626629
return output
627630

628631
model = TestModel()
@@ -641,7 +644,7 @@ def training_step(self, batch, batch_idx):
641644

642645
message = f"min_epochs={min_epochs}` or `min_steps=None` has not been met. Training will continue"
643646
num_messages = sum(1 for record in caplog.records if message in record.message)
644-
assert num_messages == min_epochs - 2
647+
assert num_messages == 1
645648
assert model.training_step_invoked == min_epochs * 2
646649

647650

0 commit comments

Comments
 (0)