diff --git a/CHANGELOG.md b/CHANGELOG.md index efa420bebcd8c..5662901d1cb81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -213,6 +213,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994)) * Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010)) * Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009)) + * Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023)) + ### Changed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index f23e01c5fdd10..28145db13be00 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -314,16 +314,25 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: return closure_loss - def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None: + def optimizer_step( + self, + optimizer: Optimizer, + opt_idx: int, + lambda_closure: Callable[[], Any], + model: Optional[Union["pl.LightningModule", Module]] = None, + **kwargs: Any + ) -> None: """performs the actual optimizer step. Args: optimizer: the optimizer performing the step opt_idx: index of the current optimizer lambda_closure: closure calculating the loss value + model: reference to the model, optionally defining optimizer step related hooks """ + model = model or self.lightning_module make_optimizer_step = self.precision_plugin.pre_optimizer_step( - self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs + model, optimizer, opt_idx, lambda_closure, **kwargs ) if make_optimizer_step: self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 496dd47933e82..0d6bc02f2ba71 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence, Union import torch from torch import Tensor +from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl @@ -97,7 +98,7 @@ def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Seq def pre_optimizer_step( self, - model: "pl.LightningModule", + model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, @@ -112,7 +113,7 @@ def pre_optimizer_step( super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) skipped_backward = result is None # in manual optimization, the closure does not return a value - if not model.automatic_optimization or not skipped_backward: + if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. optimizer.step(**kwargs) return False diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index ff33ccc690ef2..bd92607fd3b17 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -48,7 +48,7 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any def pre_optimizer_step( self, - model: "pl.LightningModule", + model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, @@ -63,12 +63,12 @@ def pre_optimizer_step( super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) skipped_backward = result is None # in manual optimization, the closure does not return a value - if model.automatic_optimization and skipped_backward: + if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: raise MisconfigurationException( "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`" ) # DeepSpeed handles the optimizer step internally - deepspeed_engine = model.trainer.model + deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model deepspeed_engine.step() return False diff --git a/pytorch_lightning/plugins/precision/ipu_precision.py b/pytorch_lightning/plugins/precision/ipu_precision.py index f8dd77dcefcbe..092ba56ad44f6 100644 --- a/pytorch_lightning/plugins/precision/ipu_precision.py +++ b/pytorch_lightning/plugins/precision/ipu_precision.py @@ -40,7 +40,7 @@ def backward(self, model: "pl.LightningModule", *args: Any, **kwargs: Any) -> No def pre_optimizer_step( self, - model: "pl.LightningModule", + model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable[[], Any], @@ -55,7 +55,7 @@ def pre_optimizer_step( closure_result = lambda_closure() skipped_backward = closure_result is None # in manual optimization, the closure does not return a value - if model.automatic_optimization and skipped_backward: + if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: # we lack coverage here and IPUs are (currently) limited - something to explore if there's demand raise MisconfigurationException( "Skipping backward by returning `None` from your `training_step` is not implemented for IPUs." diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 646c38f763fc4..83f639becde58 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -66,7 +66,7 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any def pre_optimizer_step( self, - model: "pl.LightningModule", + model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, @@ -84,7 +84,7 @@ def pre_optimizer_step( super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) skipped_backward = result is None # in manual optimization, the closure does not return a value - if not model.automatic_optimization or not skipped_backward: + if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found self.scaler.step(optimizer) self.scaler.update() diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 9ec127886396c..dc378e9cb195c 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -99,14 +99,15 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any def pre_optimizer_step( self, - model: "pl.LightningModule", + model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any, ) -> bool: """Hook to do something before each optimizer step.""" - model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) + if isinstance(model, pl.LightningModule): + model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) return True def clip_gradients( diff --git a/pytorch_lightning/plugins/precision/tpu.py b/pytorch_lightning/plugins/precision/tpu.py index dc4c7c856cbc2..b6bed35f5944e 100644 --- a/pytorch_lightning/plugins/precision/tpu.py +++ b/pytorch_lightning/plugins/precision/tpu.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable +from typing import Any, Callable, Union +from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl @@ -27,7 +28,7 @@ class TPUPrecisionPlugin(PrecisionPlugin): def pre_optimizer_step( self, - model: "pl.LightningModule", + model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable[[], Any], @@ -37,7 +38,7 @@ def pre_optimizer_step( closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": lambda_closure, **kwargs}) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value - if model.automatic_optimization and skipped_backward: + if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: # we lack coverage here so disable this - something to explore if there's demand raise MisconfigurationException( "Skipping backward by returning `None` from your `training_step` is not implemented for TPUs." diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index ffc96d4a7ece7..e28a7a963acee 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -250,7 +250,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"): return trainer.init_optimizers(model) - def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): + def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs: Any) -> None: optimizer.step(closure=lambda_closure, **kwargs) @property