From c7a1fe5b76ea7968bbb4163b7c45d8eaed397df8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 4 Aug 2021 13:40:19 +0200 Subject: [PATCH 01/17] update --- pl_examples/loops_customisation/k_fold.py | 140 +++++++++++++++ pytorch_lightning/loops/base.py | 51 +++++- pytorch_lightning/trainer/data_loading.py | 61 +++++++ pytorch_lightning/trainer/trainer.py | 89 +++++++++- pytorch_lightning/utilities/boring_model.py | 186 ++++++++++++++++++++ 5 files changed, 525 insertions(+), 2 deletions(-) create mode 100644 pl_examples/loops_customisation/k_fold.py create mode 100644 pytorch_lightning/utilities/boring_model.py diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py new file mode 100644 index 0000000000000..8506e7316d87c --- /dev/null +++ b/pl_examples/loops_customisation/k_fold.py @@ -0,0 +1,140 @@ +from typing import Any, Dict, List, Optional + +import numpy as np +from sklearn.model_selection import KFold +from torch.utils.data import Dataset +from torch.utils.data.dataloader import DataLoader + +from pytorch_lightning import _logger as log +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.loops.base import ExternalLoop +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.boring_model import BoringModel, RandomDataset +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +seed_everything(42) + + +class SplitDataset(Dataset): + """SplitDataset is used to create Dataset Subset using indices. + Args: + dataset: A dataset to be splitted + indices: List of indices to expose from the dataset + use_duplicated_indices: Whether to allow duplicated indices. + Example:: + split_ds = SplitDataset(dataset, indices=[10, 14, 25]) + split_ds = SplitDataset(dataset, indices=[10, 10, 10, 14, 25], use_duplicated_indices=True) + """ + + _INTERNAL_KEYS = ("dataset", "indices", "data") + + def __init__(self, dataset: Any, indices: List[int] = None, use_duplicated_indices: bool = False) -> None: + if indices is None: + indices = [] + if not isinstance(indices, list): + raise MisconfigurationException("indices should be a list") + + if use_duplicated_indices: + indices = list(indices) + else: + indices = list(np.unique(indices)) + + if np.max(indices) >= len(dataset) or np.min(indices) < 0: + raise MisconfigurationException(f"`indices` should be within [0, {len(dataset) -1}].") + + self.dataset = dataset + self.indices = indices + + def __getattr__(self, key: str): + if key not in self._INTERNAL_KEYS: + return self.dataset.__getattribute__(key) + raise AttributeError + + def __setattr__(self, name: str, value: Any) -> None: + if name in self._INTERNAL_KEYS: + self.__dict__[name] = value + else: + setattr(self.dataset, name, value) + + def __getitem__(self, index: int) -> Any: + return self.dataset[self.indices[index]] + + def __len__(self) -> int: + return len(self.indices) - 1 + + +class BoringDataModule(LightningDataModule): + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + +class KFoldLoop(ExternalLoop): + def __init__(self, num_folds: int, num_epochs: int = 10) -> None: + super().__init__() + self.num_folds = num_folds + self.num_epochs = num_epochs + + def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: + return super().run(*args, **kwargs) + + @property + def done(self) -> bool: + return self.current_fold >= self.num_folds + + def reset(self) -> None: + if not self.restarting: + self.current_fold = 0 + self.set_max_epochs(self.num_epochs) + + def generate_fold(self, dataloader_kwargs: Dict[str, Any], stage: str): + dataset = dataloader_kwargs["dataset"] + kfold = KFold(self.num_folds, random_state=42, shuffle=True) + train_indices, validation_indices = list(kfold.split(range(len(dataset))))[self.current_fold] + if stage == "train": + dataloader_kwargs["dataset"] = SplitDataset(dataset, train_indices.tolist()) + else: + dataloader_kwargs["dataset"] = SplitDataset(dataset, validation_indices.tolist()) + dataloader_kwargs["sampler"].data_source = dataloader_kwargs["dataset"] + return dataloader_kwargs + + def on_advance_start(self): + self.reload_train_dataloader(self.generate_fold) + self.reload_val_dataloaders(self.generate_fold) + self.trainer.call_hook("on_fold_start", self.current_fold) + + def advance(self): + return self.trainer.fit(self.lightning_module, train_dataloader=self.train_dataloader) + + def on_advance_end(self) -> None: + self.current_fold += 1 + self.increment_max_epochs(self.num_epochs) + + def on_save_checkpoint(self) -> Dict: + return {"current_fold": self.current_fold} + + def on_load_checkpoint(self, state_dict) -> None: + self.current_fold = state_dict["current_fold"] + + +class KFoldCallback(Callback): + @rank_zero_only + def on_fold_start(self, trainer, pl_module, counter): + log.info(f"Starting to train on fold {counter}") + + +loop = KFoldLoop(5) +model = BoringModel() +datamodule = BoringDataModule() +trainer = Trainer(callbacks=KFoldCallback()) +trainer.run_loop(model, datamodule=datamodule, loop=loop) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index ee5c3a1b708f1..c1ff34110de23 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -13,9 +13,11 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from functools import partial +from typing import Any, Callable, Dict, List, Optional from deprecate import void +from torch.utils.data.dataloader import DataLoader from torchmetrics import Metric import pytorch_lightning as pl @@ -238,3 +240,50 @@ def _load_from_state_dict( self.on_load_checkpoint(state_dict[prefix + "state_dict"]) self.restarting = True + + +class ExternalLoop(Loop): + """This Loop is meant wrap trainer calls""" + + def set_max_epochs(self, max_epochs: int): + self.trainer.fit_loop.max_epochs = max_epochs + + def increment_max_epochs(self, max_epochs: int): + self.trainer.fit_loop.max_epochs += max_epochs + + def set_max_steps(self, max_steps: int): + self.trainer.fit_loop.max_steps = max_steps + + def increment_max_steps(self, max_steps: int): + self.trainer.fit_loop.max_steps += max_steps + + def reload_train_dataloader(self, user_function: Optional[Callable] = None) -> DataLoader: + self.trainer.train_dataloader = None + self.trainer.reset_train_dataloader(self.trainer.lightning_module) + if user_function: + user_function = partial(user_function, stage="train") + loaders = self.trainer.train_dataloader.loaders + loaders = loaders if isinstance(loaders, DataLoader) else loaders.loaders + self.trainer.train_dataloader.loaders = self.trainer.apply_user_function(loaders, user_function) + return self.trainer.train_dataloader + + def reload_val_dataloaders(self, user_function: Optional[Callable] = None) -> List[DataLoader]: + self.trainer.reset_val_dataloader(self.trainer.lightning_module) + if user_function: + user_function = partial(user_function, stage="val") + self.trainer.val_dataloaders = [ + self.trainer.apply_user_function(dl, user_function) for dl in self.trainer.val_dataloaders + ] + return self.trainer.val_dataloaders + + @property + def lightning_module(self): + return self.trainer.lightning_module + + @property + def train_dataloader(self) -> DataLoader: + return self.trainer.train_dataloader + + @property + def val_dataloaders(self) -> List[DataLoader]: + return self.trainer.val_dataloaders diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 9f8afbe451306..973cbcc32437d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -252,6 +252,67 @@ def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[Runnin dataloader = dl_cls(**dl_kwargs) return dataloader + def apply_user_function(self, dataloader: DataLoader, user_function: Callable) -> DataLoader: + if not isinstance(dataloader, DataLoader): + raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") + + # get the dataloader instance attributes + attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} + # not part of `vars` + attrs["multiprocessing_context"] = dataloader.multiprocessing_context + + # get the dataloader instance `__init__` parameters + params = dict(inspect.signature(dataloader.__init__).parameters) + + # keep only the params whose default is different to the current attr value + non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]} + # add `dataset` as it might have been replaced with `*args` + non_defaults.add("dataset") + + # kwargs to re-construct the dataloader + dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} + dl_kwargs.update(self._resolve_batch_sampler(dataloader, dataloader.sampler, mode=RunningStage.translate)) + + required_args = { + p.name + for p in params.values() + if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) + and p.default is p.empty + and p.name not in dl_kwargs + } + # the dataloader has required args which we could not extract from the existing attributes + if required_args: + required_args = sorted(required_args) + dataloader_cls_name = dataloader.__class__.__name__ + raise MisconfigurationException( + f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " + "This would fail as some of the `__init__` arguments are not available as instance attributes. " + f"The missing attributes are {required_args}. " + f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or " + "manually add the `DistributedSampler` as: " + f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." + ) + + has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) + if not has_variadic_kwargs: + # the dataloader signature does not allow keyword arguments that need to be passed + missing_kwargs = dl_kwargs.keys() - params.keys() + if missing_kwargs: + missing_kwargs = sorted(missing_kwargs) + dataloader_cls_name = dataloader.__class__.__name__ + raise MisconfigurationException( + f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " + "This would fail as it doesn't expose all its attributes in the `__init__` signature. " + f"The missing arguments are {missing_kwargs}. " + f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or " + "manually add the `DistributedSampler` as: " + f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." + ) + + dl_cls = type(dataloader) + dataloader = dl_cls(**user_function(dl_kwargs)) + return dataloader + def _get_distributed_sampler( self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None ) -> DistributedSampler: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 590f04957bb75..fc184e06ebbaf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,6 +29,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop +from pytorch_lightning.loops.base import ExternalLoop, Loop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop @@ -834,6 +835,61 @@ def tune( return result + def run_loop( + self, + model: "pl.LightningModule", + train_dataloader: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + test_dataloaders: Optional[EVAL_DATALOADERS] = None, + predict_dataloaders: Optional[EVAL_DATALOADERS] = None, + datamodule: Optional[LightningDataModule] = None, + loop: Union[ExternalLoop, Loop] = None, + ): + + # -------------------- + # SETUP HOOK + # -------------------- + # FIXME: hack to not break + self.state.fn = TrainerFn.FITTING + self.state.status = TrainerStatus.RUNNING + self.training = True + + Trainer._log_api_event("run_loop") + + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + if train_dataloader is not None and datamodule: + raise MisconfigurationException("You cannot pass both `trainer.run_loop(dataloaders=..., datamodule=...)`") + + if loop is None or not isinstance(loop, Loop): + raise MisconfigurationException( + "You should provide an `ExternalLoop` or `Loop` object as `trainer.run_loop(loop=...)`" + ) + + model = model or self.lightning_module + if model is None: + raise MisconfigurationException( + "`model` must be provided to `trainer.predict()` when it hasn't been passed in a previous run" + ) + + # links data to the trainer + self.data_connector.attach_data( + model, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloaders, + test_dataloaders=test_dataloaders, + predict_dataloaders=predict_dataloaders, + datamodule=datamodule, + ) + + self._prepare_run(model) + + loop.trainer = self + + return loop.run() + def _restore_modules_and_callbacks(self) -> None: # restore modules after setup if self.state.fn == TrainerFn.FITTING: @@ -851,7 +907,7 @@ def _load_checkpoint_weights(self): rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") self.checkpoint_connector.restore_model_weights(self._ckpt_path) - def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + def _prepare_run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) @@ -911,6 +967,37 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, `pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions. """ + def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + self._prepare_run(model) + + # ---------------------------- + # INSPECT THE CORE LOOPS + # ---------------------------- + fr""" + Lightning internal flow looks like this: + {Trainer.fit} or {Trainer.test} or {Trainer.predict} || + | || + create accelerator || + | || + {self._dispatch} || + | || LIGHTNING + {self.accelerator.start_training} || + or {self.accelerator.start_evaluating} || + or {self.accelerator.start_predicting} || FLOW + | || + {self.run_stage} || + | || DIRECTION + {self._run_train} || + or {self._run_evaluate} || + or {self._run_predict} || + | || + results \/ + This is used to guide readers to the core loops: train, test, predict. + {self._run_predict} is the simplest to understand, use `Go to Definition` to read it :) + Search for `start_training` or `start_evaluating` or `start_predicting` in + `pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions. + """ + # ---------------------------- # TRAIN # ---------------------------- diff --git a/pytorch_lightning/utilities/boring_model.py b/pytorch_lightning/utilities/boring_model.py new file mode 100644 index 0000000000000..d20cb1287e326 --- /dev/null +++ b/pytorch_lightning/utilities/boring_model.py @@ -0,0 +1,186 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset + +from pytorch_lightning import LightningDataModule, LightningModule + + +class RandomDictDataset(Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + a = self.data[index] + b = a + 2 + return {"a": a, "b": b} + + def __len__(self): + return self.len + + +class RandomDataset(Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class RandomIterableDataset(IterableDataset): + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(self.count): + yield torch.randn(self.size) + + +class RandomIterableDatasetWithLen(IterableDataset): + def __init__(self, size: int, count: int): + self.count = count + self.size = size + + def __iter__(self): + for _ in range(len(self)): + yield torch.randn(self.size) + + def __len__(self): + return self.count + + +class BoringModel(LightningModule): + def __init__(self): + """ + Testing PL Module + + Use as follows: + - subclass + - modify the behavior for what you want + + class TestModel(BaseTestModel): + def training_step(...): + # do your own thing + + or: + + model = BaseTestModel() + model.training_epoch_end = None + + """ + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def loss(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def step(self, x): + x = self(x) + out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) + return out + + def training_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_step_end(self, training_step_outputs): + return training_step_outputs + + def training_epoch_end(self, outputs) -> None: + torch.stack([x["loss"] for x in outputs]).mean() + + def validation_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + torch.stack([x["x"] for x in outputs]).mean() + + def test_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"y": loss} + + def test_epoch_end(self, outputs) -> None: + torch.stack([x["y"] for x in outputs]).mean() + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + +class BoringDataModule(LightningDataModule): + def __init__(self, data_dir: str = "./"): + super().__init__() + self.data_dir = data_dir + self.non_picklable = None + self.checkpoint_state: Optional[str] = None + + def prepare_data(self): + self.random_full = RandomDataset(32, 64 * 4) + + def setup(self, stage: Optional[str] = None): + if stage == "fit" or stage is None: + self.random_train = Subset(self.random_full, indices=range(64)) + self.dims = self.random_train[0].shape + + if stage in ("fit", "validate") or stage is None: + self.random_val = Subset(self.random_full, indices=range(64, 64 * 2)) + + if stage == "test" or stage is None: + self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) + self.dims = getattr(self, "dims", self.random_test[0].shape) + + if stage == "predict" or stage is None: + self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) + self.dims = getattr(self, "dims", self.random_predict[0].shape) + + def train_dataloader(self): + return DataLoader(self.random_train) + + def val_dataloader(self): + return DataLoader(self.random_val) + + def test_dataloader(self): + return DataLoader(self.random_test) + + def predict_dataloader(self): + return DataLoader(self.random_predict) From 210bd51d8a16c1b0347939b1eed7ae7197cd52a8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 4 Aug 2021 16:23:49 +0200 Subject: [PATCH 02/17] update --- pl_examples/loops_customisation/k_fold.py | 18 ++++++++++++++---- pytorch_lightning/trainer/callback_hook.py | 6 ++++++ pytorch_lightning/trainer/trainer.py | 2 ++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py index 8506e7316d87c..bb4bb3739d3b7 100644 --- a/pl_examples/loops_customisation/k_fold.py +++ b/pl_examples/loops_customisation/k_fold.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Type import numpy as np from sklearn.model_selection import KFold @@ -85,8 +85,15 @@ def __init__(self, num_folds: int, num_epochs: int = 10) -> None: self.num_folds = num_folds self.num_epochs = num_epochs - def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: - return super().run(*args, **kwargs) + @staticmethod + def loop_base_callback() -> Type[Callback]: + class BaseKFoldCallback(Callback): + @rank_zero_only + def on_fold_start(self, trainer, pl_module, counter): + """Override with your own logic""" + log.info(f"Starting to train on fold {counter}") + + return BaseKFoldCallback @property def done(self) -> bool: @@ -127,7 +134,10 @@ def on_load_checkpoint(self, state_dict) -> None: self.current_fold = state_dict["current_fold"] -class KFoldCallback(Callback): +class KFoldCallback(KFoldLoop.loop_base_callback()): + + """This callback demonstrates how to implement your own logic.""" + @rank_zero_only def on_fold_start(self, trainer, pl_module, counter): log.info(f"Starting to train on fold {counter}") diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 842a10aa69ef1..6ac5e1b87f605 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -237,6 +237,12 @@ def on_keyboard_interrupt(self): for callback in self.callbacks: callback.on_keyboard_interrupt(self, self.lightning_module) + def user_defined_hook(self, hook_name: str, *args, **kwargs): + """Called when a user calls call_hook directly with its own hook name.""" + for callback in self.callbacks: + if hasattr(callback, hook_name): + getattr(callback, hook_name)(self, self.lightning_module, *args, **kwargs) + @staticmethod def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool: parameters = list(signature(fn).parameters) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fc184e06ebbaf..8ab4d901f523e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1316,6 +1316,8 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: if hasattr(self, hook_name): trainer_hook = getattr(self, hook_name) trainer_hook(*args, **kwargs) + else: + self.user_defined_hook(hook_name, *args, **kwargs) # next call hook in lightningModule output = None From 91a9dff6ac8b97e3540b06b54deea44f0b78a13b Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 4 Aug 2021 16:26:41 +0200 Subject: [PATCH 03/17] update --- pl_examples/loops_customisation/k_fold.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py index bb4bb3739d3b7..f409174c4d0f7 100644 --- a/pl_examples/loops_customisation/k_fold.py +++ b/pl_examples/loops_customisation/k_fold.py @@ -91,7 +91,6 @@ class BaseKFoldCallback(Callback): @rank_zero_only def on_fold_start(self, trainer, pl_module, counter): """Override with your own logic""" - log.info(f"Starting to train on fold {counter}") return BaseKFoldCallback @@ -136,7 +135,7 @@ def on_load_checkpoint(self, state_dict) -> None: class KFoldCallback(KFoldLoop.loop_base_callback()): - """This callback demonstrates how to implement your own logic.""" + """This callback demonstrates how to implement your create callbacks.""" @rank_zero_only def on_fold_start(self, trainer, pl_module, counter): From 2fb84961b6b6da58519283abf727a45070a8d15a Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 4 Aug 2021 17:55:34 +0200 Subject: [PATCH 04/17] update --- pl_examples/loops_customisation/k_fold.py | 14 ++++++++++---- pytorch_lightning/utilities/boring_model.py | 3 +++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py index f409174c4d0f7..e7d2a0a0dc1df 100644 --- a/pl_examples/loops_customisation/k_fold.py +++ b/pl_examples/loops_customisation/k_fold.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field from typing import Any, Dict, List, Type import numpy as np @@ -79,11 +80,13 @@ def predict_dataloader(self): return DataLoader(RandomDataset(32, 64)) +@dataclass class KFoldLoop(ExternalLoop): - def __init__(self, num_folds: int, num_epochs: int = 10) -> None: - super().__init__() - self.num_folds = num_folds - self.num_epochs = num_epochs + + num_folds: int + num_epochs: int = 10 + best_model_paths: List[str] = field(default_factory=lambda: []) + restarting: bool = False @staticmethod def loop_base_callback() -> Type[Callback]: @@ -118,6 +121,7 @@ def on_advance_start(self): self.reload_train_dataloader(self.generate_fold) self.reload_val_dataloaders(self.generate_fold) self.trainer.call_hook("on_fold_start", self.current_fold) + self.lightning_module.reset_parameters() def advance(self): return self.trainer.fit(self.lightning_module, train_dataloader=self.train_dataloader) @@ -125,6 +129,8 @@ def advance(self): def on_advance_end(self) -> None: self.current_fold += 1 self.increment_max_epochs(self.num_epochs) + # stored best weight path for this fold + self.best_model_paths.append(self.trainer.checkpoint_callback.best_model_path) def on_save_checkpoint(self) -> Dict: return {"current_fold": self.current_fold} diff --git a/pytorch_lightning/utilities/boring_model.py b/pytorch_lightning/utilities/boring_model.py index d20cb1287e326..6f9e0f4eb8984 100644 --- a/pytorch_lightning/utilities/boring_model.py +++ b/pytorch_lightning/utilities/boring_model.py @@ -146,6 +146,9 @@ def test_dataloader(self): def predict_dataloader(self): return DataLoader(RandomDataset(32, 64)) + def reset_parameters(self): + self.layer.reset_parameters() + class BoringDataModule(LightningDataModule): def __init__(self, data_dir: str = "./"): From 626525eb99cc7ba93ea3527bb33b52ad86403a97 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 4 Aug 2021 20:42:54 +0200 Subject: [PATCH 05/17] simplify code --- pl_examples/loops_customisation/k_fold.py | 22 +++-------- pytorch_lightning/trainer/trainer.py | 44 +++++---------------- pytorch_lightning/utilities/boring_model.py | 14 +++++++ 3 files changed, 29 insertions(+), 51 deletions(-) diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py index e7d2a0a0dc1df..d96b1def90da1 100644 --- a/pl_examples/loops_customisation/k_fold.py +++ b/pl_examples/loops_customisation/k_fold.py @@ -4,15 +4,13 @@ import numpy as np from sklearn.model_selection import KFold from torch.utils.data import Dataset -from torch.utils.data.dataloader import DataLoader from pytorch_lightning import _logger as log from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.loops.base import ExternalLoop from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.boring_model import BoringModel, RandomDataset +from pytorch_lightning.utilities.boring_model import BoringDataModule, BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException seed_everything(42) @@ -66,20 +64,6 @@ def __len__(self) -> int: return len(self.indices) - 1 -class BoringDataModule(LightningDataModule): - def train_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - def val_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - def test_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - def predict_dataloader(self): - return DataLoader(RandomDataset(32, 64)) - - @dataclass class KFoldLoop(ExternalLoop): @@ -117,6 +101,10 @@ def generate_fold(self, dataloader_kwargs: Dict[str, Any], stage: str): dataloader_kwargs["sampler"].data_source = dataloader_kwargs["dataset"] return dataloader_kwargs + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + # temporary hack + self.trainer.datamodule.setup("fit") + def on_advance_start(self): self.reload_train_dataloader(self.generate_fold) self.reload_val_dataloaders(self.generate_fold) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index aa3115496bf75..6c32f5198df3b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -884,10 +884,14 @@ def run_loop( datamodule=datamodule, ) - self._prepare_run(model) - loop.trainer = self + # attach model to the training type plugin + self.accelerator.connect(model) + + self.data_connector.prepare_data() + self.callback_connector._attach_model_callbacks() + return loop.run() def _restore_modules_and_callbacks(self) -> None: @@ -907,7 +911,7 @@ def _load_checkpoint_weights(self): rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") self.checkpoint_connector.restore_model_weights(self._ckpt_path) - def _prepare_run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) @@ -969,37 +973,6 @@ def _prepare_run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_ `pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions. """ - def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: - self._prepare_run(model) - - # ---------------------------- - # INSPECT THE CORE LOOPS - # ---------------------------- - fr""" - Lightning internal flow looks like this: - {Trainer.fit} or {Trainer.test} or {Trainer.predict} || - | || - create accelerator || - | || - {self._dispatch} || - | || LIGHTNING - {self.accelerator.start_training} || - or {self.accelerator.start_evaluating} || - or {self.accelerator.start_predicting} || FLOW - | || - {self.run_stage} || - | || DIRECTION - {self._run_train} || - or {self._run_evaluate} || - or {self._run_predict} || - | || - results \/ - This is used to guide readers to the core loops: train, test, predict. - {self._run_predict} is the simplest to understand, use `Go to Definition` to read it :) - Search for `start_training` or `start_evaluating` or `start_predicting` in - `pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions. - """ - # ---------------------------- # TRAIN # ---------------------------- @@ -1141,6 +1114,9 @@ def _run_train(self) -> None: # reload data when needed model = self.lightning_module + # hook + self.data_connector.prepare_data() + self.reset_train_val_dataloaders(model) try: diff --git a/pytorch_lightning/utilities/boring_model.py b/pytorch_lightning/utilities/boring_model.py index 6f9e0f4eb8984..32cc6dc665da2 100644 --- a/pytorch_lightning/utilities/boring_model.py +++ b/pytorch_lightning/utilities/boring_model.py @@ -187,3 +187,17 @@ def test_dataloader(self): def predict_dataloader(self): return DataLoader(self.random_predict) + + +class BoringLightningDataModule(LightningDataModule): + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) From 480744e3b90f57ece4bed71cf3e106b512b7b82c Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 4 Aug 2021 20:48:10 +0200 Subject: [PATCH 06/17] break if new trainer is attached --- pytorch_lightning/loops/base.py | 20 ++++++++++++++++++++ pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index c1ff34110de23..2c85b0853b7ab 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -245,6 +245,26 @@ def _load_from_state_dict( class ExternalLoop(Loop): """This Loop is meant wrap trainer calls""" + @property + def trainer(self) -> Optional["pl.Trainer"]: + return self._trainer + + @trainer.setter + def trainer(self, trainer: "pl.Trainer"): + """Connects this loop's trainer and its children""" + if not isinstance(trainer, pl.Trainer): + raise MisconfigurationException( + f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." + ) + if hasattr(self, "_trainer") and isinstance(self._trainer, pl.Trainer): + raise MisconfigurationException( + f"Loop {self.__class__.__name__} should be attached to only 1 `Trainer` instance." + ) + self._trainer = trainer + for v in self.__dict__.values(): + if isinstance(v, Loop): + v.trainer = trainer + def set_max_epochs(self, max_epochs: int): self.trainer.fit_loop.max_epochs = max_epochs diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6c32f5198df3b..aa2828966fbbe 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -884,13 +884,13 @@ def run_loop( datamodule=datamodule, ) + # attach trainer loop.trainer = self # attach model to the training type plugin self.accelerator.connect(model) self.data_connector.prepare_data() - self.callback_connector._attach_model_callbacks() return loop.run() From 7cca34de61b9396fd9cba6f83056b828a53a192d Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 4 Aug 2021 20:56:00 +0200 Subject: [PATCH 07/17] update --- tests/loops/test_loops.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 45c6453688939..d79ced7ea73a1 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -22,7 +22,9 @@ import torch from pytorch_lightning import Trainer +from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.loops import Loop, TrainingBatchLoop +from pytorch_lightning.loops.base import ExternalLoop from pytorch_lightning.trainer.progress import BaseProgress from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -510,3 +512,37 @@ def configure_optimizers_multiple(self): assert state_dict != checkpoint["loops"]["fit_loop"] assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1 assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch + + +def test_external_loop(tmpdir): + @dataclass + class CustomLoop(ExternalLoop): + + counter: int = 0 + stop_counter: int = 5 + + def reset(self): + pass + + @property + def done(self): + return self.counter >= self.stop_counter + + def advance(self, *args: Any, **kwargs: Any) -> None: + self.trainer.call_hook("custom_hook") + self.counter += 1 + + class CustomCallback(Callback): + + has_called = False + + def custom_hook(self, trainer, pl_module): + self.has_called = True + + loop = CustomLoop() + model = BoringModel() + cb = CustomCallback() + trainer = Trainer(default_root_dir=tmpdir, callbacks=cb) + trainer.run_loop(model, loop=loop) + cb.has_called = True + loop.counter = 5 From 40089aec9bae3e844976d29ca856b057f1d1818a Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 4 Aug 2021 20:57:06 +0200 Subject: [PATCH 08/17] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c1fab42597a0..723af3f406261 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) +- Added `KFoldLoop` example ([#8715](https://github.com/PyTorchLightning/pytorch-lightning/pull/8715)) + + ### Changed - Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477)) From d6d30fd79a2b2d15e3c610fb178e7209dc896d7f Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 4 Aug 2021 21:01:28 +0200 Subject: [PATCH 09/17] add warning --- pl_examples/loops_customisation/k_fold.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py index d96b1def90da1..312d4059459f3 100644 --- a/pl_examples/loops_customisation/k_fold.py +++ b/pl_examples/loops_customisation/k_fold.py @@ -1,3 +1,22 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +WARNING: Loop customization is in `pre-alpha release` and the API is likely to change quite a lot ! +Please, open issues with your own particular requests, so the Lightning Team can progressively converge to a great API. +""" + from dataclasses import dataclass, field from typing import Any, Dict, List, Type From dd560811dc313aeeb72cd85e7f2809272b6fa34f Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 4 Aug 2021 21:03:57 +0200 Subject: [PATCH 10/17] update --- pl_examples/loops_customisation/k_fold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py index 312d4059459f3..f1ff3336f630a 100644 --- a/pl_examples/loops_customisation/k_fold.py +++ b/pl_examples/loops_customisation/k_fold.py @@ -148,7 +148,7 @@ def on_load_checkpoint(self, state_dict) -> None: class KFoldCallback(KFoldLoop.loop_base_callback()): - """This callback demonstrates how to implement your create callbacks.""" + """This callback demonstrates how to implement your own callback API.""" @rank_zero_only def on_fold_start(self, trainer, pl_module, counter): From d209d53615a37fd351129c4b5f03053fa63398ff Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 5 Aug 2021 10:25:21 +0200 Subject: [PATCH 11/17] resolve bug --- pytorch_lightning/trainer/trainer.py | 29 +++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index aa2828966fbbe..9c4d40a80390a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -77,6 +77,7 @@ ) from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.enums import DistributedType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from pytorch_lightning.utilities.model_helpers import is_overridden @@ -884,15 +885,18 @@ def run_loop( datamodule=datamodule, ) - # attach trainer + # connect loop and trainer loop.trainer = self + self.loop = loop # attach model to the training type plugin self.accelerator.connect(model) - self.data_connector.prepare_data() - return loop.run() + results = loop.run() + + del self.loop + return results def _restore_modules_and_callbacks(self) -> None: # restore modules after setup @@ -1004,8 +1008,10 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, if self.state.fn == TrainerFn.FITTING: self.call_hook("on_fit_end") - # teardown - self._call_teardown_hook() + # teardown if necessary (similar calls for spawn plugins are excluded as they have + # been included at the end of `new_process` functions) + if self._distrib_type not in DistributedType.interactive_compatible_types(): + self._call_teardown_hook() if self.state.status != TrainerStatus.INTERRUPTED: self.state.status = TrainerStatus.FINISHED @@ -1114,9 +1120,6 @@ def _run_train(self) -> None: # reload data when needed model = self.lightning_module - # hook - self.data_connector.prepare_data() - self.reset_train_val_dataloaders(model) try: @@ -1274,7 +1277,7 @@ def _call_teardown_hook(self) -> None: if self.datamodule is not None: self.datamodule.teardown(stage=fn) - self.profiler.teardown(stage=fn) + self.teardown(stage=fn) self.lightning_module.teardown(stage=fn) @@ -1283,6 +1286,14 @@ def _call_teardown_hook(self) -> None: # these could have become stale if metrics are defined in `setup` self.lightning_module._metric_attributes = None + # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. + # It might be related to xla tensors blocked when moving the cpu kill loggers. + if self.logger is not None: + self.logger.finalize("success") + + # summarize profile results + self.profiler.describe() + def call_hook(self, hook_name: str, *args, **kwargs) -> Any: if self.lightning_module: prev_fx_name = self.lightning_module._current_fx_name From 615ab30023a7562cc08a6f65a8afe55dcf049160 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 5 Aug 2021 11:36:23 +0200 Subject: [PATCH 12/17] update --- pl_examples/loops_customisation/k_fold.py | 186 +++++++++++------- pytorch_lightning/loops/base.py | 60 +----- .../connectors/checkpoint_connector.py | 13 +- pytorch_lightning/trainer/trainer.py | 19 +- pytorch_lightning/utilities/boring_model.py | 24 ++- tests/loops/test_loops.py | 49 ++++- 6 files changed, 211 insertions(+), 140 deletions(-) diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py index f1ff3336f630a..ccb5ad8585744 100644 --- a/pl_examples/loops_customisation/k_fold.py +++ b/pl_examples/loops_customisation/k_fold.py @@ -17,79 +17,132 @@ Please, open issues with your own particular requests, so the Lightning Team can progressively converge to a great API. """ -from dataclasses import dataclass, field -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Optional, Type -import numpy as np from sklearn.model_selection import KFold -from torch.utils.data import Dataset +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset, Subset from pytorch_lightning import _logger as log -from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning import LightningDataModule, seed_everything, Trainer from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.loops.base import ExternalLoop from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.boring_model import BoringDataModule, BoringModel -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.boring_model import BoringModel, RandomDataset seed_everything(42) -class SplitDataset(Dataset): - """SplitDataset is used to create Dataset Subset using indices. - Args: - dataset: A dataset to be splitted - indices: List of indices to expose from the dataset - use_duplicated_indices: Whether to allow duplicated indices. - Example:: - split_ds = SplitDataset(dataset, indices=[10, 14, 25]) - split_ds = SplitDataset(dataset, indices=[10, 10, 10, 14, 25], use_duplicated_indices=True) - """ +class BaseDataModule(LightningDataModule): + def __init__(self): + super().__init__() + self.non_picklable = None + self.checkpoint_state: Optional[str] = None - _INTERNAL_KEYS = ("dataset", "indices", "data") + self._train_dataset: Optional[Dataset] = None + self._val_dataset: Optional[Dataset] = None + self._test_dataset: Optional[Dataset] = None + self._predict_dataset: Optional[Dataset] = None - def __init__(self, dataset: Any, indices: List[int] = None, use_duplicated_indices: bool = False) -> None: - if indices is None: - indices = [] - if not isinstance(indices, list): - raise MisconfigurationException("indices should be a list") + self._processed_train_dataset: Optional[Dataset] = None + self._processed_val_dataset: Optional[Dataset] = None + self._processed_test_dataset: Optional[Dataset] = None + self._processed_predict_dataset: Optional[Dataset] = None - if use_duplicated_indices: - indices = list(indices) - else: - indices = list(np.unique(indices)) + @property + def train_dataset(self): + return self._train_dataset - if np.max(indices) >= len(dataset) or np.min(indices) < 0: - raise MisconfigurationException(f"`indices` should be within [0, {len(dataset) -1}].") + @property + def val_dataset(self): + return self._val_dataset - self.dataset = dataset - self.indices = indices + @property + def test_dataset(self): + return self._test_dataset - def __getattr__(self, key: str): - if key not in self._INTERNAL_KEYS: - return self.dataset.__getattribute__(key) - raise AttributeError + @property + def predict_dataset(self): + return self._predict_dataset - def __setattr__(self, name: str, value: Any) -> None: - if name in self._INTERNAL_KEYS: - self.__dict__[name] = value - else: - setattr(self.dataset, name, value) + @property + def processed_train_dataset(self): + return self._processed_train_dataset or self.train_dataset - def __getitem__(self, index: int) -> Any: - return self.dataset[self.indices[index]] + @property + def processed_val_dataset(self): + return self._processed_val_dataset or self.val_dataset - def __len__(self) -> int: - return len(self.indices) - 1 + @property + def processed_test_dataset(self): + return self._processed_test_dataset or self.test_dataset + @property + def processed_predict_dataset(self): + return self._processed_predict_dataset or self.predict_dataset -@dataclass -class KFoldLoop(ExternalLoop): + @processed_train_dataset.setter + def processed_train_dataset(self, processed_train_dataset): + self._processed_train_dataset = processed_train_dataset + + @processed_val_dataset.setter + def processed_val_dataset(self, processed_val_dataset): + self._processed_val_dataset = processed_val_dataset + + @processed_val_dataset.setter + def processed_val_dataset(self, processed_val_dataset): + self._processed_val_dataset = processed_val_dataset + + @processed_test_dataset.setter + def processed_test_dataset(self, processed_test_dataset): + self._processed_test_dataset = processed_test_dataset + + def train_dataloader(self): + return DataLoader(self.processed_train_dataset) + + def val_dataloader(self): + return DataLoader(self.processed_val_dataset) + + def test_dataloader(self): + return DataLoader(self.processed_test_dataset) - num_folds: int - num_epochs: int = 10 - best_model_paths: List[str] = field(default_factory=lambda: []) - restarting: bool = False + def predict_dataloader(self): + return DataLoader(self.processed_predict_dataset) + + +class BoringDataModule(BaseDataModule): + def prepare_data(self): + self.random_full = RandomDataset(32, 64 * 4) + + def setup(self, stage: Optional[str] = None): + if stage == "fit" or stage is None: + self._train_dataset = Subset(self.random_full, indices=range(64)) + self.dims = self._train_dataset[0].shape + + if stage in ("fit", "validate") or stage is None: + self._val_dataset = Subset(self.random_full, indices=range(64, 64 * 2)) + + if stage == "test" or stage is None: + self._test_dataset = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) + self.dims = getattr(self, "dims", self._test_dataset[0].shape) + + if stage == "predict" or stage is None: + self._predict_dataset = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) + self.dims = getattr(self, "dims", self._predict_dataset[0].shape) + + +class KFoldLoop(ExternalLoop): + def __init__( + self, + num_folds: int, + num_epochs: int = 10, + best_model_paths: List[str] = [], + restarting: bool = False, + ): + self.num_folds = num_folds + self.num_epochs = num_epochs + self.best_model_paths = best_model_paths + self.restarting = restarting @staticmethod def loop_base_callback() -> Type[Callback]: @@ -109,29 +162,30 @@ def reset(self) -> None: self.current_fold = 0 self.set_max_epochs(self.num_epochs) - def generate_fold(self, dataloader_kwargs: Dict[str, Any], stage: str): - dataset = dataloader_kwargs["dataset"] - kfold = KFold(self.num_folds, random_state=42, shuffle=True) - train_indices, validation_indices = list(kfold.split(range(len(dataset))))[self.current_fold] - if stage == "train": - dataloader_kwargs["dataset"] = SplitDataset(dataset, train_indices.tolist()) - else: - dataloader_kwargs["dataset"] = SplitDataset(dataset, validation_indices.tolist()) - dataloader_kwargs["sampler"].data_source = dataloader_kwargs["dataset"] - return dataloader_kwargs - def on_run_start(self, *args: Any, **kwargs: Any) -> None: # temporary hack self.trainer.datamodule.setup("fit") + def process_dataset(self, stage: str, dataset: Dataset): + kfold = KFold(self.num_folds, random_state=42, shuffle=True) + train_indices, validation_indices = list(kfold.split(range(len(dataset))))[self.current_fold] + indices = train_indices if stage == "train" else validation_indices + return Subset(dataset, indices.tolist()) + def on_advance_start(self): - self.reload_train_dataloader(self.generate_fold) - self.reload_val_dataloaders(self.generate_fold) + self.trainer.datamodule.processed_train_dataset = self.process_dataset( + "train", self.trainer.datamodule.train_dataset + ) + self.trainer.datamodule.processed_val_dataset = self.process_dataset("val", self.trainer.datamodule.val_dataset) self.trainer.call_hook("on_fold_start", self.current_fold) - self.lightning_module.reset_parameters() + self.trainer.lightning_module.reset_parameters() def advance(self): - return self.trainer.fit(self.lightning_module, train_dataloader=self.train_dataloader) + return self.trainer.fit( + self.trainer.lightning_module, + train_dataloader=self.trainer.train_dataloader, + val_dataloaders=self.trainer.val_dataloaders, + ) def on_advance_end(self) -> None: self.current_fold += 1 @@ -159,4 +213,4 @@ def on_fold_start(self, trainer, pl_module, counter): model = BoringModel() datamodule = BoringDataModule() trainer = Trainer(callbacks=KFoldCallback()) -trainer.run_loop(model, datamodule=datamodule, loop=loop) +trainer.run_loop(model, datamodule=datamodule, external_loop=loop) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 2c85b0853b7ab..a9e21bb3c205b 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -13,11 +13,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from functools import partial -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, Optional from deprecate import void -from torch.utils.data.dataloader import DataLoader from torchmetrics import Metric import pytorch_lightning as pl @@ -25,6 +23,9 @@ from pytorch_lightning.trainer.progress import BaseProgress, Progress from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() class Loop(ABC): @@ -245,25 +246,9 @@ def _load_from_state_dict( class ExternalLoop(Loop): """This Loop is meant wrap trainer calls""" - @property - def trainer(self) -> Optional["pl.Trainer"]: - return self._trainer - - @trainer.setter - def trainer(self, trainer: "pl.Trainer"): - """Connects this loop's trainer and its children""" - if not isinstance(trainer, pl.Trainer): - raise MisconfigurationException( - f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." - ) - if hasattr(self, "_trainer") and isinstance(self._trainer, pl.Trainer): - raise MisconfigurationException( - f"Loop {self.__class__.__name__} should be attached to only 1 `Trainer` instance." - ) - self._trainer = trainer - for v in self.__dict__.values(): - if isinstance(v, Loop): - v.trainer = trainer + def __init__(self): + super().__init__() + warning_cache.warn("The ExternalLoop API is a `pre-alpha release` and breaking API changes are expected.") def set_max_epochs(self, max_epochs: int): self.trainer.fit_loop.max_epochs = max_epochs @@ -276,34 +261,3 @@ def set_max_steps(self, max_steps: int): def increment_max_steps(self, max_steps: int): self.trainer.fit_loop.max_steps += max_steps - - def reload_train_dataloader(self, user_function: Optional[Callable] = None) -> DataLoader: - self.trainer.train_dataloader = None - self.trainer.reset_train_dataloader(self.trainer.lightning_module) - if user_function: - user_function = partial(user_function, stage="train") - loaders = self.trainer.train_dataloader.loaders - loaders = loaders if isinstance(loaders, DataLoader) else loaders.loaders - self.trainer.train_dataloader.loaders = self.trainer.apply_user_function(loaders, user_function) - return self.trainer.train_dataloader - - def reload_val_dataloaders(self, user_function: Optional[Callable] = None) -> List[DataLoader]: - self.trainer.reset_val_dataloader(self.trainer.lightning_module) - if user_function: - user_function = partial(user_function, stage="val") - self.trainer.val_dataloaders = [ - self.trainer.apply_user_function(dl, user_function) for dl in self.trainer.val_dataloaders - ] - return self.trainer.val_dataloaders - - @property - def lightning_module(self): - return self.trainer.lightning_module - - @property - def train_dataloader(self) -> DataLoader: - return self.trainer.train_dataloader - - @property - def val_dataloaders(self) -> List[DataLoader]: - return self.trainer.val_dataloaders diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c096f0a609378..df6b682e2da3d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -186,7 +186,7 @@ def restore_callbacks(self) -> None: ) self.trainer.on_load_checkpoint(self._loaded_checkpoint) - def restore_loops(self) -> None: + def restore_loops(self, restore_external_loop: bool = False) -> None: """ Restores the loop progress from the pre-loaded checkpoint. Calls hooks on the loops to give it a chance to restore its state from the checkpoint. @@ -226,6 +226,11 @@ def restore_loops(self) -> None: self.trainer.test_loop.load_state_dict(state_dict["test_loop"]) self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"]) + if restore_external_loop: + external_loop = getattr(self.trainer, "external_loop", None) + if external_loop: + self.trainer.external_loop.load_state_dict(state_dict["external_loop"]) + def restore_optimizers_and_schedulers(self) -> None: """Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint.""" if ( @@ -471,9 +476,13 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: return state_dict def _get_loops_state_dict(self) -> Dict[str, Any]: - return { + state_dict = { "fit_loop": self.trainer.fit_loop.state_dict(), "validate_loop": self.trainer.validate_loop.state_dict(), "test_loop": self.trainer.test_loop.state_dict(), "predict_loop": self.trainer.predict_loop.state_dict(), } + external_loop = getattr(self.trainer, "external_loop", None) + if external_loop: + state_dict.update({"external_loop": external_loop.state_dict()}) + return state_dict diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9c4d40a80390a..26d71258002c3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,7 +29,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop -from pytorch_lightning.loops.base import ExternalLoop, Loop +from pytorch_lightning.loops.base import ExternalLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop @@ -844,7 +844,7 @@ def run_loop( test_dataloaders: Optional[EVAL_DATALOADERS] = None, predict_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, - loop: Union[ExternalLoop, Loop] = None, + external_loop: ExternalLoop = None, ): # -------------------- @@ -864,9 +864,9 @@ def run_loop( if train_dataloader is not None and datamodule: raise MisconfigurationException("You cannot pass both `trainer.run_loop(dataloaders=..., datamodule=...)`") - if loop is None or not isinstance(loop, Loop): + if external_loop is None or not isinstance(external_loop, ExternalLoop): raise MisconfigurationException( - "You should provide an `ExternalLoop` or `Loop` object as `trainer.run_loop(loop=...)`" + "You should provide an `ExternalLoop` object as `trainer.run_loop(loop=...)`" ) model = model or self.lightning_module @@ -886,16 +886,19 @@ def run_loop( ) # connect loop and trainer - loop.trainer = self - self.loop = loop + external_loop.trainer = self + self.external_loop = external_loop # attach model to the training type plugin self.accelerator.connect(model) self.data_connector.prepare_data() - results = loop.run() + self.checkpoint_connector.resume_start() + self.checkpoint_connector.restore_loops(restore_external_loop=True) - del self.loop + results = external_loop.run() + + del self.external_loop return results def _restore_modules_and_callbacks(self) -> None: diff --git a/pytorch_lightning/utilities/boring_model.py b/pytorch_lightning/utilities/boring_model.py index 32cc6dc665da2..6a77d7b0ba4a5 100644 --- a/pytorch_lightning/utilities/boring_model.py +++ b/pytorch_lightning/utilities/boring_model.py @@ -176,17 +176,33 @@ def setup(self, stage: Optional[str] = None): self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) self.dims = getattr(self, "dims", self.random_predict[0].shape) + @property + def train_dataset(self): + return self.random_train + + @property + def val_dataset(self): + return self.random_val + + @property + def test_dataset(self): + return self.random_test + + @property + def predict_dataset(self): + return self.random_predict + def train_dataloader(self): - return DataLoader(self.random_train) + return DataLoader(self.train_dataset) def val_dataloader(self): - return DataLoader(self.random_val) + return DataLoader(self.val_dataset) def test_dataloader(self): - return DataLoader(self.random_test) + return DataLoader(self.test_dataset) def predict_dataloader(self): - return DataLoader(self.random_predict) + return DataLoader(self.predict_dataset) class BoringLightningDataModule(LightningDataModule): diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index d79ced7ea73a1..ff6f36800106f 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from contextlib import suppress from copy import deepcopy from dataclasses import dataclass from typing import Any, Dict, Iterator @@ -22,10 +23,12 @@ import torch from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.loops import Loop, TrainingBatchLoop from pytorch_lightning.loops.base import ExternalLoop from pytorch_lightning.trainer.progress import BaseProgress +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -514,15 +517,22 @@ def configure_optimizers_multiple(self): assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7") def test_external_loop(tmpdir): - @dataclass - class CustomLoop(ExternalLoop): + class TestException(Exception): + pass - counter: int = 0 - stop_counter: int = 5 + class CustomLoop(ExternalLoop): + def __init__(self): + super().__init__() + self.counter: int = 0 + self.stop_counter: int = 5 + self.has_restarted: bool = False def reset(self): - pass + if self.restarting: + self.has_restarted = True @property def done(self): @@ -531,6 +541,17 @@ def done(self): def advance(self, *args: Any, **kwargs: Any) -> None: self.trainer.call_hook("custom_hook") self.counter += 1 + self.reload_train_dataloader() + self.trainer.fit(self.lightning_module, train_dataloader=self.train_dataloader) + + if self.counter == 3: + raise TestException + + def on_save_checkpoint(self) -> Dict: + return {"counter": self.counter} + + def on_load_checkpoint(self, state_dict: Dict) -> None: + self.counter = state_dict["counter"] class CustomCallback(Callback): @@ -542,7 +563,21 @@ def custom_hook(self, trainer, pl_module): loop = CustomLoop() model = BoringModel() cb = CustomCallback() - trainer = Trainer(default_root_dir=tmpdir, callbacks=cb) - trainer.run_loop(model, loop=loop) + cb2 = ModelCheckpoint(dirpath=tmpdir, save_last=True) + trainer_kwargs = dict( + default_root_dir=tmpdir, max_epochs=1, limit_train_batches=5, limit_val_batches=0, callbacks=[cb, cb2] + ) + trainer = Trainer(**trainer_kwargs) + + with suppress(TestException): + trainer.run_loop(model, external_loop=loop) + + loop = CustomLoop() + model = BoringModel() + cb = CustomCallback() + trainer_kwargs["resume_from_checkpoint"] = cb2.last_model_path + trainer = Trainer(**trainer_kwargs) + trainer.run_loop(model, external_loop=loop) cb.has_called = True loop.counter = 5 + assert loop.has_restarted From de5da36183373c4daa20da3610eee2585e1374f0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 5 Aug 2021 12:04:16 +0200 Subject: [PATCH 13/17] update tests --- pl_examples/loops_customisation/k_fold.py | 3 +++ tests/loops/test_loops.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py index ccb5ad8585744..4548a9744e45e 100644 --- a/pl_examples/loops_customisation/k_fold.py +++ b/pl_examples/loops_customisation/k_fold.py @@ -192,6 +192,9 @@ def on_advance_end(self) -> None: self.increment_max_epochs(self.num_epochs) # stored best weight path for this fold self.best_model_paths.append(self.trainer.checkpoint_callback.best_model_path) + # bug: Should be reset + self.trainer.train_dataloader = None + self.trainer.val_dataloaders = None def on_save_checkpoint(self) -> Dict: return {"current_fold": self.current_fold} diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index ff6f36800106f..eeb859cd80662 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -541,8 +541,7 @@ def done(self): def advance(self, *args: Any, **kwargs: Any) -> None: self.trainer.call_hook("custom_hook") self.counter += 1 - self.reload_train_dataloader() - self.trainer.fit(self.lightning_module, train_dataloader=self.train_dataloader) + self.trainer.fit(self.trainer.lightning_module, train_dataloader=self.trainer.train_dataloader) if self.counter == 3: raise TestException From d86e7af2780309bdc3f86755b1292b4f6658b6d1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 5 Aug 2021 12:13:03 +0200 Subject: [PATCH 14/17] add some typing --- pl_examples/loops_customisation/k_fold.py | 44 +++++++++++------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py index 4548a9744e45e..d8267efb3b0c5 100644 --- a/pl_examples/loops_customisation/k_fold.py +++ b/pl_examples/loops_customisation/k_fold.py @@ -50,71 +50,71 @@ def __init__(self): self._processed_predict_dataset: Optional[Dataset] = None @property - def train_dataset(self): + def train_dataset(self) -> Optional[Dataset]: return self._train_dataset @property - def val_dataset(self): + def val_dataset(self) -> Optional[Dataset]: return self._val_dataset @property - def test_dataset(self): + def test_dataset(self) -> Optional[Dataset]: return self._test_dataset @property - def predict_dataset(self): + def predict_dataset(self) -> Optional[Dataset]: return self._predict_dataset @property - def processed_train_dataset(self): + def processed_train_dataset(self) -> Optional[Dataset]: return self._processed_train_dataset or self.train_dataset @property - def processed_val_dataset(self): + def processed_val_dataset(self) -> Optional[Dataset]: return self._processed_val_dataset or self.val_dataset @property - def processed_test_dataset(self): + def processed_test_dataset(self) -> Optional[Dataset]: return self._processed_test_dataset or self.test_dataset @property - def processed_predict_dataset(self): + def processed_predict_dataset(self) -> Optional[Dataset]: return self._processed_predict_dataset or self.predict_dataset @processed_train_dataset.setter - def processed_train_dataset(self, processed_train_dataset): + def processed_train_dataset(self, processed_train_dataset) -> None: self._processed_train_dataset = processed_train_dataset @processed_val_dataset.setter - def processed_val_dataset(self, processed_val_dataset): + def processed_val_dataset(self, processed_val_dataset) -> None: self._processed_val_dataset = processed_val_dataset @processed_val_dataset.setter - def processed_val_dataset(self, processed_val_dataset): + def processed_val_dataset(self, processed_val_dataset) -> None: self._processed_val_dataset = processed_val_dataset @processed_test_dataset.setter - def processed_test_dataset(self, processed_test_dataset): + def processed_test_dataset(self, processed_test_dataset) -> None: self._processed_test_dataset = processed_test_dataset - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: return DataLoader(self.processed_train_dataset) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: return DataLoader(self.processed_val_dataset) - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: return DataLoader(self.processed_test_dataset) - def predict_dataloader(self): + def predict_dataloader(self) -> DataLoader: return DataLoader(self.processed_predict_dataset) class BoringDataModule(BaseDataModule): - def prepare_data(self): + def prepare_data(self) -> None: self.random_full = RandomDataset(32, 64 * 4) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: Optional[str] = None) -> None: if stage == "fit" or stage is None: self._train_dataset = Subset(self.random_full, indices=range(64)) self.dims = self._train_dataset[0].shape @@ -166,13 +166,13 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: # temporary hack self.trainer.datamodule.setup("fit") - def process_dataset(self, stage: str, dataset: Dataset): + def process_dataset(self, stage: str, dataset: Dataset) -> Subset: kfold = KFold(self.num_folds, random_state=42, shuffle=True) train_indices, validation_indices = list(kfold.split(range(len(dataset))))[self.current_fold] indices = train_indices if stage == "train" else validation_indices return Subset(dataset, indices.tolist()) - def on_advance_start(self): + def on_advance_start(self) -> None: self.trainer.datamodule.processed_train_dataset = self.process_dataset( "train", self.trainer.datamodule.train_dataset ) @@ -180,7 +180,7 @@ def on_advance_start(self): self.trainer.call_hook("on_fold_start", self.current_fold) self.trainer.lightning_module.reset_parameters() - def advance(self): + def advance(self) -> Any: return self.trainer.fit( self.trainer.lightning_module, train_dataloader=self.trainer.train_dataloader, @@ -208,7 +208,7 @@ class KFoldCallback(KFoldLoop.loop_base_callback()): """This callback demonstrates how to implement your own callback API.""" @rank_zero_only - def on_fold_start(self, trainer, pl_module, counter): + def on_fold_start(self, trainer, pl_module, counter) -> None: log.info(f"Starting to train on fold {counter}") From 4b42c07718c5cc693a7db5936a7947c8bdef150f Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 5 Aug 2021 12:15:04 +0200 Subject: [PATCH 15/17] update --- pytorch_lightning/trainer/data_loading.py | 61 ----------------------- 1 file changed, 61 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 29550bd9cc2cd..361d64569505d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -252,67 +252,6 @@ def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[Runnin dataloader = dl_cls(**dl_kwargs) return dataloader - def apply_user_function(self, dataloader: DataLoader, user_function: Callable) -> DataLoader: - if not isinstance(dataloader, DataLoader): - raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") - - # get the dataloader instance attributes - attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} - # not part of `vars` - attrs["multiprocessing_context"] = dataloader.multiprocessing_context - - # get the dataloader instance `__init__` parameters - params = dict(inspect.signature(dataloader.__init__).parameters) - - # keep only the params whose default is different to the current attr value - non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]} - # add `dataset` as it might have been replaced with `*args` - non_defaults.add("dataset") - - # kwargs to re-construct the dataloader - dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} - dl_kwargs.update(self._resolve_batch_sampler(dataloader, dataloader.sampler, mode=RunningStage.translate)) - - required_args = { - p.name - for p in params.values() - if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) - and p.default is p.empty - and p.name not in dl_kwargs - } - # the dataloader has required args which we could not extract from the existing attributes - if required_args: - required_args = sorted(required_args) - dataloader_cls_name = dataloader.__class__.__name__ - raise MisconfigurationException( - f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " - "This would fail as some of the `__init__` arguments are not available as instance attributes. " - f"The missing attributes are {required_args}. " - f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or " - "manually add the `DistributedSampler` as: " - f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." - ) - - has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) - if not has_variadic_kwargs: - # the dataloader signature does not allow keyword arguments that need to be passed - missing_kwargs = dl_kwargs.keys() - params.keys() - if missing_kwargs: - missing_kwargs = sorted(missing_kwargs) - dataloader_cls_name = dataloader.__class__.__name__ - raise MisconfigurationException( - f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " - "This would fail as it doesn't expose all its attributes in the `__init__` signature. " - f"The missing arguments are {missing_kwargs}. " - f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or " - "manually add the `DistributedSampler` as: " - f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." - ) - - dl_cls = type(dataloader) - dataloader = dl_cls(**user_function(dl_kwargs)) - return dataloader - def _get_distributed_sampler( self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None ) -> DistributedSampler: From f853b60bee67800e5fd5acf78c7f0fcca8add709 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 9 Aug 2021 13:02:05 +0200 Subject: [PATCH 16/17] update on comments --- pl_examples/loops_customisation/k_fold.py | 44 +++++---- pytorch_lightning/loops/base.py | 20 ----- pytorch_lightning/loops/external_loop.py | 104 ++++++++++++++++++++++ pytorch_lightning/trainer/trainer.py | 66 -------------- tests/loops/test_loops.py | 10 +-- 5 files changed, 130 insertions(+), 114 deletions(-) create mode 100644 pytorch_lightning/loops/external_loop.py diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py index d8267efb3b0c5..df6e12907a04b 100644 --- a/pl_examples/loops_customisation/k_fold.py +++ b/pl_examples/loops_customisation/k_fold.py @@ -19,14 +19,15 @@ from typing import Any, Dict, List, Optional, Type +import numpy as np from sklearn.model_selection import KFold from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset, Subset from pytorch_lightning import _logger as log -from pytorch_lightning import LightningDataModule, seed_everything, Trainer +from pytorch_lightning import LightningDataModule, seed_everything from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.loops.base import ExternalLoop +from pytorch_lightning.loops.external_loop import ExternalLoop from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.boring_model import BoringModel, RandomDataset @@ -135,12 +136,11 @@ class KFoldLoop(ExternalLoop): def __init__( self, num_folds: int, - num_epochs: int = 10, best_model_paths: List[str] = [], restarting: bool = False, ): + super().__init__() self.num_folds = num_folds - self.num_epochs = num_epochs self.best_model_paths = best_model_paths self.restarting = restarting @@ -160,42 +160,40 @@ def done(self) -> bool: def reset(self) -> None: if not self.restarting: self.current_fold = 0 - self.set_max_epochs(self.num_epochs) def on_run_start(self, *args: Any, **kwargs: Any) -> None: # temporary hack self.trainer.datamodule.setup("fit") - def process_dataset(self, stage: str, dataset: Dataset) -> Subset: - kfold = KFold(self.num_folds, random_state=42, shuffle=True) - train_indices, validation_indices = list(kfold.split(range(len(dataset))))[self.current_fold] - indices = train_indices if stage == "train" else validation_indices - return Subset(dataset, indices.tolist()) - def on_advance_start(self) -> None: - self.trainer.datamodule.processed_train_dataset = self.process_dataset( - "train", self.trainer.datamodule.train_dataset - ) - self.trainer.datamodule.processed_val_dataset = self.process_dataset("val", self.trainer.datamodule.val_dataset) + # re-create a new trainer + self.create_trainer(max_epochs=np.random.randint(10)) + dm = self.trainer.datamodule + + dm.processed_train_dataset = self.process_dataset("train", dm.train_dataset) + dm.processed_val_dataset = self.process_dataset("val", dm.val_dataset) self.trainer.call_hook("on_fold_start", self.current_fold) self.trainer.lightning_module.reset_parameters() def advance(self) -> Any: - return self.trainer.fit( - self.trainer.lightning_module, - train_dataloader=self.trainer.train_dataloader, - val_dataloaders=self.trainer.val_dataloaders, - ) + # dataloaders will be automatically reloaded + return self.trainer.fit(self.trainer.lightning_module, datamodule=self.trainer.datamodule) def on_advance_end(self) -> None: self.current_fold += 1 - self.increment_max_epochs(self.num_epochs) # stored best weight path for this fold self.best_model_paths.append(self.trainer.checkpoint_callback.best_model_path) # bug: Should be reset self.trainer.train_dataloader = None self.trainer.val_dataloaders = None + # utilities for creating a hold + def process_dataset(self, stage: str, dataset: Dataset) -> Subset: + kfold = KFold(self.num_folds, random_state=42, shuffle=True) + train_indices, validation_indices = list(kfold.split(range(len(dataset))))[self.current_fold] + indices = train_indices if stage == "train" else validation_indices + return Subset(dataset, indices.tolist()) + def on_save_checkpoint(self) -> Dict: return {"current_fold": self.current_fold} @@ -215,5 +213,5 @@ def on_fold_start(self, trainer, pl_module, counter) -> None: loop = KFoldLoop(5) model = BoringModel() datamodule = BoringDataModule() -trainer = Trainer(callbacks=KFoldCallback()) -trainer.run_loop(model, datamodule=datamodule, external_loop=loop) +loop.connect_trainer(max_epochs=10, callbacks=KFoldCallback()) +loop.run(model, datamodule=datamodule) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index a9e21bb3c205b..d4e29bce9a38f 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -241,23 +241,3 @@ def _load_from_state_dict( self.on_load_checkpoint(state_dict[prefix + "state_dict"]) self.restarting = True - - -class ExternalLoop(Loop): - """This Loop is meant wrap trainer calls""" - - def __init__(self): - super().__init__() - warning_cache.warn("The ExternalLoop API is a `pre-alpha release` and breaking API changes are expected.") - - def set_max_epochs(self, max_epochs: int): - self.trainer.fit_loop.max_epochs = max_epochs - - def increment_max_epochs(self, max_epochs: int): - self.trainer.fit_loop.max_epochs += max_epochs - - def set_max_steps(self, max_steps: int): - self.trainer.fit_loop.max_steps = max_steps - - def increment_max_steps(self, max_steps: int): - self.trainer.fit_loop.max_steps += max_steps diff --git a/pytorch_lightning/loops/external_loop.py b/pytorch_lightning/loops/external_loop.py new file mode 100644 index 0000000000000..28782a463a680 --- /dev/null +++ b/pytorch_lightning/loops/external_loop.py @@ -0,0 +1,104 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +from typing import Any, Callable, Dict, Optional + +import pytorch_lightning as pl +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() + + +class ExternalLoop(Loop): + """This Loop is meant wrap trainer calls""" + + def __init__(self): + super().__init__() + warning_cache.warn("The ExternalLoop API is a `pre-alpha release` and breaking API changes are expected.") + self.create_trainer = self._wrap_trainer_wrapper(self.create_trainer) + self._has_setup = False + self._restore_external_loop = True + + def _wrap_trainer_wrapper(self, create_trainer: Callable) -> Callable: + @functools.wraps(create_trainer) + def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: + trainer = create_trainer(*args, trainer_kwargs=self.trainer_kwargs, **kwargs) + if not isinstance(trainer, pl.Trainer): + raise MisconfigurationException("The `create_trainer` hook should return a Trainer") + self.trainer = trainer + self.trainer.external_loop = self + + self.trainer.accelerator.connect(self.__lightning_module) + + # links data to the trainer + self.trainer.data_connector.attach_data( + self.trainer.lightning_module, + train_dataloaders=self.__train_dataloader, + val_dataloaders=self.__val_dataloaders, + test_dataloaders=self.__test_dataloaders, + predict_dataloaders=self.__predict_dataloaders, + datamodule=self.__datamodule, + ) + + # attach model to the training type plugin + self.trainer.data_connector.prepare_data() + + self.trainer.checkpoint_connector.resume_start() + self.trainer.checkpoint_connector.restore_loops(restore_external_loop=self._restore_external_loop) + return trainer + + return wrapped_func + + def connect_trainer(self, **trainer_kwargs: Dict[str, Any]) -> None: + self.trainer_kwargs = trainer_kwargs + + def create_trainer(self, *args, trainer_kwargs: Dict[str, Any] = {}, **kwargs) -> "pl.Trainer": + trainer_kwargs.update(kwargs) + return pl.Trainer(*args, **trainer_kwargs) + + def run( + self, + model: "pl.LightningModule", + train_dataloader=None, + val_dataloaders=None, + test_dataloaders=None, + predict_dataloaders=None, + datamodule=None, + ): + + self.__lightning_module = model + self.__train_dataloader = train_dataloader + self.__val_dataloaders = val_dataloaders + self.__test_dataloaders = test_dataloaders + self.__predict_dataloaders = predict_dataloaders + self.__datamodule = datamodule + + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, pl.LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + + if train_dataloader is not None and datamodule: + raise MisconfigurationException("You cannot pass both `loop.run(dataloaders=..., datamodule=...)`") + + if model is None: + raise MisconfigurationException("`model` must be provided to `loop.run()`") + + if self._trainer is None: + self.create_trainer() + self._restore_external_loop = False + + return super().run() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 26d71258002c3..a6d9f348abc63 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,7 +29,6 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop -from pytorch_lightning.loops.base import ExternalLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop from pytorch_lightning.loops.fit_loop import FitLoop @@ -836,71 +835,6 @@ def tune( return result - def run_loop( - self, - model: "pl.LightningModule", - train_dataloader: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - test_dataloaders: Optional[EVAL_DATALOADERS] = None, - predict_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - external_loop: ExternalLoop = None, - ): - - # -------------------- - # SETUP HOOK - # -------------------- - # FIXME: hack to not break - self.state.fn = TrainerFn.FITTING - self.state.status = TrainerStatus.RUNNING - self.training = True - - Trainer._log_api_event("run_loop") - - # if a datamodule comes in as the second arg, then fix it for the user - if isinstance(train_dataloader, LightningDataModule): - datamodule = train_dataloader - train_dataloader = None - if train_dataloader is not None and datamodule: - raise MisconfigurationException("You cannot pass both `trainer.run_loop(dataloaders=..., datamodule=...)`") - - if external_loop is None or not isinstance(external_loop, ExternalLoop): - raise MisconfigurationException( - "You should provide an `ExternalLoop` object as `trainer.run_loop(loop=...)`" - ) - - model = model or self.lightning_module - if model is None: - raise MisconfigurationException( - "`model` must be provided to `trainer.predict()` when it hasn't been passed in a previous run" - ) - - # links data to the trainer - self.data_connector.attach_data( - model, - train_dataloaders=train_dataloader, - val_dataloaders=val_dataloaders, - test_dataloaders=test_dataloaders, - predict_dataloaders=predict_dataloaders, - datamodule=datamodule, - ) - - # connect loop and trainer - external_loop.trainer = self - self.external_loop = external_loop - - # attach model to the training type plugin - self.accelerator.connect(model) - self.data_connector.prepare_data() - - self.checkpoint_connector.resume_start() - self.checkpoint_connector.restore_loops(restore_external_loop=True) - - results = external_loop.run() - - del self.external_loop - return results - def _restore_modules_and_callbacks(self) -> None: # restore modules after setup if self.state.fn == TrainerFn.FITTING: diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index eeb859cd80662..c82a9d604437d 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -26,7 +26,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.loops import Loop, TrainingBatchLoop -from pytorch_lightning.loops.base import ExternalLoop +from pytorch_lightning.loops.external_loop import ExternalLoop from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 from tests.helpers import BoringModel @@ -566,17 +566,17 @@ def custom_hook(self, trainer, pl_module): trainer_kwargs = dict( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=5, limit_val_batches=0, callbacks=[cb, cb2] ) - trainer = Trainer(**trainer_kwargs) + loop.connect_trainer(**trainer_kwargs) with suppress(TestException): - trainer.run_loop(model, external_loop=loop) + loop.run(model) loop = CustomLoop() model = BoringModel() cb = CustomCallback() trainer_kwargs["resume_from_checkpoint"] = cb2.last_model_path - trainer = Trainer(**trainer_kwargs) - trainer.run_loop(model, external_loop=loop) + loop.connect_trainer(**trainer_kwargs) + loop.run(model) cb.has_called = True loop.counter = 5 assert loop.has_restarted From 8d667f121c5889aaa9bfa06c9e801840c03f1419 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 9 Aug 2021 13:06:35 +0200 Subject: [PATCH 17/17] update --- pl_examples/loops_customisation/k_fold.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pl_examples/loops_customisation/k_fold.py b/pl_examples/loops_customisation/k_fold.py index df6e12907a04b..52e9439193454 100644 --- a/pl_examples/loops_customisation/k_fold.py +++ b/pl_examples/loops_customisation/k_fold.py @@ -166,13 +166,15 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: self.trainer.datamodule.setup("fit") def on_advance_start(self) -> None: - # re-create a new trainer + # more reproducible as re-creating a different trainer. self.create_trainer(max_epochs=np.random.randint(10)) + # reload dataset for the current fold dm = self.trainer.datamodule - dm.processed_train_dataset = self.process_dataset("train", dm.train_dataset) dm.processed_val_dataset = self.process_dataset("val", dm.val_dataset) + # call user hook self.trainer.call_hook("on_fold_start", self.current_fold) + # reset model parameters self.trainer.lightning_module.reset_parameters() def advance(self) -> Any: @@ -183,9 +185,6 @@ def on_advance_end(self) -> None: self.current_fold += 1 # stored best weight path for this fold self.best_model_paths.append(self.trainer.checkpoint_callback.best_model_path) - # bug: Should be reset - self.trainer.train_dataloader = None - self.trainer.val_dataloaders = None # utilities for creating a hold def process_dataset(self, stage: str, dataset: Dataset) -> Subset: