From 1f0c08aba06176372d29d997d04644567705514c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 28 Jul 2021 18:56:58 +0200 Subject: [PATCH 1/3] Add `Loop.stop()` --- pytorch_lightning/callbacks/early_stopping.py | 3 +- .../callbacks/gpu_stats_monitor.py | 2 +- pytorch_lightning/callbacks/lr_monitor.py | 2 +- pytorch_lightning/callbacks/timer.py | 2 +- pytorch_lightning/loops/base.py | 18 ++++++++-- .../loops/epoch/training_epoch_loop.py | 9 ++--- pytorch_lightning/loops/fit_loop.py | 36 ++++++++----------- .../logger_connector/logger_connector.py | 4 +-- pytorch_lightning/trainer/properties.py | 12 +++++++ pytorch_lightning/trainer/trainer.py | 1 - 10 files changed, 52 insertions(+), 37 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 77683ad2819f3..2fb438631e3bd 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -211,11 +211,10 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: trainer.dev_debugger.track_early_stopping_history(self, current) should_stop, reason = self._evaluate_stopping_criteria(current) - # stop every ddp process if any world process decides to stop should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) - trainer.should_stop = trainer.should_stop or should_stop if should_stop: + trainer._active_loop.stop() self.stopped_epoch = trainer.current_epoch if reason and self.verbose: self._log_info(trainer, reason) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 3a8d110d59376..ab7995ec219ea 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -249,4 +249,4 @@ def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]: @staticmethod def _should_log(trainer) -> bool: - return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop + return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer._active_loop.should_stop diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index d7f350fecfeae..384ad76ba6365 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -265,4 +265,4 @@ def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> Lis @staticmethod def _should_log(trainer) -> bool: - return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop + return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer._active_loop.should_stop diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index 23894a1179c1f..092499118776a 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -167,7 +167,7 @@ def on_load_checkpoint( def _check_time_remaining(self, trainer: "pl.Trainer") -> None: should_stop = self.time_elapsed() >= self._duration should_stop = trainer.accelerator.broadcast(should_stop) - trainer.should_stop = trainer.should_stop or should_stop if should_stop and self._verbose: + trainer._active_loop.stop() elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING))) rank_zero_info(f"Time limit reached. Elapsed time is {elapsed}. Signaling Trainer to stop.") diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index ee5c3a1b708f1..76c2b36d73aad 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -49,6 +49,7 @@ class Loop(ABC): def __init__(self) -> None: self.restarting = False + self.should_stop = False self._trainer: Optional["pl.Trainer"] = None @property @@ -105,18 +106,29 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: self.on_run_start(*args, **kwargs) - while not self.done: + while not self.done and not self.should_stop: try: self.on_advance_start(*args, **kwargs) - self.advance(*args, **kwargs) - self.on_advance_end() + if not self.should_stop: + self.advance(*args, **kwargs) + if not self.should_stop: + self.on_advance_end() self.restarting = False except StopIteration: break output = self.on_run_end() + + self.should_stop = False return output + def stop(self) -> None: + """Manually stop a loop and its linked loops.""" + self.should_stop = True + for v in self.__dict__.values(): + if isinstance(v, Loop): + v.stop() + @abstractmethod def reset(self) -> None: """Resets the internal state of the loop at the beginning of each call to :attr:`run`.""" diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 43d51fe0027c6..b5d0c0489edcb 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -68,11 +68,10 @@ def batch_idx(self) -> int: @property def done(self) -> bool: """Returns whether the training should be stopped. - The criteria are that the number of steps reached the max steps, - the last batch is reached or the trainer signals to stop (e.g. by early stopping). + The criteria are that the number of steps reached the max steps, or the last batch is reached """ max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps - return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) + return max_steps_reached or self._num_training_batches_reached(self.is_last_batch) def connect( self, @@ -368,7 +367,9 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if is_last_batch and is_infinite_dataset: return True - if self.trainer.should_stop: + # FIXME: this is here to run validation when trainer is signaled to stop + # but with the changes here, this won't happen. + if self.should_stop: return True # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 4a09c0ca1faeb..cf7896fef44d7 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -126,31 +126,23 @@ def _results(self) -> ResultCollection: @property def done(self) -> bool: - """Evaluates when to leave the loop. - - Returns True if trainer.should_stop was set (e.g. by early stopping) - or if the maximum number of steps or epochs is reached. - """ + """Evaluates when to leave the loop.""" # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = self.max_steps is not None and self.global_step >= self.max_steps stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs - - should_stop = False - if self.trainer.should_stop: - # early stopping - met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True - met_min_steps = self.global_step >= self.min_steps if self.min_steps else True - if met_min_epochs and met_min_steps: - should_stop = True - else: - log.info( - "Trainer was signaled to stop but required minimum epochs" - f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" - " not been met. Training will continue..." - ) - self.trainer.should_stop = should_stop - - return stop_steps or should_stop or stop_epochs + return stop_steps or stop_epochs + + def stop(self) -> None: + """Checks whether a minimum number of epochs and steps have been met before stopping.""" + met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True + met_min_steps = self.global_step >= self.min_steps if self.min_steps else True + if met_min_epochs and met_min_steps: + return super().stop() + log.info( + "Trainer was signaled to stop but required minimum epochs" + f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" + " not been met. Training will continue..." + ) @property def skip(self) -> bool: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a965699510689..de815014db43c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -57,12 +57,12 @@ def on_trainer_init( @property def should_flush_logs(self) -> bool: should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 - return should_flush or self.trainer.should_stop + return should_flush or self.trainer._active_loop.should_stop @property def should_update_logs(self) -> bool: should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 - return should_log_every_n_steps or self.trainer.should_stop + return should_log_every_n_steps or self.trainer._active_loop.should_stop def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None: if logger is True: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 5e16ba57464f5..5d9962066a68a 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -628,6 +628,18 @@ def _ckpt_path(self) -> Optional[str]: if self.state.fn == TrainerFn.PREDICTING: return self.predicted_ckpt_path + @property + def should_stop(self) -> bool: + # FIXME: deprecate, ask users to access it themselves + return self._active_loop.should_stop + + @should_stop.setter + def should_stop(self, should_stop: bool) -> None: + # FIXME: deprecate this setter, ask users to call `.stop()` manually + if should_stop: + self._active_loop.stop() + self._active_loop.should_stop = should_stop + """ Logging properties """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f56a84b1d294d..dad27b79399ca 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -467,7 +467,6 @@ def __init__( def _setup_on_init(self, num_sanity_val_steps: int) -> None: self._log_device_info() - self.should_stop = False self.state = TrainerState() self.num_training_batches = 0 self.train_dataloader = None From 98fdae05a2102de3eacbeb587b3cc9fb752ed251 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 27 Aug 2021 17:53:48 +0200 Subject: [PATCH 2/3] Add stop to trainer directly --- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/callbacks/timer.py | 2 +- .../loops/epoch/training_epoch_loop.py | 5 ----- pytorch_lightning/trainer/properties.py | 16 +++++++++++++--- tests/deprecated_api/test_remove_1-7.py | 3 +++ 5 files changed, 18 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 2fb438631e3bd..7d1476f1c51b1 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -214,7 +214,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: # stop every ddp process if any world process decides to stop should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) if should_stop: - trainer._active_loop.stop() + trainer.stop() self.stopped_epoch = trainer.current_epoch if reason and self.verbose: self._log_info(trainer, reason) diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index 092499118776a..d5db0b71edf0b 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -168,6 +168,6 @@ def _check_time_remaining(self, trainer: "pl.Trainer") -> None: should_stop = self.time_elapsed() >= self._duration should_stop = trainer.accelerator.broadcast(should_stop) if should_stop and self._verbose: - trainer._active_loop.stop() + trainer.stop() elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING))) rank_zero_info(f"Time limit reached. Elapsed time is {elapsed}. Signaling Trainer to stop.") diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index b5d0c0489edcb..29fe03f427f95 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -367,11 +367,6 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: if is_last_batch and is_infinite_dataset: return True - # FIXME: this is here to run validation when trainer is signaled to stop - # but with the changes here, this won't happen. - if self.should_stop: - return True - # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = is_last_batch if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 5d9962066a68a..d627023ee71f3 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -630,16 +630,26 @@ def _ckpt_path(self) -> Optional[str]: @property def should_stop(self) -> bool: - # FIXME: deprecate, ask users to access it themselves + rank_zero_deprecation( + "Accessing `trainer.should_stop` is deprecated. You can find this attribute in the loop instance" + " you want to check. For example, `trainer.fit_loop.should_stop`" + ) return self._active_loop.should_stop @should_stop.setter def should_stop(self, should_stop: bool) -> None: - # FIXME: deprecate this setter, ask users to call `.stop()` manually + rank_zero_deprecation( + "Setting `trainer.should_stop` is deprecated in v1.5 and will be removed in v1.7. Please use" + " `trainer.stop()` or `trainer.a_loop.stop()` instead" + ) if should_stop: - self._active_loop.stop() + self.stop() + # in case `False` was passed self._active_loop.should_stop = should_stop + def stop(self) -> None: + self._active_loop.stop() + """ Logging properties """ diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index ae8f9e1dcc53d..3e970b2f58cfb 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -95,3 +95,6 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): def test_v1_7_0_test_tube_logger(_, tmpdir): with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"): _ = TestTubeLogger(tmpdir) + + +# FIXME: add deprecations for should_stop From 88628f5fede07cd145f512872cb5eb1f1eafc038 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 27 Aug 2021 17:56:07 +0200 Subject: [PATCH 3/3] Un-deprecate should_stop getter --- pytorch_lightning/callbacks/gpu_stats_monitor.py | 2 +- pytorch_lightning/callbacks/lr_monitor.py | 2 +- .../trainer/connectors/logger_connector/logger_connector.py | 4 ++-- pytorch_lightning/trainer/properties.py | 4 ---- tests/deprecated_api/test_remove_1-7.py | 2 +- 5 files changed, 5 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index ab7995ec219ea..3a8d110d59376 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -249,4 +249,4 @@ def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]: @staticmethod def _should_log(trainer) -> bool: - return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer._active_loop.should_stop + return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 384ad76ba6365..d7f350fecfeae 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -265,4 +265,4 @@ def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> Lis @staticmethod def _should_log(trainer) -> bool: - return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer._active_loop.should_stop + return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index de815014db43c..a965699510689 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -57,12 +57,12 @@ def on_trainer_init( @property def should_flush_logs(self) -> bool: should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 - return should_flush or self.trainer._active_loop.should_stop + return should_flush or self.trainer.should_stop @property def should_update_logs(self) -> bool: should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 - return should_log_every_n_steps or self.trainer._active_loop.should_stop + return should_log_every_n_steps or self.trainer.should_stop def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None: if logger is True: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index d627023ee71f3..a40a95ff1f0a8 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -630,10 +630,6 @@ def _ckpt_path(self) -> Optional[str]: @property def should_stop(self) -> bool: - rank_zero_deprecation( - "Accessing `trainer.should_stop` is deprecated. You can find this attribute in the loop instance" - " you want to check. For example, `trainer.fit_loop.should_stop`" - ) return self._active_loop.should_stop @should_stop.setter diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 3e970b2f58cfb..2d3563644c40f 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -97,4 +97,4 @@ def test_v1_7_0_test_tube_logger(_, tmpdir): _ = TestTubeLogger(tmpdir) -# FIXME: add deprecations for should_stop +# FIXME: add deprecation test for should_stop setter