Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,10 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
trainer.dev_debugger.track_early_stopping_history(self, current)

should_stop, reason = self._evaluate_stopping_criteria(current)

# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move accelerator reduction within should_stop from the loops ?

Copy link
Contributor Author

@carmocca carmocca Aug 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly, It would clean the callbacks using it.

# loops/base.py
def stop(should_stop: bool = True) -> bool:
    self.trainer.training_type_plugin.reduce_boolean_decision(should_stop)
    self.should_stop = should_stop
    ...
    return should_stop

Thoughts? @justusschock @awaelchli

trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
trainer.stop()
self.stopped_epoch = trainer.current_epoch
if reason and self.verbose:
self._log_info(trainer, reason)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def on_load_checkpoint(
def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
should_stop = self.time_elapsed() >= self._duration
should_stop = trainer.accelerator.broadcast(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop and self._verbose:
trainer.stop()
elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING)))
rank_zero_info(f"Time limit reached. Elapsed time is {elapsed}. Signaling Trainer to stop.")
18 changes: 15 additions & 3 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Loop(ABC):

def __init__(self) -> None:
self.restarting = False
self.should_stop = False
self._trainer: Optional["pl.Trainer"] = None

@property
Expand Down Expand Up @@ -105,18 +106,29 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:

self.on_run_start(*args, **kwargs)

while not self.done:
while not self.done and not self.should_stop:
try:
self.on_advance_start(*args, **kwargs)
self.advance(*args, **kwargs)
self.on_advance_end()
if not self.should_stop:
self.advance(*args, **kwargs)
if not self.should_stop:
self.on_advance_end()
self.restarting = False
except StopIteration:
break

output = self.on_run_end()

self.should_stop = False
return output

def stop(self) -> None:
"""Manually stop a loop and its linked loops."""
self.should_stop = True
for v in self.__dict__.values():
if isinstance(v, Loop):
v.stop()

@abstractmethod
def reset(self) -> None:
"""Resets the internal state of the loop at the beginning of each call to :attr:`run`."""
Expand Down
8 changes: 2 additions & 6 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,10 @@ def batch_idx(self) -> int:
@property
def done(self) -> bool:
"""Returns whether the training should be stopped.
The criteria are that the number of steps reached the max steps,
the last batch is reached or the trainer signals to stop (e.g. by early stopping).
The criteria are that the number of steps reached the max steps, or the last batch is reached
"""
max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps
return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch)
return max_steps_reached or self._num_training_batches_reached(self.is_last_batch)

def connect(
self,
Expand Down Expand Up @@ -368,9 +367,6 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
if is_last_batch and is_infinite_dataset:
return True

if self.trainer.should_stop:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed, we might actually want to get rid of this completely.
it's not well justified why we would want to run validation at the point of stopping.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me !

return True

# TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = is_last_batch
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
Expand Down
36 changes: 14 additions & 22 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,31 +126,23 @@ def _results(self) -> ResultCollection:

@property
def done(self) -> bool:
"""Evaluates when to leave the loop.

Returns True if trainer.should_stop was set (e.g. by early stopping)
or if the maximum number of steps or epochs is reached.
"""
"""Evaluates when to leave the loop."""
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = self.max_steps is not None and self.global_step >= self.max_steps
stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs

should_stop = False
if self.trainer.should_stop:
# early stopping
met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
if met_min_epochs and met_min_steps:
should_stop = True
else:
log.info(
"Trainer was signaled to stop but required minimum epochs"
f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
" not been met. Training will continue..."
)
self.trainer.should_stop = should_stop

return stop_steps or should_stop or stop_epochs
return stop_steps or stop_epochs

def stop(self) -> None:
"""Checks whether a minimum number of epochs and steps have been met before stopping."""
met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
if met_min_epochs and met_min_steps:
return super().stop()
log.info(
"Trainer was signaled to stop but required minimum epochs"
f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
" not been met. Training will continue..."
)

@property
def skip(self) -> bool:
Expand Down
18 changes: 18 additions & 0 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,24 @@ def _ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.PREDICTING:
return self.predicted_ckpt_path

@property
def should_stop(self) -> bool:
return self._active_loop.should_stop

@should_stop.setter
def should_stop(self, should_stop: bool) -> None:
rank_zero_deprecation(
"Setting `trainer.should_stop` is deprecated in v1.5 and will be removed in v1.7. Please use"
" `trainer.stop()` or `trainer.a_loop.stop()` instead"
)
if should_stop:
self.stop()
# in case `False` was passed
self._active_loop.should_stop = should_stop

def stop(self) -> None:
self._active_loop.stop()

"""
Logging properties
"""
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,6 @@ def __init__(
def _setup_on_init(self, num_sanity_val_steps: int) -> None:
self._log_device_info()

self.should_stop = False
self.state = TrainerState()
self.num_training_batches = 0
self.train_dataloader = None
Expand Down
3 changes: 3 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,6 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
def test_v1_7_0_test_tube_logger(_, tmpdir):
with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"):
_ = TestTubeLogger(tmpdir)


# FIXME: add deprecation test for should_stop setter