diff --git a/CHANGELOG.md b/CHANGELOG.md index 758c573a1a5e6..93afce3a41b18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) +- Added `KFoldLoop` example ([#8715](https://github.com/PyTorchLightning/pytorch-lightning/pull/8715)) + + ### Changed - Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477)) diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py new file mode 100644 index 0000000000000..52e9439193454 --- /dev/null +++ b/pl_examples/loops_customisation/k_fold.py @@ -0,0 +1,216 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +WARNING: Loop customization is in `pre-alpha release` and the API is likely to change quite a lot ! +Please, open issues with your own particular requests, so the Lightning Team can progressively converge to a great API. +""" + +from typing import Any, Dict, List, Optional, Type + +import numpy as np +from sklearn.model_selection import KFold +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset, Subset + +from pytorch_lightning import _logger as log +from pytorch_lightning import LightningDataModule, seed_everything +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.loops.external_loop import ExternalLoop +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.boring_model import BoringModel, RandomDataset + +seed_everything(42) + + +class BaseDataModule(LightningDataModule): + def __init__(self): + super().__init__() + self.non_picklable = None + self.checkpoint_state: Optional[str] = None + + self._train_dataset: Optional[Dataset] = None + self._val_dataset: Optional[Dataset] = None + self._test_dataset: Optional[Dataset] = None + self._predict_dataset: Optional[Dataset] = None + + self._processed_train_dataset: Optional[Dataset] = None + self._processed_val_dataset: Optional[Dataset] = None + self._processed_test_dataset: Optional[Dataset] = None + self._processed_predict_dataset: Optional[Dataset] = None + + @property + def train_dataset(self) -> Optional[Dataset]: + return self._train_dataset + + @property + def val_dataset(self) -> Optional[Dataset]: + return self._val_dataset + + @property + def test_dataset(self) -> Optional[Dataset]: + return self._test_dataset + + @property + def predict_dataset(self) -> Optional[Dataset]: + return self._predict_dataset + + @property + def processed_train_dataset(self) -> Optional[Dataset]: + return self._processed_train_dataset or self.train_dataset + + @property + def processed_val_dataset(self) -> Optional[Dataset]: + return self._processed_val_dataset or self.val_dataset + + @property + def processed_test_dataset(self) -> Optional[Dataset]: + return self._processed_test_dataset or self.test_dataset + + @property + def processed_predict_dataset(self) -> Optional[Dataset]: + return self._processed_predict_dataset or self.predict_dataset + + @processed_train_dataset.setter + def processed_train_dataset(self, processed_train_dataset) -> None: + self._processed_train_dataset = processed_train_dataset + + @processed_val_dataset.setter + def processed_val_dataset(self, processed_val_dataset) -> None: + self._processed_val_dataset = processed_val_dataset + + @processed_val_dataset.setter + def processed_val_dataset(self, processed_val_dataset) -> None: + self._processed_val_dataset = processed_val_dataset + + @processed_test_dataset.setter + def processed_test_dataset(self, processed_test_dataset) -> None: + self._processed_test_dataset = processed_test_dataset + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.processed_train_dataset) + + def val_dataloader(self) -> DataLoader: + return DataLoader(self.processed_val_dataset) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.processed_test_dataset) + + def predict_dataloader(self) -> DataLoader: + return DataLoader(self.processed_predict_dataset) + + +class BoringDataModule(BaseDataModule): + def prepare_data(self) -> None: + self.random_full = RandomDataset(32, 64 * 4) + + def setup(self, stage: Optional[str] = None) -> None: + if stage == "fit" or stage is None: + self._train_dataset = Subset(self.random_full, indices=range(64)) + self.dims = self._train_dataset[0].shape + + if stage in ("fit", "validate") or stage is None: + self._val_dataset = Subset(self.random_full, indices=range(64, 64 * 2)) + + if stage == "test" or stage is None: + self._test_dataset = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) + self.dims = getattr(self, "dims", self._test_dataset[0].shape) + + if stage == "predict" or stage is None: + self._predict_dataset = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) + self.dims = getattr(self, "dims", self._predict_dataset[0].shape) + + +class KFoldLoop(ExternalLoop): + def __init__( + self, + num_folds: int, + best_model_paths: List[str] = [], + restarting: bool = False, + ): + super().__init__() + self.num_folds = num_folds + self.best_model_paths = best_model_paths + self.restarting = restarting + + @staticmethod + def loop_base_callback() -> Type[Callback]: + class BaseKFoldCallback(Callback): + @rank_zero_only + def on_fold_start(self, trainer, pl_module, counter): + """Override with your own logic""" + + return BaseKFoldCallback + + @property + def done(self) -> bool: + return self.current_fold >= self.num_folds + + def reset(self) -> None: + if not self.restarting: + self.current_fold = 0 + + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + # temporary hack + self.trainer.datamodule.setup("fit") + + def on_advance_start(self) -> None: + # more reproducible as re-creating a different trainer. + self.create_trainer(max_epochs=np.random.randint(10)) + # reload dataset for the current fold + dm = self.trainer.datamodule + dm.processed_train_dataset = self.process_dataset("train", dm.train_dataset) + dm.processed_val_dataset = self.process_dataset("val", dm.val_dataset) + # call user hook + self.trainer.call_hook("on_fold_start", self.current_fold) + # reset model parameters + self.trainer.lightning_module.reset_parameters() + + def advance(self) -> Any: + # dataloaders will be automatically reloaded + return self.trainer.fit(self.trainer.lightning_module, datamodule=self.trainer.datamodule) + + def on_advance_end(self) -> None: + self.current_fold += 1 + # stored best weight path for this fold + self.best_model_paths.append(self.trainer.checkpoint_callback.best_model_path) + + # utilities for creating a hold + def process_dataset(self, stage: str, dataset: Dataset) -> Subset: + kfold = KFold(self.num_folds, random_state=42, shuffle=True) + train_indices, validation_indices = list(kfold.split(range(len(dataset))))[self.current_fold] + indices = train_indices if stage == "train" else validation_indices + return Subset(dataset, indices.tolist()) + + def on_save_checkpoint(self) -> Dict: + return {"current_fold": self.current_fold} + + def on_load_checkpoint(self, state_dict) -> None: + self.current_fold = state_dict["current_fold"] + + +class KFoldCallback(KFoldLoop.loop_base_callback()): + + """This callback demonstrates how to implement your own callback API.""" + + @rank_zero_only + def on_fold_start(self, trainer, pl_module, counter) -> None: + log.info(f"Starting to train on fold {counter}") + + +loop = KFoldLoop(5) +model = BoringModel() +datamodule = BoringDataModule() +loop.connect_trainer(max_epochs=10, callbacks=KFoldCallback()) +loop.run(model, datamodule=datamodule) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index ee5c3a1b708f1..d4e29bce9a38f 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -23,6 +23,9 @@ from pytorch_lightning.trainer.progress import BaseProgress, Progress from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() class Loop(ABC): diff --git a/pytorch_lightning/loops/external_loop.py b/pytorch_lightning/loops/external_loop.py new file mode 100644 index 0000000000000..28782a463a680 --- /dev/null +++ b/pytorch_lightning/loops/external_loop.py @@ -0,0 +1,104 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +from typing import Any, Callable, Dict, Optional + +import pytorch_lightning as pl +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() + + +class ExternalLoop(Loop): + """This Loop is meant wrap trainer calls""" + + def __init__(self): + super().__init__() + warning_cache.warn("The ExternalLoop API is a `pre-alpha release` and breaking API changes are expected.") + self.create_trainer = self._wrap_trainer_wrapper(self.create_trainer) + self._has_setup = False + self._restore_external_loop = True + + def _wrap_trainer_wrapper(self, create_trainer: Callable) -> Callable: + @functools.wraps(create_trainer) + def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: + trainer = create_trainer(*args, trainer_kwargs=self.trainer_kwargs, **kwargs) + if not isinstance(trainer, pl.Trainer): + raise MisconfigurationException("The `create_trainer` hook should return a Trainer") + self.trainer = trainer + self.trainer.external_loop = self + + self.trainer.accelerator.connect(self.__lightning_module) + + # links data to the trainer + self.trainer.data_connector.attach_data( + self.trainer.lightning_module, + train_dataloaders=self.__train_dataloader, + val_dataloaders=self.__val_dataloaders, + test_dataloaders=self.__test_dataloaders, + predict_dataloaders=self.__predict_dataloaders, + datamodule=self.__datamodule, + ) + + # attach model to the training type plugin + self.trainer.data_connector.prepare_data() + + self.trainer.checkpoint_connector.resume_start() + self.trainer.checkpoint_connector.restore_loops(restore_external_loop=self._restore_external_loop) + return trainer + + return wrapped_func + + def connect_trainer(self, **trainer_kwargs: Dict[str, Any]) -> None: + self.trainer_kwargs = trainer_kwargs + + def create_trainer(self, *args, trainer_kwargs: Dict[str, Any] = {}, **kwargs) -> "pl.Trainer": + trainer_kwargs.update(kwargs) + return pl.Trainer(*args, **trainer_kwargs) + + def run( + self, + model: "pl.LightningModule", + train_dataloader=None, + val_dataloaders=None, + test_dataloaders=None, + predict_dataloaders=None, + datamodule=None, + ): + + self.__lightning_module = model + self.__train_dataloader = train_dataloader + self.__val_dataloaders = val_dataloaders + self.__test_dataloaders = test_dataloaders + self.__predict_dataloaders = predict_dataloaders + self.__datamodule = datamodule + + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, pl.LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + + if train_dataloader is not None and datamodule: + raise MisconfigurationException("You cannot pass both `loop.run(dataloaders=..., datamodule=...)`") + + if model is None: + raise MisconfigurationException("`model` must be provided to `loop.run()`") + + if self._trainer is None: + self.create_trainer() + self._restore_external_loop = False + + return super().run() diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5aac1acb6c572..3b2c982129ea6 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -237,6 +237,12 @@ def on_keyboard_interrupt(self): for callback in self.callbacks: callback.on_keyboard_interrupt(self, self.lightning_module) + def user_defined_hook(self, hook_name: str, *args, **kwargs): + """Called when a user calls call_hook directly with its own hook name.""" + for callback in self.callbacks: + if hasattr(callback, hook_name): + getattr(callback, hook_name)(self, self.lightning_module, *args, **kwargs) + @staticmethod def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool: parameters = list(signature(fn).parameters) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c096f0a609378..df6b682e2da3d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -186,7 +186,7 @@ def restore_callbacks(self) -> None: ) self.trainer.on_load_checkpoint(self._loaded_checkpoint) - def restore_loops(self) -> None: + def restore_loops(self, restore_external_loop: bool = False) -> None: """ Restores the loop progress from the pre-loaded checkpoint. Calls hooks on the loops to give it a chance to restore its state from the checkpoint. @@ -226,6 +226,11 @@ def restore_loops(self) -> None: self.trainer.test_loop.load_state_dict(state_dict["test_loop"]) self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"]) + if restore_external_loop: + external_loop = getattr(self.trainer, "external_loop", None) + if external_loop: + self.trainer.external_loop.load_state_dict(state_dict["external_loop"]) + def restore_optimizers_and_schedulers(self) -> None: """Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint.""" if ( @@ -471,9 +476,13 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: return state_dict def _get_loops_state_dict(self) -> Dict[str, Any]: - return { + state_dict = { "fit_loop": self.trainer.fit_loop.state_dict(), "validate_loop": self.trainer.validate_loop.state_dict(), "test_loop": self.trainer.test_loop.state_dict(), "predict_loop": self.trainer.predict_loop.state_dict(), } + external_loop = getattr(self.trainer, "external_loop", None) + if external_loop: + state_dict.update({"external_loop": external_loop.state_dict()}) + return state_dict diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fb1f93fdd93ae..a6d9f348abc63 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1243,6 +1243,8 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: if hasattr(self, hook_name): trainer_hook = getattr(self, hook_name) trainer_hook(*args, **kwargs) + else: + self.user_defined_hook(hook_name, *args, **kwargs) # next call hook in lightningModule output = None diff --git a/pytorch_lightning/utilities/boring_model.py b/pytorch_lightning/utilities/boring_model.py new file mode 100644 index 0000000000000..6a77d7b0ba4a5 --- /dev/null +++ b/pytorch_lightning/utilities/boring_model.py @@ -0,0 +1,219 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset + +from pytorch_lightning import LightningDataModule, LightningModule + + +class RandomDictDataset(Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + a = self.data[index] + b = a + 2 + return {"a": a, "b": b} + + def __len__(self): + return self.len + + +class RandomDataset(Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class RandomIterableDataset(IterableDataset): + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(self.count): + yield torch.randn(self.size) + + +class RandomIterableDatasetWithLen(IterableDataset): + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(len(self)): + yield torch.randn(self.size) + + def __len__(self): + return self.count + + +class BoringModel(LightningModule): + def __init__(self): + """ + Testing PL Module + + Use as follows: + - subclass + - modify the behavior for what you want + + class TestModel(BaseTestModel): + def training_step(...): + # do your own thing + + or: + + model = BaseTestModel() + model.training_epoch_end = None + + """ + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def loss(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def step(self, x): + x = self(x) + out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) + return out + + def training_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_step_end(self, training_step_outputs): + return training_step_outputs + + def training_epoch_end(self, outputs) -> None: + torch.stack([x["loss"] for x in outputs]).mean() + + def validation_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + torch.stack([x["x"] for x in outputs]).mean() + + def test_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"y": loss} + + def test_epoch_end(self, outputs) -> None: + torch.stack([x["y"] for x in outputs]).mean() + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def reset_parameters(self): + self.layer.reset_parameters() + + +class BoringDataModule(LightningDataModule): + def __init__(self, data_dir: str = "./"): + super().__init__() + self.data_dir = data_dir + self.non_picklable = None + self.checkpoint_state: Optional[str] = None + + def prepare_data(self): + self.random_full = RandomDataset(32, 64 * 4) + + def setup(self, stage: Optional[str] = None): + if stage == "fit" or stage is None: + self.random_train = Subset(self.random_full, indices=range(64)) + self.dims = self.random_train[0].shape + + if stage in ("fit", "validate") or stage is None: + self.random_val = Subset(self.random_full, indices=range(64, 64 * 2)) + + if stage == "test" or stage is None: + self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) + self.dims = getattr(self, "dims", self.random_test[0].shape) + + if stage == "predict" or stage is None: + self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) + self.dims = getattr(self, "dims", self.random_predict[0].shape) + + @property + def train_dataset(self): + return self.random_train + + @property + def val_dataset(self): + return self.random_val + + @property + def test_dataset(self): + return self.random_test + + @property + def predict_dataset(self): + return self.random_predict + + def train_dataloader(self): + return DataLoader(self.train_dataset) + + def val_dataloader(self): + return DataLoader(self.val_dataset) + + def test_dataloader(self): + return DataLoader(self.test_dataset) + + def predict_dataloader(self): + return DataLoader(self.predict_dataset) + + +class BoringLightningDataModule(LightningDataModule): + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 45c6453688939..c82a9d604437d 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from contextlib import suppress from copy import deepcopy from dataclasses import dataclass from typing import Any, Dict, Iterator @@ -22,8 +23,12 @@ import torch from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.loops import Loop, TrainingBatchLoop +from pytorch_lightning.loops.external_loop import ExternalLoop from pytorch_lightning.trainer.progress import BaseProgress +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -510,3 +515,68 @@ def configure_optimizers_multiple(self): assert state_dict != checkpoint["loops"]["fit_loop"] assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1 assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7") +def test_external_loop(tmpdir): + class TestException(Exception): + pass + + class CustomLoop(ExternalLoop): + def __init__(self): + super().__init__() + self.counter: int = 0 + self.stop_counter: int = 5 + self.has_restarted: bool = False + + def reset(self): + if self.restarting: + self.has_restarted = True + + @property + def done(self): + return self.counter >= self.stop_counter + + def advance(self, *args: Any, **kwargs: Any) -> None: + self.trainer.call_hook("custom_hook") + self.counter += 1 + self.trainer.fit(self.trainer.lightning_module, train_dataloader=self.trainer.train_dataloader) + + if self.counter == 3: + raise TestException + + def on_save_checkpoint(self) -> Dict: + return {"counter": self.counter} + + def on_load_checkpoint(self, state_dict: Dict) -> None: + self.counter = state_dict["counter"] + + class CustomCallback(Callback): + + has_called = False + + def custom_hook(self, trainer, pl_module): + self.has_called = True + + loop = CustomLoop() + model = BoringModel() + cb = CustomCallback() + cb2 = ModelCheckpoint(dirpath=tmpdir, save_last=True) + trainer_kwargs = dict( + default_root_dir=tmpdir, max_epochs=1, limit_train_batches=5, limit_val_batches=0, callbacks=[cb, cb2] + ) + loop.connect_trainer(**trainer_kwargs) + + with suppress(TestException): + loop.run(model) + + loop = CustomLoop() + model = BoringModel() + cb = CustomCallback() + trainer_kwargs["resume_from_checkpoint"] = cb2.last_model_path + loop.connect_trainer(**trainer_kwargs) + loop.run(model) + cb.has_called = True + loop.counter = 5 + assert loop.has_restarted