Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from collections import defaultdict
from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union
Expand All @@ -33,6 +34,8 @@
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache

log = logging.getLogger(__name__)

_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]


Expand Down Expand Up @@ -100,6 +103,13 @@ def _is_validation_done(self) -> bool:
@property
def done(self) -> bool:
"""Evaluates when to leave the loop."""
if self.trainer.should_stop and self.min_steps:
self.trainer.should_stop = self.global_step >= self.min_steps
if not self.trainer.should_stop:
log.info(
f"Trainer was signaled to stop but required minimum steps ({self.min_steps}) has not been met."
" Training will continue..."
)
return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop

def connect( # type: ignore[override]
Expand Down
21 changes: 6 additions & 15 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,28 +167,19 @@ def _results(self) -> _ResultCollection:
@property
def done(self) -> bool:
"""Evaluates when to leave the loop."""
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
# we use it here because the checkpoint data won't have `completed` increased yet
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)

should_stop = False
if self.trainer.should_stop:
# early stopping
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
if met_min_epochs and met_min_steps:
should_stop = True
else:
if self.trainer.should_stop and self.min_epochs:
self.trainer.should_stop = self.epoch_progress.current.processed >= self.min_epochs
if not self.trainer.should_stop:
log.info(
"Trainer was signaled to stop but required minimum epochs"
f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
" not been met. Training will continue..."
f"Trainer was signaled to stop but required minimum epochs ({self.min_epochs}) has not been met."
" Training will continue..."
)
self.trainer.should_stop = should_stop

return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0
return stop_steps or self.trainer.should_stop or stop_epochs or self.trainer.num_training_batches == 0

@property
def skip(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def training_step(self, batch, batch_idx):
with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"):
trainer.fit(model)

message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue"
message = f"minimum epochs ({min_epochs}) has not been met. Training will continue"
num_messages = sum(1 for record in caplog.records if message in record.message)
assert num_messages == min_epochs - 2
assert model.training_step_invoked == min_epochs * 2
Expand Down