Skip to content

Commit dd2a1c5

Browse files
awaelchlicarmoccaBorda
authored
Integrate Lite Precision into PL (#14798)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 6df6dea commit dd2a1c5

File tree

17 files changed

+112
-204
lines changed

17 files changed

+112
-204
lines changed

docs/source-pytorch/api_references.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ precision
189189
FullyShardedNativeNativeMixedPrecisionPlugin
190190
HPUPrecisionPlugin
191191
IPUPrecisionPlugin
192-
MixedPrecisionPlugin
193192
NativeMixedPrecisionPlugin
194193
PrecisionPlugin
195194
ShardedNativeMixedPrecisionPlugin

docs/source-pytorch/extensions/plugins.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ The full list of built-in precision plugins is listed below.
5959
FullyShardedNativeNativeMixedPrecisionPlugin
6060
HPUPrecisionPlugin
6161
IPUPrecisionPlugin
62-
MixedPrecisionPlugin
6362
NativeMixedPrecisionPlugin
6463
PrecisionPlugin
6564
ShardedNativeMixedPrecisionPlugin

src/lightning_lite/plugins/precision/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from lightning_lite.plugins.precision.precision import Precision # isort:skip
1514
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
1615
from lightning_lite.plugins.precision.double import DoublePrecision
1716
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
17+
from lightning_lite.plugins.precision.precision import Precision
1818
from lightning_lite.plugins.precision.tpu import TPUPrecision
1919
from lightning_lite.plugins.precision.tpu_bf16 import TPUBf16Precision
2020

src/lightning_lite/plugins/precision/double.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818

19-
from lightning_lite.plugins.precision import Precision
19+
from lightning_lite.plugins.precision.precision import Precision
2020

2121

2222
class DoublePrecision(Precision):

src/lightning_lite/plugins/precision/native_amp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.nn import Module
2020
from torch.optim import LBFGS
2121

22-
from lightning_lite.plugins.precision import Precision
22+
from lightning_lite.plugins.precision.precision import Precision
2323
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
2424
from lightning_lite.utilities.types import Steppable
2525

src/lightning_lite/plugins/precision/precision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def forward_context(self) -> Generator[None, None, None]:
3434
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
3535
yield
3636

37-
def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> None:
37+
def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> Any:
3838
"""Runs before precision plugin executes backward.
3939
4040
Args:
@@ -51,7 +51,7 @@ def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs
5151
"""
5252
tensor.backward(*args, **kwargs)
5353

54-
def post_backward(self, tensor: Tensor, module: Optional[Module]) -> None:
54+
def post_backward(self, tensor: Tensor, module: Optional[Module]) -> Any:
5555
"""Runs after precision plugin executes backward.
5656
5757
Args:
@@ -67,7 +67,7 @@ def optimizer_step(
6767
"""Hook to run the optimizer step."""
6868
return optimizer.step(**kwargs)
6969

70-
def get_main_params(self, optimizer: Optimizer) -> _PARAMETERS:
70+
def main_params(self, optimizer: Optimizer) -> _PARAMETERS:
7171
"""The main params of the model.
7272
7373
Returns the plain model params here. Maybe different in other precision plugins.

src/pytorch_lightning/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7878

7979
- In Lightning Lite, state-dict access to the module wrapper now gets passed through to the original module reference ([#14629](https://github.com/Lightning-AI/lightning/pull/14629))
8080

81+
8182
- Removed fall-back to `LightningEnvironment` when number of SLURM tasks does not correspond to number of processes in Trainer ([#14300](https://github.com/Lightning-AI/lightning/pull/14300))
8283

8384

85+
- Integrated the Lite Precision plugins into the PL Precision plugins - the base class in PL now extends the `lightning_lite.precision.Precision` base class ([#14798](https://github.com/Lightning-AI/lightning/pull/14798))
86+
* The `PrecisionPlugin.backward` signature changed: The `closure_loss` argument was renamed to `tensor`
87+
* The `PrecisionPlugin.{pre_,post_}backward` signature changed: The `closure_loss` argument was renamed to `tensor` and moved as the first argument
88+
* The `PrecisionPlugin.optimizer_step` signature changed: The `model`, `optimizer_idx` and `closure` arguments need to be passed as keyword arguments now
89+
90+
8491
- Trainer queries the CUDA devices through NVML if available to avoid initializing CUDA before forking, which eliminates the need for the `PL_DISABLE_FORK` environment variable introduced in v1.7.4 ([#14631](https://github.com/Lightning-AI/lightning/issues/14631))
8592

8693

src/pytorch_lightning/core/module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from lightning_lite.utilities.cloud_io import get_filesystem
3838
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
3939
from lightning_lite.utilities.distributed import distributed_available, sync_ddp
40+
from lightning_lite.utilities.types import Steppable
4041
from pytorch_lightning.callbacks.callback import Callback
4142
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
4243
from pytorch_lightning.core.mixins import HyperparametersMixin
@@ -1398,7 +1399,7 @@ def training_step(...):
13981399
self.trainer.strategy.backward(loss, None, None, *args, **kwargs)
13991400

14001401
def backward(
1401-
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args: Any, **kwargs: Any
1402+
self, loss: Tensor, optimizer: Optional[Steppable], optimizer_idx: Optional[int], *args: Any, **kwargs: Any
14021403
) -> None:
14031404
"""Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your
14041405
own implementation if you need to.

src/pytorch_lightning/plugins/precision/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
1919
from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin
2020
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin
21-
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
2221
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
2322
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
2423
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
@@ -33,7 +32,6 @@
3332
"FullyShardedNativeMixedPrecisionPlugin",
3433
"HPUPrecisionPlugin",
3534
"IPUPrecisionPlugin",
36-
"MixedPrecisionPlugin",
3735
"NativeMixedPrecisionPlugin",
3836
"PrecisionPlugin",
3937
"ShardedNativeMixedPrecisionPlugin",

src/pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,22 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, Dict, Optional, Union
14+
from typing import Any, Callable, Dict, Optional
1515

1616
from torch import Tensor
17-
from torch.nn import Module
1817
from torch.optim import LBFGS, Optimizer
1918

2019
import pytorch_lightning as pl
21-
from lightning_lite.utilities.types import _PARAMETERS
22-
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
20+
from lightning_lite.utilities.types import _PARAMETERS, Steppable
21+
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
2322
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
2423
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2524

2625
if _APEX_AVAILABLE:
2726
from apex import amp
2827

2928

30-
class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
29+
class ApexMixedPrecisionPlugin(PrecisionPlugin):
3130
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""
3231

3332
backend = AMPType.APEX
@@ -55,31 +54,32 @@ def dispatch(self, trainer: "pl.Trainer") -> None:
5554
self._connected = True
5655
return super().dispatch(trainer)
5756

58-
def backward(
57+
def backward( # type: ignore[override]
5958
self,
59+
tensor: Tensor,
6060
model: "pl.LightningModule",
61-
closure_loss: Tensor,
62-
optimizer: Optional[Optimizer],
63-
optimizer_idx: Optional[int],
61+
optimizer: Optional[Steppable],
6462
*args: Any,
6563
**kwargs: Any,
6664
) -> None:
67-
"""Run before precision plugin executes backward.
65+
r"""Run before precision plugin executes backward.
6866
6967
Args:
68+
tensor: the loss value obtained from the closure
7069
model: the model to be optimized
71-
closure_loss: the loss value obtained from the closure
7270
optimizer: current optimizer being used. ``None`` if using manual optimization
73-
optimizer_idx: the index of the current optimizer. ``None`` if using manual optimization
71+
\*args: Positional arguments intended for the actual function that performs the backward, like
72+
:meth:`~torch.Tensor.backward`.
73+
\**kwargs: Keyword arguments for the same purpose as ``*args``.
7474
"""
7575
opt = optimizer or model.trainer.optimizers
76-
with amp.scale_loss(closure_loss, opt) as closure_loss:
77-
super().backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs)
76+
with amp.scale_loss(tensor, opt) as tensor:
77+
super().backward(tensor, model, optimizer, *args, **kwargs)
7878

79-
def optimizer_step(
79+
def optimizer_step( # type: ignore[override]
8080
self,
81-
model: Optional[Union["pl.LightningModule", Module]],
82-
optimizer: Optimizer,
81+
optimizer: Steppable,
82+
model: "pl.LightningModule",
8383
optimizer_idx: int,
8484
closure: Callable[[], Any],
8585
**kwargs: Any,
@@ -97,7 +97,7 @@ def optimizer_step(
9797
self._after_closure(model, optimizer, optimizer_idx)
9898
skipped_backward = closure_result is None
9999
# in manual optimization, the closure does not return a value
100-
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
100+
if not model.automatic_optimization or not skipped_backward:
101101
return optimizer.step(**kwargs)
102102
return closure_result
103103

0 commit comments

Comments
 (0)