Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ecca4f5
Move src/pytorch_lightning/lite to src/lightning_lite (#14735)
awaelchli Sep 17, 2022
749f4c0
Add backward-compatibility for LightningLite in PL (#14735)
awaelchli Sep 19, 2022
2d48318
integrate
awaelchli Sep 20, 2022
c10c447
precision
awaelchli Sep 20, 2022
80a7e9e
integrate
awaelchli Sep 20, 2022
af35652
Merge branch 'master' into lite/integrate-precision
awaelchli Sep 20, 2022
3302aac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2022
02f9d25
fix imports
awaelchli Sep 20, 2022
c6ac72c
Merge branch 'master' into lite/integrate-precision
awaelchli Sep 20, 2022
90d6843
remove reference
awaelchli Sep 20, 2022
ae6530d
revert circular import
awaelchli Sep 20, 2022
f5c0b39
Merge branch 'master' into lite/integrate-precision
awaelchli Sep 20, 2022
0822739
update changelog
awaelchli Sep 20, 2022
977f175
update changelog
awaelchli Sep 20, 2022
d3728c8
Merge branch 'master' into lite/integrate-precision
carmocca Sep 20, 2022
0c6f213
Remove model from optimizer step
carmocca Sep 20, 2022
0367575
Structural typing
carmocca Sep 20, 2022
106f5ab
Comment
carmocca Sep 20, 2022
5005a36
update the tests
awaelchli Sep 20, 2022
d57aa18
fix deepspeed test
awaelchli Sep 20, 2022
1ae4806
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2022
eb27a7b
Merge branch 'lite/remove-model-opt-step' into lite/integrate-precision
awaelchli Sep 21, 2022
80ab2aa
wip
awaelchli Sep 21, 2022
1c10b59
Merge branch 'master' into lite/integrate-precision
awaelchli Sep 21, 2022
f859ce4
update types
awaelchli Sep 21, 2022
1447723
address review
awaelchli Sep 21, 2022
faf6f27
re-order
awaelchli Sep 21, 2022
7f96a43
update changelog
awaelchli Sep 21, 2022
6ed0fca
revert api change
awaelchli Sep 21, 2022
1397cf5
update
awaelchli Sep 21, 2022
31f52a2
address review
awaelchli Sep 22, 2022
b9ed882
Merge branch 'master' into lite/integrate-precision
Borda Sep 22, 2022
106996c
Merge branch 'master' into lite/integrate-precision
carmocca Sep 22, 2022
a2a7927
Mention breaking change
carmocca Sep 22, 2022
d7dc7af
Minor docstring and mypy changes
carmocca Sep 22, 2022
04ea80f
Option (d)
carmocca Sep 22, 2022
2a6fb59
update changelog
awaelchli Sep 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ precision
FullyShardedNativeNativeMixedPrecisionPlugin
HPUPrecisionPlugin
IPUPrecisionPlugin
MixedPrecisionPlugin
NativeMixedPrecisionPlugin
PrecisionPlugin
ShardedNativeMixedPrecisionPlugin
Expand Down
1 change: 0 additions & 1 deletion docs/source-pytorch/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ The full list of built-in precision plugins is listed below.
FullyShardedNativeNativeMixedPrecisionPlugin
HPUPrecisionPlugin
IPUPrecisionPlugin
MixedPrecisionPlugin
NativeMixedPrecisionPlugin
PrecisionPlugin
ShardedNativeMixedPrecisionPlugin
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_lite/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning_lite.plugins.precision.precision import Precision # isort:skip
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
from lightning_lite.plugins.precision.double import DoublePrecision
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.tpu import TPUPrecision
from lightning_lite.plugins.precision.tpu_bf16 import TPUBf16Precision

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_lite/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from lightning_lite.plugins.precision import Precision
from lightning_lite.plugins.precision.precision import Precision


class DoublePrecision(Precision):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_lite/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn import Module
from torch.optim import LBFGS

from lightning_lite.plugins.precision import Precision
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
from lightning_lite.utilities.types import Steppable

Expand Down
6 changes: 3 additions & 3 deletions src/lightning_lite/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward_context(self) -> Generator[None, None, None]:
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
yield

def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> None:
def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> Any:
"""Runs before precision plugin executes backward.

Args:
Expand All @@ -51,7 +51,7 @@ def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs
"""
tensor.backward(*args, **kwargs)

def post_backward(self, tensor: Tensor, module: Optional[Module]) -> None:
def post_backward(self, tensor: Tensor, module: Optional[Module]) -> Any:
"""Runs after precision plugin executes backward.

Args:
Expand All @@ -67,7 +67,7 @@ def optimizer_step(
"""Hook to run the optimizer step."""
return optimizer.step(**kwargs)

def get_main_params(self, optimizer: Optimizer) -> _PARAMETERS:
def main_params(self, optimizer: Optimizer) -> _PARAMETERS:
"""The main params of the model.

Returns the plain model params here. Maybe different in other precision plugins.
Expand Down
7 changes: 7 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- In Lightning Lite, state-dict access to the module wrapper now gets passed through to the original module reference ([#14629](https://github.com/Lightning-AI/lightning/pull/14629))


- Removed fall-back to `LightningEnvironment` when number of SLURM tasks does not correspond to number of processes in Trainer ([#14300](https://github.com/Lightning-AI/lightning/pull/14300))


- Integrated the Lite Precision plugins into the PL Precision plugins - the base class in PL now extends the `lightning_lite.precision.Precision` base class ([#14798](https://github.com/Lightning-AI/lightning/pull/14798))
* The `PrecisionPlugin.backward` signature changed: The `closure_loss` argument was renamed to `tensor`
* The `PrecisionPlugin.{pre_,post_}backward` signature changed: The `closure_loss` argument was renamed to `tensor` and moved as the first argument
* The `PrecisionPlugin.optimizer_step` signature changed: The `model`, `optimizer_idx` and `closure` arguments need to be passed as keyword arguments now


- Trainer queries the CUDA devices through NVML if available to avoid initializing CUDA before forking, which eliminates the need for the `PL_DISABLE_FORK` environment variable introduced in v1.7.4 ([#14631](https://github.com/Lightning-AI/lightning/issues/14631))


Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning_lite.utilities.distributed import distributed_available, sync_ddp
from lightning_lite.utilities.types import Steppable
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import HyperparametersMixin
Expand Down Expand Up @@ -1398,7 +1399,7 @@ def training_step(...):
self.trainer.strategy.backward(loss, None, None, *args, **kwargs)

def backward(
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args: Any, **kwargs: Any
self, loss: Tensor, optimizer: Optional[Steppable], optimizer_idx: Optional[int], *args: Any, **kwargs: Any
) -> None:
"""Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your
own implementation if you need to.
Expand Down
2 changes: 0 additions & 2 deletions src/pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
Expand All @@ -33,7 +32,6 @@
"FullyShardedNativeMixedPrecisionPlugin",
"HPUPrecisionPlugin",
"IPUPrecisionPlugin",
"MixedPrecisionPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
"ShardedNativeMixedPrecisionPlugin",
Expand Down
36 changes: 18 additions & 18 deletions src/pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional

from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer

import pytorch_lightning as pl
from lightning_lite.utilities.types import _PARAMETERS
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from lightning_lite.utilities.types import _PARAMETERS, Steppable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _APEX_AVAILABLE:
from apex import amp


class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
class ApexMixedPrecisionPlugin(PrecisionPlugin):
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""

backend = AMPType.APEX
Expand Down Expand Up @@ -55,31 +54,32 @@ def dispatch(self, trainer: "pl.Trainer") -> None:
self._connected = True
return super().dispatch(trainer)

def backward(
def backward( # type: ignore[override]
self,
tensor: Tensor,
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
optimizer: Optional[Steppable],
*args: Any,
**kwargs: Any,
) -> None:
"""Run before precision plugin executes backward.
r"""Run before precision plugin executes backward.

Args:
tensor: the loss value obtained from the closure
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: current optimizer being used. ``None`` if using manual optimization
optimizer_idx: the index of the current optimizer. ``None`` if using manual optimization
\*args: Positional arguments intended for the actual function that performs the backward, like
:meth:`~torch.Tensor.backward`.
\**kwargs: Keyword arguments for the same purpose as ``*args``.
"""
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:
super().backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs)
with amp.scale_loss(tensor, opt) as tensor:
super().backward(tensor, model, optimizer, *args, **kwargs)

def optimizer_step(
def optimizer_step( # type: ignore[override]
self,
model: Optional[Union["pl.LightningModule", Module]],
optimizer: Optimizer,
optimizer: Steppable,
model: "pl.LightningModule",
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
Expand All @@ -97,7 +97,7 @@ def optimizer_step(
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
if not model.automatic_optimization or not skipped_backward:
return optimizer.step(**kwargs)
return closure_result

Expand Down
33 changes: 11 additions & 22 deletions src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer

import pytorch_lightning as pl
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.types import Steppable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -73,20 +73,20 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona
self.amp_type = amp_type
self.amp_level = amp_level

def backward(
def backward( # type: ignore[override]
self,
tensor: Tensor,
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer: Optional[Steppable],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
r"""Performs back-propagation using DeepSpeed's engine.

Args:
tensor: the loss tensor
model: the model to be optimized
closure_loss: the loss tensor
optimizer: ignored for DeepSpeed
optimizer_idx: ignored for DeepSpeed
\*args: additional positional arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
Expand All @@ -98,19 +98,12 @@ def backward(
" the backward logic internally."
)
deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)
deepspeed_engine.backward(tensor, *args, **kwargs)

def _run_backward(
self, tensor: Tensor, model: Optional["deepspeed.DeepSpeedEngine"], *args: Any, **kwargs: Any
) -> None:
if model is None:
raise ValueError("Please provide the model as input to `backward`.")
model.backward(tensor, *args, **kwargs)

def optimizer_step(
def optimizer_step( # type: ignore[override]
self,
model: Optional[Union["pl.LightningModule", Module]],
optimizer: Optimizer,
optimizer: Steppable,
model: "pl.LightningModule",
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
Expand All @@ -123,16 +116,12 @@ def optimizer_step(
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
if model.automatic_optimization and skipped_backward:
raise MisconfigurationException(
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
)
# DeepSpeed handles the optimizer step internally
deepspeed_engine: "deepspeed.DeepSpeedEngine"
if isinstance(model, pl.LightningModule):
deepspeed_engine = model.trainer.model
else:
deepspeed_engine = model
deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model
return deepspeed_engine.step(**kwargs)

def clip_gradients(
Expand Down
21 changes: 14 additions & 7 deletions src/pytorch_lightning/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union

from lightning_utilities.core.rank_zero import WarningCache
from torch.nn import Module
from torch import Tensor
from torch.optim import LBFGS, Optimizer

import pytorch_lightning as pl
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.types import Steppable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -45,17 +46,23 @@ def __init__(self, precision: int) -> None:
super().__init__()
self.precision = precision

def backward(self, model: "pl.LightningModule", *_: Any, **__: Any) -> None:
def backward( # type: ignore[override]
self,
tensor: Tensor,
model: "pl.LightningModule",
*args: Any,
**kwargs: Any,
) -> None:
if is_overridden("backward", model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since IPUs handle"
" the backward logic internally."
)

def optimizer_step(
def optimizer_step( # type: ignore[override]
self,
model: Optional[Union["pl.LightningModule", Module]],
optimizer: Optimizer,
optimizer: Steppable,
model: "pl.LightningModule",
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
Expand All @@ -69,7 +76,7 @@ def optimizer_step(
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
if model.automatic_optimization and skipped_backward:
# we lack coverage here and IPUs are (currently) limited - something to explore if there's demand
raise MisconfigurationException(
"Skipping backward by returning `None` from your `training_step` is not implemented for IPUs."
Expand Down
26 changes: 0 additions & 26 deletions src/pytorch_lightning/plugins/precision/mixed.py

This file was deleted.

Loading