Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1108,9 +1108,6 @@ Combine hooks for accumulated benefit:

When using Post-localSGD, you must also pass ``model_averaging_period`` to allow for model parameter averaging:

.. note::
Post-localSGD support requires PyTorch>=1.10.0

.. code-block:: python

from pytorch_lightning import Trainer
Expand Down
5 changes: 1 addition & 4 deletions docs/source-pytorch/common/precision_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ BFloat16 Mixed Precision

.. warning::

BFloat16 requires PyTorch 1.10 or later and is only supported with PyTorch Native AMP.

BFloat16 is also experimental and may not provide significant speedups or memory improvements, offering better numerical stability.

Do note for GPUs, the most significant benefits require `Ampere <https://en.wikipedia.org/wiki/Ampere_(microarchitecture)>`__ based GPUs, such as A100s or 3090s.
Expand All @@ -126,14 +124,13 @@ BFloat16 Mixed precision is similar to FP16 mixed precision, however, it maintai
Under the hood, we use `torch.autocast <https://pytorch.org/docs/stable/amp.html>`__ with the dtype set to ``bfloat16``, with no gradient scaling.

.. testcode::
:skipif: not _TORCH_GREATER_EQUAL_1_10 or not torch.cuda.is_available()
:skipif: not torch.cuda.is_available()

Trainer(accelerator="gpu", devices=1, precision="bf16")

It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDNN under the hood.

.. testcode::
:skipif: not _TORCH_GREATER_EQUAL_1_10

Trainer(precision="bf16")

Expand Down
1 change: 0 additions & 1 deletion docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ def package_list_from_file(file):
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TORCH_GREATER_EQUAL_1_10,
)
from pytorch_lightning.loggers.neptune import _NEPTUNE_AVAILABLE
from pytorch_lightning.loggers.comet import _COMET_AVAILABLE
Expand Down
7 changes: 2 additions & 5 deletions src/lightning_lite/plugins/collectives/torch_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import Self

from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_13
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from lightning_lite.utilities.types import CollectibleGroup, RedOpType, ReduceOp

if dist.is_available():
Expand Down Expand Up @@ -86,10 +86,7 @@ def all_gather_object(self, object_list: List[Any], obj: Any) -> List[Any]:
def broadcast_object_list(
self, object_list: List[Any], src: int, device: Optional[torch.device] = None
) -> List[Any]:
kwargs = {}
if _TORCH_GREATER_EQUAL_1_10:
kwargs["device"] = device
dist.broadcast_object_list(object_list, src, group=self.group, **kwargs)
dist.broadcast_object_list(object_list, src, group=self.group, device=device)
return object_list

def gather_object(self, obj: Any, object_gather_list: List[Any], dst: int = 0) -> List[Any]:
Expand Down
20 changes: 5 additions & 15 deletions src/lightning_lite/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Dict, Generator, Optional, Union
from typing import Any, Dict, Generator, Optional

import torch
from torch import Tensor
Expand All @@ -23,14 +23,8 @@
from lightning_lite.accelerators.cuda import _patch_cuda_is_available
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
from lightning_lite.utilities.types import Optimizable

if _TORCH_GREATER_EQUAL_1_10:
from torch import autocast as new_autocast
else:
from torch.cuda.amp import autocast as old_autocast


