diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c8d607648034..8e22b81cf3fa1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -101,6 +101,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant relies on `signal.SIGTERM` to gracefully exit instead of `signal.SIGUSR1` ([#10605](https://github.com/PyTorchLightning/pytorch-lightning/pull/10605)) +- `Loop.restarting=...` now sets the value recursively for all subloops ([#11442](https://github.com/PyTorchLightning/pytorch-lightning/pull/11442)) + + - Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541)) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 49f2ee0077575..7876e4d44eb26 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -49,7 +49,7 @@ class Loop(ABC, Generic[T]): """ def __init__(self) -> None: - self.restarting = False + self._restarting = False self._trainer: Optional["pl.Trainer"] = None @property @@ -70,6 +70,19 @@ def trainer(self, trainer: "pl.Trainer") -> None: if isinstance(v, Loop): v.trainer = trainer + @property + def restarting(self) -> bool: + """Whether the state of this loop was reloaded and it needs to restart.""" + return self._restarting + + @restarting.setter + def restarting(self, restarting: bool) -> None: + """Connects this loop's restarting value and its children.""" + self._restarting = restarting + for loop in vars(self).values(): + if isinstance(loop, Loop): + loop.restarting = restarting + @property @abstractmethod def done(self) -> bool: @@ -190,7 +203,7 @@ def run(self, *args, **kwargs): self.on_advance_start(*args, **kwargs) self.advance(*args, **kwargs) self.on_advance_end() - self.restarting = False + self._restarting = False except StopIteration: break @@ -301,6 +314,7 @@ def load_state_dict( for k, v in self.__dict__.items(): if isinstance(v, Loop): v.load_state_dict(state_dict.copy(), prefix + k + ".") + self.restarting = True def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None: for k, v in self.__dict__.items(): @@ -336,4 +350,3 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional if prefix + "state_dict" in state_dict: # compatibility with old checkpoints self.on_load_checkpoint(state_dict[prefix + "state_dict"]) - self.restarting = True diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 69432ee07dd0b..46f2d3752c829 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -145,6 +145,8 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch): # skip training and run validation in `on_advance_end` return + # we are going to train first so the val loop does not need to restart + self.val_loop.restarting = False assert self._dataloader_iter is not None batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 89b4f4fb30edf..b6f0c05d3fd72 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -86,6 +86,23 @@ def test_connect_loops_recursive(): assert child1.trainer is trainer +def test_restarting_loops_recursive(): + class MyLoop(NestedLoop): + def __init__(self, loop=None): + super().__init__() + self.child = loop + + loop = MyLoop(MyLoop(MyLoop())) + + assert not loop.restarting + assert not loop.child.restarting + assert not loop.child.child.restarting + loop.restarting = True + assert loop.restarting + assert loop.child.restarting + assert loop.child.child.restarting + + def test_connect_subloops(tmpdir): """Test connecting individual subloops by calling `trainer.x.y.connect()`""" model = BoringModel()