Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0a2aee2
all changes
awaelchli Oct 19, 2021
b10017e
changelog
awaelchli Oct 19, 2021
8401611
remove
awaelchli Oct 19, 2021
3ba3bcf
update tpu
awaelchli Oct 19, 2021
295137e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 19, 2021
5c6a789
Merge branch 'master' into lightning-lite/refactors/optimizer-step
awaelchli Oct 19, 2021
98b27a4
Isolate optimizer step logic to the `PrecisionPlugin`
carmocca Oct 19, 2021
4391425
Update
carmocca Oct 19, 2021
cb258ab
Docs
carmocca Oct 19, 2021
6fff125
Docs
carmocca Oct 19, 2021
84ee13e
Add test
carmocca Oct 19, 2021
50e6e1b
Undo changes
carmocca Oct 19, 2021
4a5d336
Update error
carmocca Oct 20, 2021
b35210a
Merge branch 'master' into refactor/constain-step-logic
carmocca Oct 20, 2021
a9e9229
Merge branch 'master' into lightning-lite/refactors/optimizer-step
awaelchli Oct 20, 2021
877823d
revert removal of opt_idx
awaelchli Oct 20, 2021
f2801b2
remove unused arguments
awaelchli Oct 20, 2021
d38e98a
update all type hints
awaelchli Oct 20, 2021
18b09bc
mypy
awaelchli Oct 20, 2021
c15504f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2021
80a7eb5
none checks and signature fix (mypy)
awaelchli Oct 20, 2021
e363968
updates
awaelchli Oct 20, 2021
c700c12
update ipu
awaelchli Oct 20, 2021
a9d31d0
Merge branch 'refactor/constain-step-logic' into lightning-lite/refac…
awaelchli Oct 20, 2021
265189e
Merge branch 'master' into refactor/constain-step-logic
awaelchli Oct 20, 2021
86ff634
Merge branch 'refactor/constain-step-logic' into lightning-lite/refac…
awaelchli Oct 20, 2021
d0c3fa4
update carlos
awaelchli Oct 20, 2021
5b6233c
remove unused
awaelchli Oct 20, 2021
9df4887
Update pytorch_lightning/plugins/precision/ipu_precision.py
awaelchli Oct 20, 2021
8866a08
Update pytorch_lightning/accelerators/accelerator.py
awaelchli Oct 20, 2021
e602f00
Update pytorch_lightning/accelerators/accelerator.py
awaelchli Oct 20, 2021
5610612
update deepspeed logic for Lite
awaelchli Oct 20, 2021
0f02616
Merge branch 'master' into lightning-lite/refactors/optimizer-step
carmocca Oct 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))


### Changed

Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,25 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:

return closure_loss

def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
def optimizer_step(
self,
optimizer: Optimizer,
opt_idx: int,
lambda_closure: Callable[[], Any],
model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any
) -> None:
"""performs the actual optimizer step.

Args:
optimizer: the optimizer performing the step
opt_idx: index of the current optimizer
lambda_closure: closure calculating the loss value
model: reference to the model, optionally defining optimizer step related hooks
"""
model = model or self.lightning_module
make_optimizer_step = self.precision_plugin.pre_optimizer_step(
self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs
model, optimizer, opt_idx, lambda_closure, **kwargs
)
if make_optimizer_step:
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, Optional, Sequence, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer

import pytorch_lightning as pl
Expand Down Expand Up @@ -97,7 +98,7 @@ def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Seq

def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
Expand All @@ -112,7 +113,7 @@ def pre_optimizer_step(
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
skipped_backward = result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
optimizer.step(**kwargs)
return False
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any

def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
Expand All @@ -63,12 +63,12 @@ def pre_optimizer_step(
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
skipped_backward = result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward:
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
raise MisconfigurationException(
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
)
# DeepSpeed handles the optimizer step internally
deepspeed_engine = model.trainer.model
deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model
deepspeed_engine.step()
return False

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/ipu_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def backward(self, model: "pl.LightningModule", *args: Any, **kwargs: Any) -> No

def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable[[], Any],
Expand All @@ -55,7 +55,7 @@ def pre_optimizer_step(
closure_result = lambda_closure()
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward:
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
# we lack coverage here and IPUs are (currently) limited - something to explore if there's demand
raise MisconfigurationException(
"Skipping backward by returning `None` from your `training_step` is not implemented for IPUs."
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any

def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
Expand All @@ -84,7 +84,7 @@ def pre_optimizer_step(
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
skipped_backward = result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
self.scaler.step(optimizer)
self.scaler.update()
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,15 @@ def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any

def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
if isinstance(model, pl.LightningModule):
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
return True

def clip_gradients(
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/plugins/precision/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable
from typing import Any, Callable, Union

from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand All @@ -27,7 +28,7 @@
class TPUPrecisionPlugin(PrecisionPlugin):
def pre_optimizer_step(
self,
model: "pl.LightningModule",
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable[[], Any],
Expand All @@ -37,7 +38,7 @@ def pre_optimizer_step(
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": lambda_closure, **kwargs})
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward:
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
# we lack coverage here so disable this - something to explore if there's demand
raise MisconfigurationException(
"Skipping backward by returning `None` from your `training_step` is not implemented for TPUs."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
return trainer.init_optimizers(model)

def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs: Any) -> None:
optimizer.step(closure=lambda_closure, **kwargs)

@property
Expand Down