class NativeMixedPrecision(Precision):
"""Plugin for Native Mixed Precision (AMP) training with ``torch.autocast``.
Expand All @@ -45,8 +39,6 @@ def __init__(
self, precision: Literal[16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
super().__init__()
if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10:
raise ImportError("To use bfloat16 with native amp you must install torch greater or equal to 1.10.")
if scaler is None and precision == 16:
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
Expand Down Expand Up @@ -96,9 +88,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)

def _autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]:
if _TORCH_GREATER_EQUAL_1_10:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return new_autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
return old_autocast()
def _autocast_context_manager(self) -> torch.autocast:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
2 changes: 0 additions & 2 deletions src/lightning_lite/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_TORCH_GREATER_EQUAL_1_10 = compare_version("torch", operator.ge, "1.10.0")
_TORCH_LESSER_EQUAL_1_10_2 = compare_version("torch", operator.le, "1.10.2")
_TORCH_GREATER_EQUAL_1_11 = compare_version("torch", operator.ge, "1.11.0")
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
Expand Down
20 changes: 5 additions & 15 deletions src/pytorch_lightning/callbacks/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,14 @@

import torch
from torch import Tensor
from torch.ao.quantization.qconfig import QConfig
from torch.quantization import FakeQuantizeBase

import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _TORCH_GREATER_EQUAL_1_10:
from torch.ao.quantization.qconfig import QConfig
else:
from torch.quantization import QConfig

if _TORCH_GREATER_EQUAL_1_11:
from torch.ao.quantization import fuse_modules_qat as fuse_modules
else:
Expand Down Expand Up @@ -252,15 +248,9 @@ def _prepare_model(self, model: "pl.LightningModule") -> None:
if self._observer_type == "histogram":
model.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
elif self._observer_type == "average":
extra_kwargs: Dict[str, Optional[int]] = {}
if _TORCH_GREATER_EQUAL_1_12:
extra_kwargs["version"] = 0
# version=None corresponds to using FakeQuantize rather than
# FusedMovingAvgObsFakeQuantize which was introduced in PT1.10
# details in https://github.com/pytorch/pytorch/issues/64564
elif _TORCH_GREATER_EQUAL_1_10:
extra_kwargs["version"] = None
model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
model.qconfig = torch.quantization.get_default_qat_qconfig(
self._qconfig, version=0 if _TORCH_GREATER_EQUAL_1_12 else None
)

elif isinstance(self._qconfig, QConfig):
model.qconfig = self._qconfig # type: ignore [assignment]
Expand Down
11 changes: 2 additions & 9 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.loggers import Logger
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType
from pytorch_lightning.utilities import _IS_WINDOWS, GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn
Expand Down Expand Up @@ -1824,13 +1824,6 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non
input_sample = self._on_before_batch_transfer(input_sample)
input_sample = self._apply_batch_transfer_handler(input_sample)

if not _TORCH_GREATER_EQUAL_1_10 and "example_outputs" not in kwargs:
self.eval()
if isinstance(input_sample, tuple):
kwargs["example_outputs"] = self(*input_sample)
else:
kwargs["example_outputs"] = self(input_sample)

torch.onnx.export(self, input_sample, file_path, **kwargs)
self.train(mode)

Expand Down Expand Up @@ -1938,7 +1931,7 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:

These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
"""
if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS or not torch.distributed.is_available():
if _IS_WINDOWS or not torch.distributed.is_available():
rank_zero_debug("Could not register sharded tensor state dict hooks")
return

Expand Down
21 changes: 5 additions & 16 deletions src/pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,9 @@
from lightning_lite.accelerators.cuda import _patch_cuda_is_available
from lightning_lite.utilities.types import Optimizable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType, GradClipAlgorithmType
from pytorch_lightning.utilities import AMPType, GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _TORCH_GREATER_EQUAL_1_10:
from torch import autocast as new_autocast
else:
from torch.cuda.amp import autocast as old_autocast


class NativeMixedPrecisionPlugin(PrecisionPlugin):
"""Plugin for Native Mixed Precision (AMP) training with ``torch.autocast``.
Expand All @@ -46,10 +41,6 @@ def __init__(
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
super().__init__()
if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10:
raise MisconfigurationException(
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
if scaler is None and precision == 16:
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
Expand Down Expand Up @@ -113,12 +104,10 @@ def clip_gradients(
)
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)

def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]:
if _TORCH_GREATER_EQUAL_1_10:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return new_autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
return old_autocast()
def autocast_context_manager(self) -> torch.autocast:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
Expand Down
8 changes: 4 additions & 4 deletions src/pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
from pytorch_lightning.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
else:
OSS = object
if _TORCH_GREATER_EQUAL_1_10 and torch.distributed.is_available():
if torch.distributed.is_available():
from torch.distributed.algorithms.model_averaging.averagers import ModelAverager

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -181,7 +181,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.setup_optimizers(trainer)
_optimizers_to_device(self.optimizers, self.root_device)

if _TORCH_GREATER_EQUAL_1_10 and trainer_fn == TrainerFn.FITTING:
if trainer_fn == TrainerFn.FITTING:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD

if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
Expand Down Expand Up @@ -279,7 +279,7 @@ def optimizer_step(
"""
optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)

if not _TORCH_GREATER_EQUAL_1_10 or self._model_averager is None:
if self._model_averager is None:
return optimizer_output

params = [param for group in optimizer.param_groups for param in group["params"] if param.grad is not None]
Expand Down
1 change: 0 additions & 1 deletion src/pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
_IS_WINDOWS,
_OMEGACONF_AVAILABLE,
_POPTORCH_AVAILABLE,
_TORCH_GREATER_EQUAL_1_10,
_TORCH_GREATER_EQUAL_1_11,
_TORCH_GREATER_EQUAL_1_12,
_TORCH_QUANTIZE_AVAILABLE,
Expand Down
1 change: 0 additions & 1 deletion src/pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
_IS_WINDOWS = platform.system() == "Windows"
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_TORCH_GREATER_EQUAL_1_10 = compare_version("torch", operator.ge, "1.10.0")
_TORCH_LESSER_EQUAL_1_10_2 = compare_version("torch", operator.le, "1.10.2")
_TORCH_GREATER_EQUAL_1_11 = compare_version("torch", operator.ge, "1.11.0")
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
Expand Down
3 changes: 1 addition & 2 deletions tests/tests_lite/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from lightning_lite.accelerators.mps import MPSAccelerator
from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10


