From 20c54630320aa4b8e8e972501feea84054e01a6f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 16:13:11 +0100 Subject: [PATCH 1/4] Set `Loop.restarting` recursively --- pytorch_lightning/loops/base.py | 17 ++++++++++++++--- .../loops/epoch/training_epoch_loop.py | 3 +++ tests/loops/test_loops.py | 17 +++++++++++++++++ 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 904e05ac6a804..fe2702a270c74 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -49,7 +49,7 @@ class Loop(ABC, Generic[T]): """ def __init__(self) -> None: - self.restarting = False + self._restarting = False self._trainer: Optional["pl.Trainer"] = None @property @@ -70,6 +70,17 @@ def trainer(self, trainer: "pl.Trainer") -> None: if isinstance(v, Loop): v.trainer = trainer + @property + def restarting(self) -> bool: + return self._restarting + + @restarting.setter + def restarting(self, restarting: bool) -> None: + self._restarting = restarting + for loop in vars(self).values(): + if isinstance(loop, Loop): + loop.restarting = restarting + @property @abstractmethod def done(self) -> bool: @@ -190,7 +201,7 @@ def run(self, *args, **kwargs): self.on_advance_start(*args, **kwargs) self.advance(*args, **kwargs) self.on_advance_end() - self.restarting = False + self._restarting = False except StopIteration: break @@ -301,6 +312,7 @@ def load_state_dict( for k, v in self.__dict__.items(): if isinstance(v, Loop): v.load_state_dict(state_dict.copy(), prefix + k + ".") + self.restarting = True def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None: for k, v in self.__dict__.items(): @@ -335,4 +347,3 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional v.reset(metrics=False) self.on_load_checkpoint(state_dict[prefix + "state_dict"]) - self.restarting = True diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 69432ee07dd0b..345f8652c6082 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -145,6 +145,9 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch): # skip training and run validation in `on_advance_end` return + else: + # we are going to train first so the val loop does not need to restart + self.val_loop.restarting = False assert self._dataloader_iter is not None batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 89b4f4fb30edf..b6f0c05d3fd72 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -86,6 +86,23 @@ def test_connect_loops_recursive(): assert child1.trainer is trainer +def test_restarting_loops_recursive(): + class MyLoop(NestedLoop): + def __init__(self, loop=None): + super().__init__() + self.child = loop + + loop = MyLoop(MyLoop(MyLoop())) + + assert not loop.restarting + assert not loop.child.restarting + assert not loop.child.child.restarting + loop.restarting = True + assert loop.restarting + assert loop.child.restarting + assert loop.child.child.restarting + + def test_connect_subloops(tmpdir): """Test connecting individual subloops by calling `trainer.x.y.connect()`""" model = BoringModel() From 6d324cb207e506bc53f225b9b7bb70f955c17a35 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 16:18:43 +0100 Subject: [PATCH 2/4] Docs --- pytorch_lightning/loops/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index fe2702a270c74..1b4008bc9c7a0 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -72,10 +72,12 @@ def trainer(self, trainer: "pl.Trainer") -> None: @property def restarting(self) -> bool: + """Whether the state of this loop was reloaded and it needs to restart.""" return self._restarting @restarting.setter def restarting(self, restarting: bool) -> None: + """Connects this loop's restarting value and its children.""" self._restarting = restarting for loop in vars(self).values(): if isinstance(loop, Loop): From 1090f585dfc5c2705dfe38f0f9425d8224b9edec Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 12 Jan 2022 16:20:23 +0100 Subject: [PATCH 3/4] CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c8d607648034..8e22b81cf3fa1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -101,6 +101,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant relies on `signal.SIGTERM` to gracefully exit instead of `signal.SIGUSR1` ([#10605](https://github.com/PyTorchLightning/pytorch-lightning/pull/10605)) +- `Loop.restarting=...` now sets the value recursively for all subloops ([#11442](https://github.com/PyTorchLightning/pytorch-lightning/pull/11442)) + + - Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541)) From 83cc822e4759d3c6f3a99d71304abf8f4370c31d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 13 Jan 2022 15:45:34 +0100 Subject: [PATCH 4/4] Update pytorch_lightning/loops/epoch/training_epoch_loop.py Co-authored-by: Aki Nitta --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 345f8652c6082..46f2d3752c829 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -145,9 +145,8 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch): # skip training and run validation in `on_advance_end` return - else: - # we are going to train first so the val loop does not need to restart - self.val_loop.restarting = False + # we are going to train first so the val loop does not need to restart + self.val_loop.restarting = False assert self._dataloader_iter is not None batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)