diff --git a/src/lightning_fabric/connector.py b/src/lightning_fabric/connector.py index 9fb4fe09f6ec3..aa09d58f9d477 100644 --- a/src/lightning_fabric/connector.py +++ b/src/lightning_fabric/connector.py @@ -13,10 +13,10 @@ # limitations under the License. import os from collections import Counter -from typing import Any, Dict, List, Optional, Union +from typing import Any, cast, Dict, List, Optional, Union import torch -from typing_extensions import Literal +from typing_extensions import get_args from lightning_fabric.accelerators import ACCELERATOR_REGISTRY from lightning_fabric.accelerators.accelerator import Accelerator @@ -41,6 +41,7 @@ ) from lightning_fabric.plugins.precision.double import DoublePrecision from lightning_fabric.plugins.precision.fsdp import FSDPPrecision +from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR from lightning_fabric.strategies import ( DDPShardedStrategy, DDPStrategy, @@ -59,7 +60,6 @@ _PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO] _PLUGIN_INPUT = Union[_PLUGIN, str] -_PRECISION_INPUT = Literal[16, 32, 64, "bf16"] class _Connector: @@ -113,14 +113,13 @@ def __init__( # Get registered strategies, built-in accelerators and precision plugins self._registered_strategies = STRATEGY_REGISTRY.available_strategies() self._registered_accelerators = ACCELERATOR_REGISTRY.available_accelerators() - self._precision_types = ("16", "32", "64", "bf16") # Raise an exception if there are conflicts between flags # Set each valid flag to `self._x_flag` after validation # For devices: Assign gpus, etc. to the accelerator flag and devices flag self._strategy_flag: Optional[Union[Strategy, str]] = None self._accelerator_flag: Optional[Union[Accelerator, str]] = None - self._precision_input: Optional[_PRECISION_INPUT] = None + self._precision_input: _PRECISION_INPUT_STR = "32" self._precision_instance: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None self._parallel_devices: List[Union[int, torch.device, str]] = [] @@ -206,12 +205,10 @@ def _check_config_and_set_final_flags( self._accelerator_flag = accelerator - if precision is not None: - if str(precision) not in self._precision_types: - raise ValueError( - f"Precision {repr(precision)} is invalid. Allowed precision values: {self._precision_types}" - ) - self._precision_input = precision + supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + if precision not in supported_precision: + raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}") + self._precision_input = cast(_PRECISION_INPUT_STR, str(precision)) if plugins: plugins_flags_types: Dict[str, int] = Counter() @@ -442,10 +439,10 @@ def _check_and_init_precision(self) -> Precision: return self._precision_instance if isinstance(self.accelerator, TPUAccelerator): - if self._precision_input == 32: + if self._precision_input == "32": return TPUPrecision() - elif self._precision_input in (16, "bf16"): - if self._precision_input == 16: + elif self._precision_input in ("16", "bf16"): + if self._precision_input == "16": rank_zero_warn( "You passed `Fabric(accelerator='tpu', precision=16)` but AMP" " is not supported with TPUs. Using `precision='bf16'` instead." @@ -454,22 +451,22 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_input) # type: ignore - if self._precision_input == 32: + if self._precision_input == "32": return Precision() - if self._precision_input == 64: + if self._precision_input == "64": return DoublePrecision() - if self._precision_input == 16 and self._accelerator_flag == "cpu": + if self._precision_input == "16" and self._accelerator_flag == "cpu": rank_zero_warn( "You passed `Fabric(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead." ) self._precision_input = "bf16" - if self._precision_input in (16, "bf16"): + if self._precision_input in ("16", "bf16"): rank_zero_info( "Using 16-bit Automatic Mixed Precision (AMP)" - if self._precision_input == 16 + if self._precision_input == "16" else "Using bfloat16 Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" @@ -483,7 +480,7 @@ def _check_and_init_precision(self) -> Precision: def _validate_precision_choice(self) -> None: """Validate the combination of choices for precision, and accelerator.""" if isinstance(self.accelerator, TPUAccelerator): - if self._precision_input == 64: + if self._precision_input == "64": raise NotImplementedError( "`Fabric(accelerator='tpu', precision=64)` is not implemented." " Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`" @@ -536,16 +533,12 @@ def _lazy_init_strategy(self) -> None: @staticmethod def _argument_from_env(name: str, current: Any, default: Any) -> Any: - env_value: Optional[Union[str, int]] = os.environ.get("LT_" + name.upper()) + env_value: Optional[str] = os.environ.get("LT_" + name.upper()) if env_value is None: return current - if name == "precision": - # TODO: support precision input as string, then this special handling is not needed - env_value = int(env_value) if env_value in ("16", "32", "64") else env_value - - if env_value is not None and env_value != current and current != default: + if env_value is not None and env_value != str(current) and str(current) != str(default): raise ValueError( f"Your code has `Fabric({name}={current!r}, ...)` but it conflicts with the value " f"`--{name}={current}` set through the CLI. " diff --git a/src/lightning_fabric/plugins/precision/deepspeed.py b/src/lightning_fabric/plugins/precision/deepspeed.py index 6cf1cf50ebc5a..68aa84b236a43 100644 --- a/src/lightning_fabric/plugins/precision/deepspeed.py +++ b/src/lightning_fabric/plugins/precision/deepspeed.py @@ -11,22 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING, Union import torch from lightning_utilities.core.imports import RequirementCache from torch import Tensor -from typing_extensions import Literal +from typing_extensions import get_args, Literal from lightning_fabric.plugins.precision.precision import Precision from lightning_fabric.plugins.precision.utils import _convert_fp_tensor -from lightning_fabric.utilities.enums import PrecisionType from lightning_fabric.utilities.types import Steppable _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") if TYPE_CHECKING and _DEEPSPEED_AVAILABLE: import deepspeed +_PRECISION_INPUT_INT = Literal[32, 16] +_PRECISION_INPUT_STR = Literal["32", "16", "bf16"] +_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] + class DeepSpeedPrecision(Precision): """Precision plugin for DeepSpeed integration. @@ -39,19 +42,17 @@ class DeepSpeedPrecision(Precision): If unsupported ``precision`` is provided. """ - def __init__(self, precision: Literal[16, 32, "bf16"]) -> None: - supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT) + def __init__(self, precision: _PRECISION_INPUT) -> None: + supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) if precision not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in DeepSpeed." - f" `precision` must be one of: {(x.value for x in supported_precision)}." + f" `precision` must be one of: {supported_precision}." ) - - super().__init__() - self.precision = precision + self.precision = cast(_PRECISION_INPUT_STR, str(precision)) def convert_input(self, data: Tensor) -> Tensor: - precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16, 32: torch.float32} + precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16, "32": torch.float32} dst_type = precision_to_type[self.precision] return _convert_fp_tensor(data, dst_type) diff --git a/src/lightning_fabric/plugins/precision/double.py b/src/lightning_fabric/plugins/precision/double.py index 630f00b44182e..0bdc786a4fe51 100644 --- a/src/lightning_fabric/plugins/precision/double.py +++ b/src/lightning_fabric/plugins/precision/double.py @@ -17,6 +17,7 @@ import torch from torch import Tensor from torch.nn import Module +from typing_extensions import Literal from lightning_fabric.plugins.precision.precision import Precision from lightning_fabric.plugins.precision.utils import _convert_fp_tensor @@ -25,7 +26,7 @@ class DoublePrecision(Precision): """Plugin for training with double (``torch.float64``) precision.""" - precision: int = 64 + precision: Literal["64"] = "64" def convert_module(self, module: Module) -> Module: return module.double() diff --git a/src/lightning_fabric/plugins/precision/fsdp.py b/src/lightning_fabric/plugins/precision/fsdp.py index 392a83a1c3d92..3e8cb6346e230 100644 --- a/src/lightning_fabric/plugins/precision/fsdp.py +++ b/src/lightning_fabric/plugins/precision/fsdp.py @@ -17,7 +17,6 @@ from typing_extensions import Literal from lightning_fabric.plugins.precision.native_amp import MixedPrecision -from lightning_fabric.utilities.enums import PrecisionType from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 if TYPE_CHECKING: @@ -29,7 +28,7 @@ class FSDPPrecision(MixedPrecision): """AMP for Fully Sharded Data Parallel training.""" def __init__( - self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None + self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None ) -> None: if not _TORCH_GREATER_EQUAL_1_12: raise NotImplementedError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.") @@ -39,16 +38,16 @@ def __init__( super().__init__( precision=precision, device=device, - scaler=(ShardedGradScaler() if scaler is None and precision == 16 else None), + scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None), ) @property def mixed_precision_config(self) -> "TorchMixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision - if self.precision == PrecisionType.HALF: + if self.precision == "16": dtype = torch.float16 - elif self.precision == PrecisionType.BFLOAT: + elif self.precision == "bf16": dtype = torch.bfloat16 else: raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.") diff --git a/src/lightning_fabric/plugins/precision/native_amp.py b/src/lightning_fabric/plugins/precision/native_amp.py index 7ee3829f4d98d..e86a61c490a24 100644 --- a/src/lightning_fabric/plugins/precision/native_amp.py +++ b/src/lightning_fabric/plugins/precision/native_amp.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Dict, Generator, Optional +from typing import Any, cast, Dict, Generator, Optional import torch from torch import Tensor @@ -36,16 +36,15 @@ class MixedPrecision(Precision): """ def __init__( - self, precision: Literal[16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None + self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None ) -> None: - super().__init__() - if scaler is None and precision == 16: + self.precision = cast(Literal["16", "bf16"], str(precision)) + if scaler is None and self.precision == "16": with _patch_cuda_is_available(): # if possible, we defer CUDA initialization to support strategies that will attempt forks scaler = torch.cuda.amp.GradScaler() - if scaler is not None and precision == "bf16": + if scaler is not None and self.precision == "bf16": raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.") - self.precision = precision self.device = device self.scaler = scaler @@ -55,7 +54,7 @@ def forward_context(self) -> Generator[None, None, None]: yield def convert_input(self, data: Tensor) -> Tensor: - precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16} + precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16} dst_type = precision_to_type[self.precision] return _convert_fp_tensor(data, dst_type) diff --git a/src/lightning_fabric/plugins/precision/precision.py b/src/lightning_fabric/plugins/precision/precision.py index db355316895b7..2a18eeca3d750 100644 --- a/src/lightning_fabric/plugins/precision/precision.py +++ b/src/lightning_fabric/plugins/precision/precision.py @@ -18,10 +18,15 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer +from typing_extensions import Literal from lightning_fabric.plugins.precision.utils import _convert_fp_tensor from lightning_fabric.utilities.types import _PARAMETERS, Optimizable +_PRECISION_INPUT_INT = Literal[64, 32, 16] +_PRECISION_INPUT_STR = Literal["64", "32", "16", "bf16"] +_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] + class Precision: """Base class for all plugins handling the precision-specific parts of the training. @@ -29,7 +34,7 @@ class Precision: The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. """ - precision: Union[str, int] = 32 + precision: _PRECISION_INPUT_STR = "32" def convert_module(self, module: Module) -> Module: """Convert the module parameters to the precision type this plugin handles. diff --git a/src/lightning_fabric/plugins/precision/tpu_bf16.py b/src/lightning_fabric/plugins/precision/tpu_bf16.py index 84c4d7eeec8cb..36be3d4d2b96e 100644 --- a/src/lightning_fabric/plugins/precision/tpu_bf16.py +++ b/src/lightning_fabric/plugins/precision/tpu_bf16.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from lightning_fabric.plugins.precision import TPUPrecision from lightning_fabric.plugins.precision.utils import _convert_fp_tensor @@ -23,7 +24,7 @@ class TPUBf16Precision(TPUPrecision): """Plugin that enables bfloats on TPUs.""" - precision: str = "bf16" + precision: Literal["bf16"] = "bf16" def __init__(self) -> None: super().__init__() diff --git a/src/lightning_fabric/strategies/deepspeed.py b/src/lightning_fabric/strategies/deepspeed.py index 3b3379c7302a1..3faaacf6d2fca 100644 --- a/src/lightning_fabric/strategies/deepspeed.py +++ b/src/lightning_fabric/strategies/deepspeed.py @@ -31,7 +31,6 @@ from lightning_fabric.strategies.ddp import DDPStrategy from lightning_fabric.strategies.strategy import _Sharded from lightning_fabric.utilities.distributed import log -from lightning_fabric.utilities.enums import PrecisionType from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only from lightning_fabric.utilities.seed import reset_seed from lightning_fabric.utilities.types import _PATH @@ -349,9 +348,9 @@ def module_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: assert self._config_initialized - if self.precision.precision == PrecisionType.HALF: + if self.precision.precision == "16": dtype = torch.float16 - elif self.precision.precision == PrecisionType.BFLOAT: + elif self.precision.precision == "bf16": dtype = torch.bfloat16 else: dtype = torch.float32 @@ -499,7 +498,7 @@ def _format_config(self) -> None: def _format_precision_config(self) -> None: assert isinstance(self.config, dict) - if self.precision.precision == PrecisionType.HALF: + if self.precision.precision == "16": if "fp16" not in self.config: # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") @@ -511,7 +510,7 @@ def _format_precision_config(self) -> None: "hysteresis": self.hysteresis, "min_loss_scale": self.min_loss_scale, } - elif "bf16" not in self.config and self.precision.precision == PrecisionType.BFLOAT: + elif "bf16" not in self.config and self.precision.precision == "bf16": rank_zero_info("Enabling DeepSpeed BF16.") self.config["bf16"] = {"enabled": True} diff --git a/src/lightning_fabric/strategies/fairscale.py b/src/lightning_fabric/strategies/fairscale.py index e78be8ec83083..a31f1f6e9c1ea 100644 --- a/src/lightning_fabric/strategies/fairscale.py +++ b/src/lightning_fabric/strategies/fairscale.py @@ -26,7 +26,6 @@ from lightning_fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning_fabric.strategies.ddp import DDPStrategy from lightning_fabric.strategies.strategy import _BackwardSyncControl -from lightning_fabric.utilities.enums import PrecisionType from lightning_fabric.utilities.imports import _IS_WINDOWS _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and module_available("fairscale.nn") @@ -116,7 +115,7 @@ def _reinit_optimizers_with_oss(optimizers: List[Optimizer], precision: Precisio if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) - is_fp16 = precision.precision in (PrecisionType.MIXED, PrecisionType.HALF) + is_fp16 = precision.precision == "16" # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade # the model performance. diff --git a/src/lightning_fabric/utilities/enums.py b/src/lightning_fabric/utilities/enums.py index cd8a3dd5bd062..e98796b7c6702 100644 --- a/src/lightning_fabric/utilities/enums.py +++ b/src/lightning_fabric/utilities/enums.py @@ -29,24 +29,6 @@ class LightningEnum(StrEnum, Enum): LightningEnum = StrEnum -class PrecisionType(LightningEnum): - """Type of precision used.""" - - HALF = "16" - FLOAT = "32" - FULL = "64" - BFLOAT = "bf16" - MIXED = "mixed" - - @staticmethod - def supported_type(precision: str | int) -> bool: - return any(x == precision for x in PrecisionType) - - @staticmethod - def supported_types() -> list[str]: - return [x.value for x in PrecisionType] - - class _StrategyType(LightningEnum): """Define type of training strategy.""" diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index e26bd60a5c069..23dda6a4da78e 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -292,7 +292,9 @@ def _to_lite_precision(plugin: Optional[PLPrecisionPlugin]) -> LitePrecision: return LiteDoublePrecision() if type(plugin) is PLDeepSpeedPrecisionPlugin: - return LiteDeepSpeedPrecision(precision=plugin.precision) # type: ignore[arg-type] + return LiteDeepSpeedPrecision( + precision=plugin.precision, # type: ignore[arg-type] + ) if type(plugin) is PLTPUPrecisionPlugin: return LiteTPUPrecision() diff --git a/src/pytorch_lightning/plugins/precision/colossalai.py b/src/pytorch_lightning/plugins/precision/colossalai.py index 6643c8f7f803a..4e177b927629f 100644 --- a/src/pytorch_lightning/plugins/precision/colossalai.py +++ b/src/pytorch_lightning/plugins/precision/colossalai.py @@ -11,15 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, cast, Optional, Union from torch import Tensor from torch.optim import Optimizer +from typing_extensions import Literal import pytorch_lightning as pl from lightning_fabric.utilities.types import Steppable from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin -from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.rank_zero import WarningCache warning_cache = WarningCache() @@ -36,14 +36,13 @@ class ColossalAIPrecisionPlugin(PrecisionPlugin): If precison is not 16. """ - def __init__(self, precision: Union[str, int] = 16) -> None: - if not (precision == PrecisionType.HALF): + def __init__(self, precision: Literal["16", 16] = 16) -> None: + if precision not in ("16", 16): raise ValueError( f"`Trainer(strategy='colossalai', precision={precision!r})` is not supported." " Consider setting `precision=16`." ) - super().__init__() - self.precision = precision + self.precision = cast(Literal["16"], str(precision)) def backward( # type: ignore[override] self, diff --git a/src/pytorch_lightning/plugins/precision/deepspeed.py b/src/pytorch_lightning/plugins/precision/deepspeed.py index 741220d1c4962..a9438886b9c57 100644 --- a/src/pytorch_lightning/plugins/precision/deepspeed.py +++ b/src/pytorch_lightning/plugins/precision/deepspeed.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.optim import LBFGS, Optimizer +from typing_extensions import get_args, Literal import pytorch_lightning as pl -from lightning_fabric.utilities.enums import PrecisionType from lightning_fabric.utilities.types import Steppable from pytorch_lightning.plugins.precision.apex_amp import _APEX_AVAILABLE from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin @@ -33,19 +33,26 @@ warning_cache = WarningCache() +_PRECISION_INPUT_INT = Literal[32, 16] +_PRECISION_INPUT_STR = Literal["32", "16", "bf16"] +_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] + class DeepSpeedPrecisionPlugin(PrecisionPlugin): """Precision plugin for DeepSpeed integration. Args: - precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). + precision: Full precision (32), half precision (16) or bfloat16 precision (bf16). Raises: ValueError: If unsupported ``precision`` is provided. """ def __init__( - self, precision: Union[str, int], amp_type: Optional[str] = None, amp_level: Optional[str] = None + self, + precision: Literal["32", 32, "16", 16, "bf16"], + amp_type: Optional[str] = None, + amp_level: Optional[str] = None, ) -> None: if amp_type == "apex": # TODO: remove in v1.10.0 @@ -73,15 +80,13 @@ def __init__( f" in v1.10.0. This argument is no longer necessary." ) - supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT) + supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) if precision not in supported_precision: raise ValueError( f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported." - f" `precision` must be one of: {(x.value for x in supported_precision)}." + f" `precision` must be one of: {supported_precision}." ) - - super().__init__() - self.precision = precision + self.precision = cast(_PRECISION_INPUT_STR, str(precision)) self.amp_type = amp_type self.amp_level = amp_level diff --git a/src/pytorch_lightning/plugins/precision/double.py b/src/pytorch_lightning/plugins/precision/double.py index 6b52c83b88781..fb519912dcdc0 100644 --- a/src/pytorch_lightning/plugins/precision/double.py +++ b/src/pytorch_lightning/plugins/precision/double.py @@ -19,6 +19,7 @@ from lightning_utilities.core.apply_func import apply_to_collection from torch import FloatTensor, Tensor from torch.optim import Optimizer +from typing_extensions import Literal import pytorch_lightning as pl from lightning_fabric.plugins.precision.utils import _convert_fp_tensor @@ -72,7 +73,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class DoublePrecisionPlugin(PrecisionPlugin): """Plugin for training with double (``torch.float64``) precision.""" - precision: int = 64 + precision: Literal["64"] = "64" def connect( self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] diff --git a/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py b/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py index 875da77ddd35b..5489aec2679a0 100644 --- a/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Optional import torch +from typing_extensions import Literal -from lightning_fabric.utilities.enums import PrecisionType from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.plugins.precision.native_amp import MixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -31,12 +31,16 @@ class FullyShardedNativeNativeMixedPrecisionPlugin(MixedPrecisionPlugin): """Native AMP for Fully Sharded Native Training.""" - def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None: + def __init__( + self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None + ) -> None: if not _TORCH_GREATER_EQUAL_1_12: raise MisconfigurationException( "`FullyShardedNativeNativeMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards." ) - super().__init__(precision, device, scaler=ShardedGradScaler() if scaler is None and precision == 16 else None) + super().__init__( + precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None) + ) def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ @@ -51,9 +55,9 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: @property def mixed_precision_config(self) -> Optional[MixedPrecision]: assert MixedPrecision is not None - if self.precision == PrecisionType.HALF: + if self.precision == "16": dtype = torch.float16 - elif self.precision == PrecisionType.BFLOAT: + elif self.precision == "bf16": dtype = torch.bfloat16 else: raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") diff --git a/src/pytorch_lightning/plugins/precision/hpu.py b/src/pytorch_lightning/plugins/precision/hpu.py index 61fc078475460..0fbab4efefc4d 100644 --- a/src/pytorch_lightning/plugins/precision/hpu.py +++ b/src/pytorch_lightning/plugins/precision/hpu.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import cast, Optional, Union + +from typing_extensions import get_args, Literal -from lightning_fabric.utilities.enums import PrecisionType from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _HPU_AVAILABLE @@ -21,6 +22,10 @@ if _HPU_AVAILABLE: from habana_frameworks.torch.hpex import hmp +_PRECISION_INPUT_INT = Literal[32, 16] +_PRECISION_INPUT_STR = Literal["32", "16", "bf16"] +_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] + class HPUPrecisionPlugin(PrecisionPlugin): """Plugin that enables bfloat/half support on HPUs. @@ -35,7 +40,7 @@ class HPUPrecisionPlugin(PrecisionPlugin): def __init__( self, - precision: Union[str, int], + precision: _PRECISION_INPUT, opt_level: str = "O2", bf16_file_path: Optional[str] = None, fp32_file_path: Optional[str] = None, @@ -43,15 +48,14 @@ def __init__( ) -> None: if not _HPU_AVAILABLE: raise MisconfigurationException("HPU precision plugin requires HPU devices.") - supported_precision_values = (16, 32, "bf16") - if precision not in supported_precision_values: + supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + if precision not in supported_precision: raise ValueError( f"`Trainer(accelerator='hpu', precision={precision!r})` is not supported." - f" `precision` must be one of: {supported_precision_values}." + f" `precision` must be one of: {supported_precision}." ) - super().__init__() - self.precision = precision - if precision in (PrecisionType.HALF, PrecisionType.BFLOAT): + self.precision = cast(_PRECISION_INPUT_STR, str(precision)) + if self.precision in ("16", "bf16"): hmp.convert( opt_level=opt_level, bf16_file_path=bf16_file_path, fp32_file_path=fp32_file_path, isVerbose=verbose ) diff --git a/src/pytorch_lightning/plugins/precision/ipu.py b/src/pytorch_lightning/plugins/precision/ipu.py index 30e5947413af3..aec99b28e17aa 100644 --- a/src/pytorch_lightning/plugins/precision/ipu.py +++ b/src/pytorch_lightning/plugins/precision/ipu.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Union +from typing import Any, Callable, cast, Union from torch import Tensor from torch.optim import LBFGS, Optimizer +from typing_extensions import get_args, Literal import pytorch_lightning as pl -from lightning_fabric.utilities.enums import PrecisionType from lightning_fabric.utilities.types import Optimizable from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType @@ -27,6 +27,10 @@ warning_cache = WarningCache() +_PRECISION_INPUT_INT = Literal[32, 16] +_PRECISION_INPUT_STR = Literal["32", "16"] +_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] + class IPUPrecisionPlugin(PrecisionPlugin): """Precision plugin for IPU integration. @@ -36,15 +40,14 @@ class IPUPrecisionPlugin(PrecisionPlugin): If the precision is neither 16 nor 32. """ - def __init__(self, precision: int) -> None: - supported_precision_values = (PrecisionType.HALF, PrecisionType.FLOAT) - if precision not in supported_precision_values: + def __init__(self, precision: Literal["32", 32, "16", 16]) -> None: + supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + if precision not in supported_precision: raise ValueError( f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported." - f" `precision` must be one of: {supported_precision_values}." + f" `precision` must be one of: {supported_precision}." ) - super().__init__() - self.precision = precision + self.precision = cast(_PRECISION_INPUT_STR, str(precision)) def backward( # type: ignore[override] self, diff --git a/src/pytorch_lightning/plugins/precision/native_amp.py b/src/pytorch_lightning/plugins/precision/native_amp.py index 7d104e5affa91..92dc323a29238 100644 --- a/src/pytorch_lightning/plugins/precision/native_amp.py +++ b/src/pytorch_lightning/plugins/precision/native_amp.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, Optional, Union +from typing import Any, Callable, cast, Dict, Generator, Optional, Union import torch from torch import Tensor from torch.optim import LBFGS, Optimizer +from typing_extensions import Literal import pytorch_lightning as pl from lightning_fabric.accelerators.cuda import _patch_cuda_is_available @@ -37,16 +38,15 @@ class MixedPrecisionPlugin(PrecisionPlugin): """ def __init__( - self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None + self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None ) -> None: - super().__init__() - if scaler is None and precision == 16: + self.precision = cast(Literal["16", "bf16"], str(precision)) + if scaler is None and self.precision == "16": with _patch_cuda_is_available(): # if possible, we defer CUDA initialization to support strategies that will attempt forks scaler = torch.cuda.amp.GradScaler() - if scaler is not None and precision == "bf16": + if scaler is not None and self.precision == "bf16": raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.") - self.precision = precision self.device = device self.scaler = scaler diff --git a/src/pytorch_lightning/plugins/precision/precision_plugin.py b/src/pytorch_lightning/plugins/precision/precision_plugin.py index a259de0942932..ab4447011ca4e 100644 --- a/src/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/src/pytorch_lightning/plugins/precision/precision_plugin.py @@ -33,8 +33,6 @@ class PrecisionPlugin(LitePrecision, CheckpointHooks): The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. """ - precision: Union[str, int] = 32 - def connect( self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any] ) -> Tuple[Module, List[Optimizer], List[Any]]: diff --git a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py index f5d2e656fdbc5..785a80a86bc0d 100644 --- a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -13,6 +13,8 @@ # limitations under the License. from typing import Optional, Union +from typing_extensions import Literal + from lightning_fabric.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.precision.native_amp import MixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -27,13 +29,17 @@ class ShardedNativeMixedPrecisionPlugin(MixedPrecisionPlugin): """Native AMP for Sharded Training.""" - def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None: + def __init__( + self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None + ) -> None: if not _FAIRSCALE_AVAILABLE: raise MisconfigurationException( "You have asked for sharded AMP but you have not installed it." " Install `fairscale` using this guide: https://https://github.com/facebookresearch/fairscale" ) - super().__init__(precision, device, scaler=ShardedGradScaler() if scaler is None and precision == 16 else None) + super().__init__( + precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None) + ) def clip_grad_by_norm(self, optimizer: "OSS", clip_val: Union[int, float]) -> None: optimizer.clip_grad_norm(clip_val) diff --git a/src/pytorch_lightning/plugins/precision/tpu_bf16.py b/src/pytorch_lightning/plugins/precision/tpu_bf16.py index 94254313b85be..3a765b7e8d3d5 100644 --- a/src/pytorch_lightning/plugins/precision/tpu_bf16.py +++ b/src/pytorch_lightning/plugins/precision/tpu_bf16.py @@ -16,6 +16,7 @@ import torch.nn as nn from torch.optim import Optimizer +from typing_extensions import Literal from pytorch_lightning.plugins.precision import TPUPrecisionPlugin @@ -23,7 +24,7 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin): """Plugin that enables bfloats on TPUs.""" - precision: str = "bf16" + precision: Literal["bf16"] = "bf16" def connect( self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] diff --git a/src/pytorch_lightning/strategies/colossalai.py b/src/pytorch_lightning/strategies/colossalai.py index de91655743fb5..caed48bb92cae 100644 --- a/src/pytorch_lightning/strategies/colossalai.py +++ b/src/pytorch_lightning/strategies/colossalai.py @@ -32,7 +32,6 @@ from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -326,7 +325,7 @@ def setup_precision_plugin(self) -> None: def setup(self, trainer: "pl.Trainer") -> None: precision = self.precision_plugin.precision - if not (precision == PrecisionType.HALF): + if precision != "16": raise ValueError( f"`Trainer(strategy='colossalai', precision={precision!r})` is not supported." " Consider setting `precision=16`." diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 50ad4d3245173..0b4b67244f9ec 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -30,7 +30,6 @@ import pytorch_lightning as pl from lightning_fabric.plugins import ClusterEnvironment -from lightning_fabric.utilities.enums import PrecisionType from lightning_fabric.utilities.optimizer import _optimizers_to_device from lightning_fabric.utilities.seed import reset_seed from lightning_fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau @@ -515,9 +514,9 @@ def model_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: assert self._config_initialized - if self.precision_plugin.precision == PrecisionType.HALF: + if self.precision_plugin.precision == "16": dtype = torch.float16 - elif self.precision_plugin.precision == PrecisionType.BFLOAT: + elif self.precision_plugin.precision == "bf16": dtype = torch.bfloat16 else: dtype = torch.float32 @@ -652,7 +651,7 @@ def _auto_select_batch_size(self) -> int: def _format_precision_config(self) -> None: assert isinstance(self.config, dict) - if self.precision_plugin.precision == PrecisionType.HALF: + if self.precision_plugin.precision == "16": if "fp16" not in self.config and self.precision_plugin.amp_type == "native": # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") @@ -667,7 +666,7 @@ def _format_precision_config(self) -> None: elif "amp" not in self.config and self.precision_plugin.amp_type == "apex": rank_zero_info("Enabling DeepSpeed APEX Implementation.") self.config["amp"] = {"enabled": True, "opt_level": self.precision_plugin.amp_level} - elif "bf16" not in self.config and self.precision_plugin.precision == PrecisionType.BFLOAT: + elif "bf16" not in self.config and self.precision_plugin.precision == "bf16": rank_zero_info("Enabling DeepSpeed BF16.") self.config["bf16"] = {"enabled": True} diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index 5570ee2cf9e30..d876b3523002e 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -20,7 +20,6 @@ import pytorch_lightning as pl from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment from lightning_fabric.strategies.fairscale import _FAIRSCALE_AVAILABLE, _optimizer_has_flat_params -from lightning_fabric.utilities.enums import PrecisionType from lightning_fabric.utilities.optimizer import _optimizers_to_device from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -217,7 +216,7 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, flatten_parameters=self.flatten_parameters, - mixed_precision=(self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)), + mixed_precision=(self.precision_plugin.precision == "16"), reshard_after_forward=self.reshard_after_forward, fp32_reduce_scatter=self.fp32_reduce_scatter, compute_dtype=self.compute_dtype, @@ -240,7 +239,7 @@ def wrap_policy(*args: Any, **kwargs: Any) -> Any: cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, flatten_parameters=self.flatten_parameters, - mixed_precision=(precision in (PrecisionType.MIXED, PrecisionType.HALF)), + mixed_precision=(precision == "16"), reshard_after_forward=self.reshard_after_forward, fp32_reduce_scatter=self.fp32_reduce_scatter, compute_dtype=self.compute_dtype, diff --git a/src/pytorch_lightning/strategies/hivemind.py b/src/pytorch_lightning/strategies/hivemind.py index 7fa9f0e63a868..e3c9e9f6499d0 100644 --- a/src/pytorch_lightning/strategies/hivemind.py +++ b/src/pytorch_lightning/strategies/hivemind.py @@ -8,7 +8,6 @@ from torch import Tensor import pytorch_lightning as pl -from lightning_fabric.utilities.enums import PrecisionType from lightning_fabric.utilities.types import LRScheduler, ReduceLROnPlateau from pytorch_lightning.strategies.strategy import Strategy, TBroadcast from pytorch_lightning.utilities.data import extract_batch_size @@ -193,7 +192,7 @@ def is_global_zero(self) -> bool: def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() super().setup(trainer) - if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED): + if self.precision_plugin.precision == "16": self.precision_plugin.scaler = hivemind.GradScaler() def _initialize_hivemind(self) -> None: diff --git a/src/pytorch_lightning/strategies/utils.py b/src/pytorch_lightning/strategies/utils.py index 643d1aeb1b3ae..46bb385260591 100644 --- a/src/pytorch_lightning/strategies/utils.py +++ b/src/pytorch_lightning/strategies/utils.py @@ -17,10 +17,10 @@ import torch from torch import Tensor +from typing_extensions import Literal from lightning_fabric.plugins.precision.utils import _convert_fp_tensor from lightning_fabric.strategies import _StrategyRegistry -from lightning_fabric.utilities.enums import PrecisionType from lightning_fabric.utilities.registry import _is_register_method_overridden from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation @@ -41,9 +41,9 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) -> mod.register_strategies(registry) -def _fp_to_half(tensor: Tensor, precision: PrecisionType) -> Tensor: - if precision == PrecisionType.HALF: +def _fp_to_half(tensor: Tensor, precision: Literal["64", 64, "32", 32, "16", 16, "bf16"]) -> Tensor: + if str(precision) == "16": return _convert_fp_tensor(tensor, torch.half) - if precision == PrecisionType.BFLOAT: + if precision == "bf16": return _convert_fp_tensor(tensor, torch.bfloat16) return tensor diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index 3a4654a64b5c5..a616680d89953 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -15,10 +15,10 @@ import logging import os from collections import Counter -from typing import Dict, List, Optional, Union +from typing import cast, Dict, List, Optional, Union import torch -from typing_extensions import Literal +from typing_extensions import get_args, Literal from lightning_fabric.plugins.environments import ( ClusterEnvironment, @@ -90,6 +90,9 @@ import horovod.torch as hvd _LITERAL_WARN = Literal["warn"] +_PRECISION_INPUT_INT = Literal[64, 32, 16] +_PRECISION_INPUT_STR = Literal["64", "32", "16", "bf16"] +_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] class AcceleratorConnector: @@ -100,7 +103,7 @@ def __init__( accelerator: Optional[Union[str, Accelerator]] = None, strategy: Optional[Union[str, Strategy]] = None, plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, - precision: Union[int, str] = 32, + precision: _PRECISION_INPUT = 32, amp_type: Optional[str] = None, amp_level: Optional[str] = None, sync_batchnorm: bool = False, @@ -162,14 +165,13 @@ def __init__( # Get registered strategies, built-in accelerators and precision plugins self._registered_strategies = StrategyRegistry.available_strategies() self._accelerator_types = AcceleratorRegistry.available_accelerators() - self._precision_types = ("16", "32", "64", "bf16") # Raise an exception if there are conflicts between flags # Set each valid flag to `self._x_flag` after validation # For devices: Assign gpus, ipus, etc. to the accelerator flag and devices flag self._strategy_flag: Optional[Union[Strategy, str]] = None self._accelerator_flag: Optional[Union[Accelerator, str]] = None - self._precision_flag: Optional[Union[int, str]] = None + self._precision_flag: _PRECISION_INPUT_STR = "32" self._precision_plugin_flag: Optional[PrecisionPlugin] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None self._parallel_devices: List[Union[int, torch.device, str]] = [] @@ -236,7 +238,7 @@ def _check_config_and_set_final_flags( self, strategy: Optional[Union[str, Strategy]], accelerator: Optional[Union[str, Accelerator]], - precision: Union[int, str], + precision: _PRECISION_INPUT, plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]], amp_type: Optional[str], amp_level: Optional[str], @@ -285,12 +287,12 @@ def _check_config_and_set_final_flags( self._accelerator_flag = accelerator - if precision is not None: - if str(precision) not in self._precision_types: - raise MisconfigurationException( - f"Precision {repr(precision)} is invalid. Allowed precision values: {self._precision_types}" - ) - self._precision_flag = precision + supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + if precision not in supported_precision: + raise MisconfigurationException( + f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}" + ) + self._precision_flag = cast(_PRECISION_INPUT_STR, str(precision)) if plugins: plugins_flags_types: Dict[str, int] = Counter() @@ -697,10 +699,10 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if isinstance(self.accelerator, HPUAccelerator): return HPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, TPUAccelerator): - if self._precision_flag == 32: + if self._precision_flag == "32": return TPUPrecisionPlugin() - elif self._precision_flag in (16, "bf16"): - if self._precision_flag == 16: + elif self._precision_flag in ("16", "bf16"): + if self._precision_flag == "16": rank_zero_warn( "You passed `Trainer(accelerator='tpu', precision=16)` but AMP" " is not supported with TPUs. Using `precision='bf16'` instead." @@ -712,22 +714,22 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin(self._precision_flag, self._amp_type_flag, self._amp_level_flag) - if self._precision_flag == 32: + if self._precision_flag == "32": return PrecisionPlugin() - if self._precision_flag == 64: + if self._precision_flag == "64": return DoublePrecisionPlugin() - if self._precision_flag == 16 and self._accelerator_flag == "cpu": + if self._precision_flag == "16" and self._accelerator_flag == "cpu": rank_zero_warn( "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead." ) self._precision_flag = "bf16" - if self._precision_flag in (16, "bf16"): + if self._precision_flag in ("16", "bf16"): rank_zero_info( f"Using 16bit {self._amp_type_flag} Automatic Mixed Precision (AMP)" - if self._precision_flag == 16 + if self._precision_flag == "16" else "Using bfloat16 Automatic Mixed Precision (AMP)" ) @@ -751,7 +753,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin: def _validate_precision_choice(self) -> None: """Validate the combination of choices for precision, AMP type, and accelerator.""" if isinstance(self.accelerator, TPUAccelerator): - if self._precision_flag == 64: + if self._precision_flag == "64": raise MisconfigurationException( "`Trainer(accelerator='tpu', precision=64)` is not implemented." " Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`" @@ -765,12 +767,12 @@ def _validate_precision_choice(self) -> None: f" found: {self._precision_plugin_flag}." ) if isinstance(self.accelerator, HPUAccelerator): - if self._precision_flag not in (16, "bf16", 32): + if self._precision_flag not in ("16", "bf16", "32"): raise MisconfigurationException( f"`Trainer(accelerator='hpu', precision={self._precision_flag!r})` is not supported." ) if ( - self._precision_flag == 16 + self._precision_flag == "16" and isinstance(self.accelerator, CPUAccelerator) and self._amp_type_flag == "apex" ): @@ -778,7 +780,7 @@ def _validate_precision_choice(self) -> None: "You passed `Trainer(accelerator='cpu', precision=16, amp_type='apex')`" " but apex AMP not supported on CPU." ) - if self._precision_flag in (16, "bf16") and self._amp_type_flag == "apex": + if self._precision_flag in ("16", "bf16") and self._amp_type_flag == "apex": if self._precision_flag == "bf16": raise MisconfigurationException( "You passed `Trainer(amp_type='apex', precision='bf16')` but it's not supported." diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 1dae21db02550..9d405e5fb136e 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -149,7 +149,8 @@ def attach_data( def _copy_trainer_model_properties(self, model: "pl.LightningModule") -> None: model.trainer = proxy(self.trainer) - model.precision = self.trainer.precision + # for backward compatibility + model.precision = int(self.trainer.precision) if self.trainer.precision != "bf16" else "bf16" def attach_dataloaders( self, diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 591a584d6694d..00e070d36e33d 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -69,7 +69,12 @@ ) from pytorch_lightning.trainer import call, setup from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations -from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector +from pytorch_lightning.trainer.connectors.accelerator_connector import ( + _LITERAL_WARN, + _PRECISION_INPUT, + _PRECISION_INPUT_STR, + AcceleratorConnector, +) from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector @@ -146,7 +151,7 @@ def __init__( accelerator: Optional[Union[str, Accelerator]] = None, strategy: Optional[Union[str, Strategy]] = None, sync_batchnorm: bool = False, - precision: Union[int, str] = 32, + precision: _PRECISION_INPUT = 32, enable_model_summary: bool = True, num_sanity_val_steps: int = 2, resume_from_checkpoint: Optional[Union[Path, str]] = None, @@ -1768,7 +1773,7 @@ def amp_backend(self) -> Optional[str]: return None @property - def precision(self) -> Union[str, int]: + def precision(self) -> _PRECISION_INPUT_STR: return self.strategy.precision_plugin.precision @property diff --git a/src/pytorch_lightning/utilities/argparse.py b/src/pytorch_lightning/utilities/argparse.py index 8b1872ee7b643..a4c050b77ef20 100644 --- a/src/pytorch_lightning/utilities/argparse.py +++ b/src/pytorch_lightning/utilities/argparse.py @@ -140,7 +140,15 @@ def get_init_arguments_and_types(cls: _ARGPARSE_CLS) -> List[Tuple[str, Tuple, A arg_type = cls_default_params[arg].annotation arg_default = cls_default_params[arg].default try: - arg_types = tuple(arg_type.__args__) + if type(arg_type).__name__ == "_LiteralGenericAlias": + # Special case: Literal[a, b, c, ...] + arg_types = tuple({type(a) for a in arg_type.__args__}) + elif "typing.Literal" in str(arg_type) or "typing_extensions.Literal" in str(arg_type): + # Special case: Union[Literal, ...] + arg_types = tuple({type(a) for union_args in arg_type.__args__ for a in union_args.__args__}) + else: + # Special case: ComposedType[type0, type1, ...] + arg_types = tuple(arg_type.__args__) except (AttributeError, TypeError): arg_types = (arg_type,) diff --git a/src/pytorch_lightning/utilities/enums.py b/src/pytorch_lightning/utilities/enums.py index 460735c5d0a58..e6e473988d070 100644 --- a/src/pytorch_lightning/utilities/enums.py +++ b/src/pytorch_lightning/utilities/enums.py @@ -18,7 +18,7 @@ from enum import Enum, EnumMeta from typing import Any -from lightning_fabric.utilities.enums import LightningEnum, PrecisionType # noqa: F401 +from lightning_fabric.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation @@ -48,6 +48,30 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: return obj +class PrecisionType(LightningEnum, metaclass=_DeprecatedEnumMeta): + """Type of precision used.""" + + HALF = "16" + FLOAT = "32" + FULL = "64" + BFLOAT = "bf16" + MIXED = "mixed" + + @staticmethod + def supported_type(precision: str | int) -> bool: + return any(x == precision for x in PrecisionType) + + @staticmethod + def supported_types() -> list[str]: + return [x.value for x in PrecisionType] + + def deprecate(self) -> None: + rank_zero_deprecation( + f"The `{type(self).__name__}` enum has been deprecated in v1.9.0 and will be removed in v1.10.0." + f" Use the string value `{self.value!r}` instead." + ) + + class AMPType(LightningEnum, metaclass=_DeprecatedEnumMeta): """Type of Automatic Mixed Precision used for training.""" diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py index 10ceb937d0497..75c04fe2e9d6f 100644 --- a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py +++ b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py @@ -21,7 +21,7 @@ @RunIf(deepspeed=True) @pytest.mark.parametrize("precision", ["bf16", 16, 32]) -def test_deepspeed_precision_choice(precision, tmpdir): +def test_deepspeed_precision_choice(precision): """Test to ensure precision plugin is correctly chosen. DeepSpeed handles precision via custom DeepSpeedPrecision. @@ -34,4 +34,4 @@ def test_deepspeed_precision_choice(precision, tmpdir): assert isinstance(connector.strategy, DeepSpeedStrategy) assert isinstance(connector.strategy.precision, DeepSpeedPrecision) - assert connector.strategy.precision.precision == precision + assert connector.strategy.precision.precision == str(precision) diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 99249bd247305..fc3eefd719eb5 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -47,7 +47,7 @@ def _step(lite, model, batch): assert isinstance(forward_module, FullyShardedDataParallel) assert isinstance(lite._precision, FSDPPrecision) - precision = torch.float16 if lite._precision.precision == 16 else torch.bfloat16 + precision = torch.float16 if lite._precision.precision == "16" else torch.bfloat16 assert forward_module.mixed_precision.param_dtype == precision assert forward_module.mixed_precision.reduce_dtype == precision assert forward_module.mixed_precision.buffer_dtype == precision diff --git a/tests/tests_fabric/utilities/test_enums.py b/tests/tests_fabric/utilities/test_enums.py deleted file mode 100644 index 30f34546d91fb..0000000000000 --- a/tests/tests_fabric/utilities/test_enums.py +++ /dev/null @@ -1,9 +0,0 @@ -from lightning_fabric.utilities.enums import PrecisionType - - -def test_precision_supported_types(): - assert PrecisionType.supported_types() == ["16", "32", "64", "bf16", "mixed"] - assert PrecisionType.supported_type(16) - assert PrecisionType.supported_type("16") - assert not PrecisionType.supported_type(1) - assert not PrecisionType.supported_type("invalid") diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index 238e348d105ba..b43bc14a2b52f 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -197,7 +197,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non default_root_dir=tmpdir, fast_dev_run=True, accelerator="ipu", devices=1, precision=16, callbacks=TestCallback() ) assert isinstance(trainer.strategy.precision_plugin, IPUPrecisionPlugin) - assert trainer.strategy.precision_plugin.precision == 16 + assert trainer.strategy.precision_plugin.precision == "16" with pytest.raises(SystemExit): trainer.fit(model) @@ -206,7 +206,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non def test_pure_half_precision(tmpdir): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - assert trainer.strategy.precision_plugin.precision == 16 + assert trainer.strategy.precision_plugin.precision == "16" for param in trainer.strategy.model.parameters(): assert param.dtype == torch.float16 raise SystemExit @@ -219,7 +219,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: assert isinstance(trainer.strategy, IPUStrategy) assert isinstance(trainer.strategy.precision_plugin, IPUPrecisionPlugin) - assert trainer.strategy.precision_plugin.precision == 16 + assert trainer.strategy.precision_plugin.precision == "16" changed_dtypes = [torch.float, torch.float64] data = [torch.zeros((1), dtype=dtype) for dtype in changed_dtypes] @@ -557,7 +557,7 @@ def test_precision_plugin(): """Ensure precision plugin value is set correctly.""" plugin = IPUPrecisionPlugin(precision=16) - assert plugin.precision == 16 + assert plugin.precision == "16" @RunIf(ipu=True) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 3ca94e19c8037..100157787a284 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -70,7 +70,7 @@ sync_ddp_if_available, tpu_distributed, ) -from pytorch_lightning.utilities.enums import AMPType +from pytorch_lightning.utilities.enums import AMPType, PrecisionType from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device from pytorch_lightning.utilities.seed import pl_worker_init_function, reset_seed, seed_everything from pytorch_lightning.utilities.xla_device import inner_f, pl_multi_process, XLADeviceUtils @@ -431,3 +431,8 @@ def test_pick_single_gpu(_): RuntimeError ): pick_single_gpu([]) + + +def test_deprecated_precision_type(): + with pytest.deprecated_call(match="PrecisionType` enum has been deprecated in v1.9"): + _ = PrecisionType.HALF diff --git a/tests/tests_pytorch/helpers/deterministic_model.py b/tests/tests_pytorch/helpers/deterministic_model.py index 25c6a3aa9afd2..fff8445f618dd 100644 --- a/tests/tests_pytorch/helpers/deterministic_model.py +++ b/tests/tests_pytorch/helpers/deterministic_model.py @@ -112,7 +112,7 @@ def configure_optimizers__lr_on_plateau_step(self): def backward(self, loss, optimizer, optimizer_idx): if self.assert_backward: - if self.trainer.precision == 16: + if self.trainer.precision == "16": assert loss > 171 * 1000 else: assert loss == 171.0 diff --git a/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py b/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py index d28fb3434c629..6131befa24427 100644 --- a/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py +++ b/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py @@ -65,7 +65,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non def test_pure_half_precision(tmpdir, hmp_params: dict): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - assert trainer.strategy.model.precision == 16 + assert trainer.strategy.model.precision == "16" for param in trainer.strategy.model.parameters(): assert param.dtype == torch.float16 raise SystemExit @@ -83,7 +83,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: assert isinstance(trainer.strategy, SingleHPUStrategy) assert isinstance(trainer.strategy.precision_plugin, HPUPrecisionPlugin) - assert trainer.strategy.precision_plugin.precision == 16 + assert trainer.strategy.precision_plugin.precision == "16" with pytest.raises(RuntimeError, match=r"float16/half is not supported on Gaudi."): trainer.fit(model) diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index c639b8b92dc5f..c1d2761937b2c 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -161,7 +161,7 @@ def test_deepspeed_precision_choice(cuda_count_1, amp_backend, tmpdir): assert isinstance(trainer.strategy, DeepSpeedStrategy) assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.strategy.precision_plugin.precision == 16 + assert trainer.strategy.precision_plugin.precision == "16" @RunIf(deepspeed=True) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index fa3d527ff920e..b35ee92ff1852 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -650,7 +650,7 @@ def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch): monkeypatch.setattr(ipu, "_IPU_AVAILABLE", True) with pytest.raises(ValueError, match=r"accelerator='ipu', precision='bf16'\)` is not supported"): Trainer(accelerator="ipu", precision="bf16") - with pytest.raises(ValueError, match=r"accelerator='ipu', precision=64\)` is not supported"): + with pytest.raises(ValueError, match=r"accelerator='ipu', precision='64'\)` is not supported"): Trainer(accelerator="ipu", precision=64)