class RunIf:
Expand Down Expand Up @@ -97,7 +96,7 @@ def __new__(

if bf16_cuda:
try:
cond = not (torch.cuda.is_available() and _TORCH_GREATER_EQUAL_1_10 and torch.cuda.is_bf16_supported())
cond = not (torch.cuda.is_available() and torch.cuda.is_bf16_supported())
except (AssertionError, RuntimeError) as e:
# AssertionError: Torch not compiled with CUDA enabled
# RuntimeError: Found no NVIDIA driver on your system.
Expand Down
8 changes: 0 additions & 8 deletions tests/tests_lite/plugins/precision/test_native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import Mock

import pytest
Expand All @@ -25,7 +24,6 @@ def test_native_amp_precision_default_scaler():
assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)


@mock.patch("lightning_lite.plugins.precision.native_amp._TORCH_GREATER_EQUAL_1_10", True)
def test_native_amp_precision_scaler_with_bf16():
with pytest.raises(ValueError, match="`precision='bf16'` does not use a scaler"):
NativeMixedPrecision(precision="bf16", device=Mock(), scaler=Mock())
Expand All @@ -34,12 +32,6 @@ def test_native_amp_precision_scaler_with_bf16():
assert precision.scaler is None


@mock.patch("lightning_lite.plugins.precision.native_amp._TORCH_GREATER_EQUAL_1_10", False)
def test_native_amp_precision_bf16_min_torch():
with pytest.raises(ImportError, match="you must install torch greater or equal to 1.10"):
NativeMixedPrecision(precision="bf16", device=Mock())


def test_native_amp_precision_forward_context():
"""Test to ensure that the context manager correctly is set to bfloat16 on CPU and CUDA."""
precision = NativeMixedPrecision(precision=16, device="cuda")
Expand Down
10 changes: 0 additions & 10 deletions tests/tests_lite/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,22 +763,13 @@ def test_ddp_fork_on_unsupported_platform(_, strategy):
_Connector(strategy=strategy)


@mock.patch("lightning_lite.plugins.precision.native_amp._TORCH_GREATER_EQUAL_1_10", True)
def test_precision_selection_16_on_cpu_warns():
with pytest.warns(
UserWarning, match=r"precision=16\)` but native AMP is not supported on CPU. Using `precision='bf16"
):
_Connector(precision=16)


@mock.patch("lightning_lite.plugins.precision.native_amp._TORCH_GREATER_EQUAL_1_10", False)
def test_precision_selection_16_raises_torch_version(monkeypatch):
with pytest.raises(ImportError, match="must install torch greater or equal to 1.10"):
_Connector(accelerator="cpu", precision=16)
with pytest.raises(ImportError, match="must install torch greater or equal to 1.10"):
_Connector(accelerator="cpu", precision="bf16")


class MyNativeAMP(NativeMixedPrecision):
pass

Expand All @@ -789,7 +780,6 @@ class MyNativeAMP(NativeMixedPrecision):
"is_custom_plugin,plugin_cls",
[(False, NativeMixedPrecision), (True, MyNativeAMP)],
)
@mock.patch("lightning_lite.plugins.precision.native_amp._TORCH_GREATER_EQUAL_1_10", True)
def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin_cls):
plugin = None
if is_custom_plugin:
Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
_IPU_AVAILABLE,
_OMEGACONF_AVAILABLE,
_PSUTIL_AVAILABLE,
_TORCH_GREATER_EQUAL_1_10,
_TORCH_QUANTIZE_AVAILABLE,
)
from tests_pytorch.helpers.datamodules import _SKLEARN_AVAILABLE
Expand Down Expand Up @@ -162,7 +161,7 @@ def __new__(

if bf16_cuda:
try:
cond = not (torch.cuda.is_available() and _TORCH_GREATER_EQUAL_1_10 and torch.cuda.is_bf16_supported())
cond = not (torch.cuda.is_available() and torch.cuda.is_bf16_supported())
except (AssertionError, RuntimeError) as e:
# AssertionError: Torch not compiled with CUDA enabled
# RuntimeError: Found no NVIDIA driver on your system.
Expand Down
11 changes: 0 additions & 11 deletions tests/tests_pytorch/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,17 +258,6 @@ def test_precision_selection_raises(monkeypatch):
):
Trainer(amp_backend="apex", precision=16)

import pytorch_lightning.plugins.precision.native_amp as amp

monkeypatch.setattr(amp, "_TORCH_GREATER_EQUAL_1_10", False)
with pytest.warns(
UserWarning, match=r"precision=16\)` but native AMP is not supported on CPU. Using `precision='bf16"
), pytest.raises(MisconfigurationException, match="must install torch greater or equal to 1.10"):
Trainer(precision=16)

with pytest.raises(MisconfigurationException, match="must install torch greater or equal to 1.10"):
Trainer(precision="bf16")

with pytest.raises(MisconfigurationException, match=r"amp_type='apex', precision='bf16'\)` but it's not supported"):
Trainer(amp_backend="apex", precision="bf16")

Expand Down
Loading