Skip to content

Commit d419028

Browse files
awaelchlicarmocca
andauthored
Update optimizer_step methods in accelerator and plugins (#10023)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 5b30a65 commit d419028

File tree

9 files changed

+32
-18
lines changed

9 files changed

+32
-18
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
213213
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
214214
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
215215
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
216+
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))
217+
216218

217219
### Changed
218220

pytorch_lightning/accelerators/accelerator.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,16 +314,25 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
314314

315315
return closure_loss
316316

317-
def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
317+
def optimizer_step(
318+
self,
319+
optimizer: Optimizer,
320+
opt_idx: int,
321+
lambda_closure: Callable[[], Any],
322+
model: Optional[Union["pl.LightningModule", Module]] = None,
323+
**kwargs: Any
324+
) -> None:
318325
"""performs the actual optimizer step.
319326
320327
Args:
321328
optimizer: the optimizer performing the step
322329
opt_idx: index of the current optimizer
323330
lambda_closure: closure calculating the loss value
331+
model: reference to the model, optionally defining optimizer step related hooks
324332
"""
333+
model = model or self.lightning_module
325334
make_optimizer_step = self.precision_plugin.pre_optimizer_step(
326-
self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs
335+
model, optimizer, opt_idx, lambda_closure, **kwargs
327336
)
328337
if make_optimizer_step:
329338
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, Dict, Optional, Sequence
14+
from typing import Any, Callable, Dict, Optional, Sequence, Union
1515

1616
import torch
1717
from torch import Tensor
18+
from torch.nn import Module
1819
from torch.optim import LBFGS, Optimizer
1920

2021
import pytorch_lightning as pl
@@ -97,7 +98,7 @@ def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Seq
9798

9899
def pre_optimizer_step(
99100
self,
100-
model: "pl.LightningModule",
101+
model: Union["pl.LightningModule", Module],
101102
optimizer: Optimizer,
102103
optimizer_idx: int,
103104
lambda_closure: Callable,
@@ -112,7 +113,7 @@ def pre_optimizer_step(
112113
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
113114
skipped_backward = result is None
114115
# in manual optimization, the closure does not return a value
115-
if not model.automatic_optimization or not skipped_backward:
116+
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
116117
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
117118
optimizer.step(**kwargs)
118119
return False

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any
4848

4949
def pre_optimizer_step(
5050
self,
51-
model: "pl.LightningModule",
51+
model: Union["pl.LightningModule", Module],
5252
optimizer: Optimizer,
5353
optimizer_idx: int,
5454
lambda_closure: Callable,
@@ -63,12 +63,12 @@ def pre_optimizer_step(
6363
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
6464
skipped_backward = result is None
6565
# in manual optimization, the closure does not return a value
66-
if model.automatic_optimization and skipped_backward:
66+
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
6767
raise MisconfigurationException(
6868
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
6969
)
7070
# DeepSpeed handles the optimizer step internally
71-
deepspeed_engine = model.trainer.model
71+
deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model
7272
deepspeed_engine.step()
7373
return False
7474

pytorch_lightning/plugins/precision/ipu_precision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def backward(self, model: "pl.LightningModule", *args: Any, **kwargs: Any) -> No
4040

4141
def pre_optimizer_step(
4242
self,
43-
model: "pl.LightningModule",
43+
model: Union["pl.LightningModule", Module],
4444
optimizer: Optimizer,
4545
optimizer_idx: int,
4646
lambda_closure: Callable[[], Any],
@@ -55,7 +55,7 @@ def pre_optimizer_step(
5555
closure_result = lambda_closure()
5656
skipped_backward = closure_result is None
5757
# in manual optimization, the closure does not return a value
58-
if model.automatic_optimization and skipped_backward:
58+
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
5959
# we lack coverage here and IPUs are (currently) limited - something to explore if there's demand
6060
raise MisconfigurationException(
6161
"Skipping backward by returning `None` from your `training_step` is not implemented for IPUs."

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any
6666

6767
def pre_optimizer_step(
6868
self,
69-
model: "pl.LightningModule",
69+
model: Union["pl.LightningModule", Module],
7070
optimizer: Optimizer,
7171
optimizer_idx: int,
7272
lambda_closure: Callable,
@@ -84,7 +84,7 @@ def pre_optimizer_step(
8484
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
8585
skipped_backward = result is None
8686
# in manual optimization, the closure does not return a value
87-
if not model.automatic_optimization or not skipped_backward:
87+
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
8888
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
8989
self.scaler.step(optimizer)
9090
self.scaler.update()

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,15 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any
9999

100100
def pre_optimizer_step(
101101
self,
102-
model: "pl.LightningModule",
102+
model: Union["pl.LightningModule", Module],
103103
optimizer: Optimizer,
104104
optimizer_idx: int,
105105
lambda_closure: Callable,
106106
**kwargs: Any,
107107
) -> bool:
108108
"""Hook to do something before each optimizer step."""
109-
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
109+
if isinstance(model, pl.LightningModule):
110+
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
110111
return True
111112

112113
def clip_gradients(

pytorch_lightning/plugins/precision/tpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable
14+
from typing import Any, Callable, Union
1515

16+
from torch.nn import Module
1617
from torch.optim import Optimizer
1718

1819
import pytorch_lightning as pl
@@ -27,7 +28,7 @@
2728
class TPUPrecisionPlugin(PrecisionPlugin):
2829
def pre_optimizer_step(
2930
self,
30-
model: "pl.LightningModule",
31+
model: Union["pl.LightningModule", Module],
3132
optimizer: Optimizer,
3233
optimizer_idx: int,
3334
lambda_closure: Callable[[], Any],
@@ -37,7 +38,7 @@ def pre_optimizer_step(
3738
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": lambda_closure, **kwargs})
3839
skipped_backward = closure_result is None
3940
# in manual optimization, the closure does not return a value
40-
if model.automatic_optimization and skipped_backward:
41+
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
4142
# we lack coverage here so disable this - something to explore if there's demand
4243
raise MisconfigurationException(
4344
"Skipping backward by returning `None` from your `training_step` is not implemented for TPUs."

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
250250
def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
251251
return trainer.init_optimizers(model)
252252

253-
def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
253+
def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs: Any) -> None:
254254
optimizer.step(closure=lambda_closure, **kwargs)
255255

256256
@property

0 commit comments

Comments
 (0)