Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fault Tolerant relies on `signal.SIGTERM` to gracefully exit instead of `signal.SIGUSR1` ([#10605](https://github.com/PyTorchLightning/pytorch-lightning/pull/10605))


- `Loop.restarting=...` now sets the value recursively for all subloops ([#11442](https://github.com/PyTorchLightning/pytorch-lightning/pull/11442))


- Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541))


Expand Down
19 changes: 16 additions & 3 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Loop(ABC, Generic[T]):
"""

def __init__(self) -> None:
self.restarting = False
self._restarting = False
self._trainer: Optional["pl.Trainer"] = None

@property
Expand All @@ -70,6 +70,19 @@ def trainer(self, trainer: "pl.Trainer") -> None:
if isinstance(v, Loop):
v.trainer = trainer

@property
def restarting(self) -> bool:
"""Whether the state of this loop was reloaded and it needs to restart."""
return self._restarting

@restarting.setter
def restarting(self, restarting: bool) -> None:
"""Connects this loop's restarting value and its children."""
self._restarting = restarting
for loop in vars(self).values():
if isinstance(loop, Loop):
loop.restarting = restarting

@property
@abstractmethod
def done(self) -> bool:
Expand Down Expand Up @@ -190,7 +203,7 @@ def run(self, *args, **kwargs):
self.on_advance_start(*args, **kwargs)
self.advance(*args, **kwargs)
self.on_advance_end()
self.restarting = False
self._restarting = False
except StopIteration:
break

Expand Down Expand Up @@ -301,6 +314,7 @@ def load_state_dict(
for k, v in self.__dict__.items():
if isinstance(v, Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".")
self.restarting = True

def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
for k, v in self.__dict__.items():
Expand Down Expand Up @@ -336,4 +350,3 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional

if prefix + "state_dict" in state_dict: # compatibility with old checkpoints
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
self.restarting = True
2 changes: 2 additions & 0 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch):
# skip training and run validation in `on_advance_end`
return
# we are going to train first so the val loop does not need to restart
self.val_loop.restarting = False

assert self._dataloader_iter is not None
batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)
Expand Down
17 changes: 17 additions & 0 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,23 @@ def test_connect_loops_recursive():
assert child1.trainer is trainer


def test_restarting_loops_recursive():
class MyLoop(NestedLoop):
def __init__(self, loop=None):
super().__init__()
self.child = loop

loop = MyLoop(MyLoop(MyLoop()))

assert not loop.restarting
assert not loop.child.restarting
assert not loop.child.child.restarting
loop.restarting = True
assert loop.restarting
assert loop.child.restarting
assert loop.child.child.restarting


def test_connect_subloops(tmpdir):
"""Test connecting individual subloops by calling `trainer.x.y.connect()`"""
model = BoringModel()
Expand Down