Skip to content

Commit c569fb3

Browse files
awaelchlicarmoccarohitgr7
committed
Fix incorrect precision="mixed" being used with DeepSpeedStrategy and IPUStrategy (#14041)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 854fb31 commit c569fb3

File tree

8 files changed

+16
-15
lines changed

8 files changed

+16
-15
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313
- Fixed the `NeptuneLogger` dependency being unrecognized ([#13988](https://github.com/Lightning-AI/lightning/pull/13988))
1414
- Fixed an issue where users would be warned about unset `max_epochs` even when `fast_dev_run` was set ([#13262](https://github.com/Lightning-AI/lightning/pull/13262))
1515
- Fixed MPS device being unrecognized ([#13992](https://github.com/Lightning-AI/lightning/pull/13992))
16+
- Fixed incorrect `precision="mixed"` being used with `DeepSpeedStrategy` and `IPUStrategy` ([#14041](https://github.com/Lightning-AI/lightning/pull/14041))
1617
- Fixed dtype inference during gradient norm computation ([#14051](https://github.com/Lightning-AI/lightning/pull/14051))
1718
- Fixed a bug that caused `ddp_find_unused_parameters` to be set `False`, whereas the intended default is `True` ([#14095](https://github.com/Lightning-AI/lightning/pull/14095))
1819

@@ -21,8 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2122

2223
### Added
2324

24-
- Added ``ServableModule`` and its associated callback called ``ServableModuleValidator`` to ensure the model can served ([#13614](https://github.com/Lightning-AI/lightning/pull/13614))
25-
- Converted validation loop config warnings to `PossibleUserWarning` ([#13377](https://github.com/Lightning-AI/lightning/pull/13377))
25+
- Added ``ServableModule`` and its associated callback called ``ServableModuleValidator`` to ensure the model can served ([#13614](https://github.com/Lightning-AI/lightning/pull/13614))
26+
- Converted validation loop config warnings to `PossibleUserWarning` ([#13377](https://github.com/Lightning-AI/lightning/pull/13377))
2627
- Added a flag named `log_rank_zero_only` to `EarlyStopping` to disable logging to non-zero rank processes ([#13233](https://github.com/Lightning-AI/lightning/pull/13233))
2728
- Added support for reloading the last checkpoint saved by passing `ckpt_path="last"` ([#12816](https://github.com/Lightning-AI/lightning/pull/12816))
2829
- Added `LightningDataModule.load_from_checkpoint` to support loading datamodules directly from checkpoint ([#12550](https://github.com/Lightning-AI/lightning/pull/12550))

src/pytorch_lightning/plugins/precision/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona
6060

6161
amp_level = amp_level or "O2"
6262

63-
supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED)
63+
supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT)
6464
if precision not in supported_precision:
6565
raise ValueError(
6666
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported."

src/pytorch_lightning/plugins/precision/ipu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytorch_lightning as pl
2020
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
2121
from pytorch_lightning.utilities import GradClipAlgorithmType
22+
from pytorch_lightning.utilities.enums import PrecisionType
2223
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2324
from pytorch_lightning.utilities.model_helpers import is_overridden
2425
from pytorch_lightning.utilities.warnings import WarningCache
@@ -35,7 +36,7 @@ class IPUPrecisionPlugin(PrecisionPlugin):
3536
"""
3637

3738
def __init__(self, precision: int) -> None:
38-
supported_precision_values = (16, 32)
39+
supported_precision_values = (PrecisionType.HALF, PrecisionType.FLOAT)
3940
if precision not in supported_precision_values:
4041
raise ValueError(
4142
f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported."

src/pytorch_lightning/strategies/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def _auto_select_batch_size(self) -> int:
695695

696696
def _format_precision_config(self) -> None:
697697
assert isinstance(self.config, dict)
698-
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
698+
if self.precision_plugin.precision == PrecisionType.HALF:
699699
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
700700
# FP16 is a DeepSpeed standalone AMP implementation
701701
rank_zero_info("Enabling DeepSpeed FP16.")

src/pytorch_lightning/strategies/ipu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
self.precision = precision
5858

5959
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
60-
if self.precision in (PrecisionType.MIXED, PrecisionType.HALF):
60+
if self.precision == PrecisionType.HALF:
6161
inputs = self._move_float_tensors_to_half(inputs)
6262

6363
return super().forward(*inputs, **kwargs)

src/pytorch_lightning/strategies/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def on_colab_kaggle() -> bool:
2424

2525
def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
2626
if torch.is_floating_point(tensor):
27-
if precision in (PrecisionType.MIXED, PrecisionType.HALF):
27+
if precision == PrecisionType.HALF:
2828
return tensor.half()
2929
if precision == PrecisionType.BFLOAT:
3030
return tensor.bfloat16()

tests/tests_pytorch/accelerators/test_ipu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_optimization(tmpdir):
185185

186186

187187
@RunIf(ipu=True)
188-
def test_mixed_precision(tmpdir):
188+
def test_half_precision(tmpdir):
189189
class TestCallback(Callback):
190190
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
191191
assert trainer.strategy.model.precision == 16

tests/tests_pytorch/strategies/test_deepspeed_strategy.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,11 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config):
172172

173173
@RunIf(deepspeed=True)
174174
@mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1)
175-
@pytest.mark.parametrize("precision", [16, "mixed"])
176175
@pytest.mark.parametrize(
177176
"amp_backend",
178177
["native", pytest.param("apex", marks=RunIf(amp_apex=True))],
179178
)
180-
def test_deepspeed_precision_choice(_, amp_backend, precision, tmpdir):
179+
def test_deepspeed_precision_choice(_, amp_backend, tmpdir):
181180
"""Test to ensure precision plugin is also correctly chosen.
182181
183182
DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin
@@ -189,16 +188,16 @@ def test_deepspeed_precision_choice(_, amp_backend, precision, tmpdir):
189188
accelerator="gpu",
190189
strategy="deepspeed",
191190
amp_backend=amp_backend,
192-
precision=precision,
191+
precision=16,
193192
)
194193

195194
assert isinstance(trainer.strategy, DeepSpeedStrategy)
196195
assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecisionPlugin)
197-
assert trainer.strategy.precision_plugin.precision == precision
196+
assert trainer.strategy.precision_plugin.precision == 16
198197

199198

200199
@RunIf(deepspeed=True)
201-
def test_deepspeed_with_invalid_config_path(tmpdir):
200+
def test_deepspeed_with_invalid_config_path():
202201
"""Test to ensure if we pass an invalid config path we throw an exception."""
203202

204203
with pytest.raises(
@@ -219,7 +218,7 @@ def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config):
219218

220219

221220
@RunIf(deepspeed=True)
222-
def test_deepspeed_defaults(tmpdir):
221+
def test_deepspeed_defaults():
223222
"""Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed."""
224223
strategy = DeepSpeedStrategy()
225224
assert strategy.config is not None
@@ -664,7 +663,7 @@ def training_step(self, batch, batch_idx):
664663

665664

666665
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
667-
def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config):
666+
def test_deepspeed_multigpu_stage_3(tmpdir):
668667
"""Test to ensure ZeRO Stage 3 works with a parallel model."""
669668
model = ModelParallelBoringModel()
670669
trainer = Trainer(

0 commit comments

Comments
 (0)