Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed MPS device being unrecognized ([#13992](https://github.com/Lightning-AI/lightning/pull/13992))


- Fixed incorrect `precision="mixed"` being used with `DeepSpeedStrategy` and `IPUStrategy` ([#14041](https://github.com/Lightning-AI/lightning/pull/14041))


- Fixed dtype inference during gradient norm computation ([#14051](https://github.com/Lightning-AI/lightning/pull/14051))


Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona

amp_level = amp_level or "O2"

supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED)
supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT)
if precision not in supported_precision:
raise ValueError(
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported."
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache
Expand All @@ -35,7 +36,7 @@ class IPUPrecisionPlugin(PrecisionPlugin):
"""

def __init__(self, precision: int) -> None:
supported_precision_values = (16, 32)
supported_precision_values = (PrecisionType.HALF, PrecisionType.FLOAT)
if precision not in supported_precision_values:
raise ValueError(
f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported."
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def _auto_select_batch_size(self) -> int:

def _format_precision_config(self) -> None:
assert isinstance(self.config, dict)
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
if self.precision_plugin.precision == PrecisionType.HALF:
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.precision = precision

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
if self.precision in (PrecisionType.MIXED, PrecisionType.HALF):
if self.precision == PrecisionType.HALF:
inputs = self._move_float_tensors_to_half(inputs)

return super().forward(*inputs, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def on_colab_kaggle() -> bool:

def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
if torch.is_floating_point(tensor):
if precision in (PrecisionType.MIXED, PrecisionType.HALF):
if precision == PrecisionType.HALF:
return tensor.half()
if precision == PrecisionType.BFLOAT:
return tensor.bfloat16()
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_optimization(tmpdir):


@RunIf(ipu=True)
def test_mixed_precision(tmpdir):
def test_half_precision(tmpdir):
class TestCallback(Callback):
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
assert trainer.strategy.model.precision == 16
Expand Down
13 changes: 6 additions & 7 deletions tests/tests_pytorch/strategies/test_deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,11 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config):

@RunIf(deepspeed=True)
@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1)
@pytest.mark.parametrize("precision", [16, "mixed"])
@pytest.mark.parametrize(
"amp_backend",
["native", pytest.param("apex", marks=RunIf(amp_apex=True))],
)
def test_deepspeed_precision_choice(_, amp_backend, precision, tmpdir):
def test_deepspeed_precision_choice(_, amp_backend, tmpdir):
"""Test to ensure precision plugin is also correctly chosen.

DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin
Expand All @@ -188,16 +187,16 @@ def test_deepspeed_precision_choice(_, amp_backend, precision, tmpdir):
accelerator="gpu",
strategy="deepspeed",
amp_backend=amp_backend,
precision=precision,
precision=16,
)

assert isinstance(trainer.strategy, DeepSpeedStrategy)
assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecisionPlugin)
assert trainer.strategy.precision_plugin.precision == precision
assert trainer.strategy.precision_plugin.precision == 16


@RunIf(deepspeed=True)
def test_deepspeed_with_invalid_config_path(tmpdir):
def test_deepspeed_with_invalid_config_path():
"""Test to ensure if we pass an invalid config path we throw an exception."""

with pytest.raises(
Expand All @@ -218,7 +217,7 @@ def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config):


@RunIf(deepspeed=True)
def test_deepspeed_defaults(tmpdir):
def test_deepspeed_defaults():
"""Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed."""
strategy = DeepSpeedStrategy()
assert strategy.config is not None
Expand Down Expand Up @@ -663,7 +662,7 @@ def training_step(self, batch, batch_idx):


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config):
def test_deepspeed_multigpu_stage_3(tmpdir):
"""Test to ensure ZeRO Stage 3 works with a parallel model."""
model = ModelParallelBoringModel()
trainer = Trainer(
Expand Down