Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
c0846da
different precision types
awaelchli Sep 22, 2022
d9b413f
imports
awaelchli Sep 22, 2022
7099e05
reset
awaelchli Dec 26, 2022
a42d8d9
update
awaelchli Dec 26, 2022
e7a6ac7
Revert "reset"
awaelchli Dec 26, 2022
408136b
update
awaelchli Dec 26, 2022
dfc4f3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2022
d0a25e0
update
awaelchli Dec 26, 2022
89456bb
update
awaelchli Dec 26, 2022
863b7a8
update
awaelchli Dec 26, 2022
9c946c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2022
050ad67
precision type
awaelchli Dec 26, 2022
6c07f4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2022
e9f0c14
fix
awaelchli Dec 26, 2022
76259a1
update
awaelchli Dec 26, 2022
958481a
fixes
awaelchli Dec 26, 2022
5de00e3
fixes
awaelchli Dec 26, 2022
707040c
fix
awaelchli Dec 26, 2022
0fca84e
fix
awaelchli Dec 26, 2022
d9be90b
update
awaelchli Dec 26, 2022
e963ba7
ordering
awaelchli Jan 3, 2023
94ebb9c
update
awaelchli Jan 4, 2023
a222744
fix deepspeed
awaelchli Jan 4, 2023
51ec174
update
awaelchli Jan 4, 2023
efef3b4
fixes
awaelchli Jan 4, 2023
f7f2f01
update
awaelchli Jan 4, 2023
684f3eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2023
dae8807
boring model
awaelchli Jan 5, 2023
700bd80
update
awaelchli Jan 5, 2023
23bef59
literal support
awaelchli Jan 5, 2023
e46db02
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2023
760368a
special case comment
awaelchli Jan 5, 2023
4d42a87
Avoid type ignores in the connector
carmocca Jan 5, 2023
32bd574
Same changes for the deepspeed precision at fabric
carmocca Jan 5, 2023
492851a
Base class should use literal too
carmocca Jan 5, 2023
1137f33
Apply the same changes to PL
carmocca Jan 5, 2023
ecdbeaa
Fix mypy
carmocca Jan 5, 2023
0715ea4
Hack to avoid protected import failing
carmocca Jan 5, 2023
bb946f4
Hack to support Union[Literal]
carmocca Jan 5, 2023
3b31665
define precision inputs in precision base file
awaelchli Jan 6, 2023
ae41864
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2023
2953300
fix import error
awaelchli Jan 6, 2023
1a7f20c
Merge branch 'master' into lite/precision-str-int
awaelchli Jan 6, 2023
eea4631
notebook
awaelchli Jan 6, 2023
f4747c6
notebook
awaelchli Jan 6, 2023
88ce471
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2023
c03b41f
typing why..
awaelchli Jan 6, 2023
3f38e3e
unused import
awaelchli Jan 6, 2023
0fd6c64
deprecation
awaelchli Jan 6, 2023
72dd301
python 3.7
awaelchli Jan 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 19 additions & 26 deletions src/lightning_fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.
import os
from collections import Counter
from typing import Any, Dict, List, Optional, Union
from typing import Any, cast, Dict, List, Optional, Union

import torch
from typing_extensions import Literal
from typing_extensions import get_args

from lightning_fabric.accelerators import ACCELERATOR_REGISTRY
from lightning_fabric.accelerators.accelerator import Accelerator
Expand All @@ -41,6 +41,7 @@
)
from lightning_fabric.plugins.precision.double import DoublePrecision
from lightning_fabric.plugins.precision.fsdp import FSDPPrecision
from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR
from lightning_fabric.strategies import (
DDPShardedStrategy,
DDPStrategy,
Expand All @@ -59,7 +60,6 @@

_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
_PLUGIN_INPUT = Union[_PLUGIN, str]
_PRECISION_INPUT = Literal[16, 32, 64, "bf16"]


class _Connector:
Expand Down Expand Up @@ -113,14 +113,13 @@ def __init__(
# Get registered strategies, built-in accelerators and precision plugins
self._registered_strategies = STRATEGY_REGISTRY.available_strategies()
self._registered_accelerators = ACCELERATOR_REGISTRY.available_accelerators()
self._precision_types = ("16", "32", "64", "bf16")

# Raise an exception if there are conflicts between flags
# Set each valid flag to `self._x_flag` after validation
# For devices: Assign gpus, etc. to the accelerator flag and devices flag
self._strategy_flag: Optional[Union[Strategy, str]] = None
self._accelerator_flag: Optional[Union[Accelerator, str]] = None
self._precision_input: Optional[_PRECISION_INPUT] = None
self._precision_input: _PRECISION_INPUT_STR = "32"
self._precision_instance: Optional[Precision] = None
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
self._parallel_devices: List[Union[int, torch.device, str]] = []
Expand Down Expand Up @@ -206,12 +205,10 @@ def _check_config_and_set_final_flags(

self._accelerator_flag = accelerator

if precision is not None:
if str(precision) not in self._precision_types:
raise ValueError(
f"Precision {repr(precision)} is invalid. Allowed precision values: {self._precision_types}"
)
self._precision_input = precision
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
if precision not in supported_precision:
raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}")
self._precision_input = cast(_PRECISION_INPUT_STR, str(precision))

if plugins:
plugins_flags_types: Dict[str, int] = Counter()
Expand Down Expand Up @@ -442,10 +439,10 @@ def _check_and_init_precision(self) -> Precision:
return self._precision_instance

if isinstance(self.accelerator, TPUAccelerator):
if self._precision_input == 32:
if self._precision_input == "32":
return TPUPrecision()
elif self._precision_input in (16, "bf16"):
if self._precision_input == 16:
elif self._precision_input in ("16", "bf16"):
if self._precision_input == "16":
rank_zero_warn(
"You passed `Fabric(accelerator='tpu', precision=16)` but AMP"
" is not supported with TPUs. Using `precision='bf16'` instead."
Expand All @@ -454,22 +451,22 @@ def _check_and_init_precision(self) -> Precision:
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore

if self._precision_input == 32:
if self._precision_input == "32":
return Precision()
if self._precision_input == 64:
if self._precision_input == "64":
return DoublePrecision()

if self._precision_input == 16 and self._accelerator_flag == "cpu":
if self._precision_input == "16" and self._accelerator_flag == "cpu":
rank_zero_warn(
"You passed `Fabric(accelerator='cpu', precision=16)` but native AMP is not supported on CPU."
" Using `precision='bf16'` instead."
)
self._precision_input = "bf16"

if self._precision_input in (16, "bf16"):
if self._precision_input in ("16", "bf16"):
rank_zero_info(
"Using 16-bit Automatic Mixed Precision (AMP)"
if self._precision_input == 16
if self._precision_input == "16"
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
Expand All @@ -483,7 +480,7 @@ def _check_and_init_precision(self) -> Precision:
def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, and accelerator."""
if isinstance(self.accelerator, TPUAccelerator):
if self._precision_input == 64:
if self._precision_input == "64":
raise NotImplementedError(
"`Fabric(accelerator='tpu', precision=64)` is not implemented."
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
Expand Down Expand Up @@ -536,16 +533,12 @@ def _lazy_init_strategy(self) -> None:

@staticmethod
def _argument_from_env(name: str, current: Any, default: Any) -> Any:
env_value: Optional[Union[str, int]] = os.environ.get("LT_" + name.upper())
env_value: Optional[str] = os.environ.get("LT_" + name.upper())

if env_value is None:
return current

if name == "precision":
# TODO: support precision input as string, then this special handling is not needed
env_value = int(env_value) if env_value in ("16", "32", "64") else env_value

if env_value is not None and env_value != current and current != default:
if env_value is not None and env_value != str(current) and str(current) != str(default):
raise ValueError(
f"Your code has `Fabric({name}={current!r}, ...)` but it conflicts with the value "
f"`--{name}={current}` set through the CLI. "
Expand Down
21 changes: 11 additions & 10 deletions src/lightning_fabric/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, TYPE_CHECKING
from typing import Any, cast, TYPE_CHECKING, Union

import torch
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import Literal
from typing_extensions import get_args, Literal

from lightning_fabric.plugins.precision.precision import Precision
from lightning_fabric.plugins.precision.utils import _convert_fp_tensor
from lightning_fabric.utilities.enums import PrecisionType
from lightning_fabric.utilities.types import Steppable

_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
import deepspeed

_PRECISION_INPUT_INT = Literal[32, 16]
_PRECISION_INPUT_STR = Literal["32", "16", "bf16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]


class DeepSpeedPrecision(Precision):
"""Precision plugin for DeepSpeed integration.
Expand All @@ -39,19 +42,17 @@ class DeepSpeedPrecision(Precision):
If unsupported ``precision`` is provided.
"""

def __init__(self, precision: Literal[16, 32, "bf16"]) -> None:
supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT)
def __init__(self, precision: _PRECISION_INPUT) -> None:
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in DeepSpeed."
f" `precision` must be one of: {(x.value for x in supported_precision)}."
f" `precision` must be one of: {supported_precision}."
)

super().__init__()
self.precision = precision
self.precision = cast(_PRECISION_INPUT_STR, str(precision))

def convert_input(self, data: Tensor) -> Tensor:
precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16, 32: torch.float32}
precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16, "32": torch.float32}
dst_type = precision_to_type[self.precision]
return _convert_fp_tensor(data, dst_type)

Expand Down
3 changes: 2 additions & 1 deletion src/lightning_fabric/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from torch import Tensor
from torch.nn import Module
from typing_extensions import Literal

from lightning_fabric.plugins.precision.precision import Precision
from lightning_fabric.plugins.precision.utils import _convert_fp_tensor
Expand All @@ -25,7 +26,7 @@
class DoublePrecision(Precision):
"""Plugin for training with double (``torch.float64``) precision."""

precision: int = 64
precision: Literal["64"] = "64"

def convert_module(self, module: Module) -> Module:
return module.double()
Expand Down
9 changes: 4 additions & 5 deletions src/lightning_fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing_extensions import Literal

from lightning_fabric.plugins.precision.native_amp import MixedPrecision
from lightning_fabric.utilities.enums import PrecisionType
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12

if TYPE_CHECKING:
Expand All @@ -29,7 +28,7 @@ class FSDPPrecision(MixedPrecision):
"""AMP for Fully Sharded Data Parallel training."""

def __init__(
self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None
) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
raise NotImplementedError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.")
Expand All @@ -39,16 +38,16 @@ def __init__(
super().__init__(
precision=precision,
device=device,
scaler=(ShardedGradScaler() if scaler is None and precision == 16 else None),
scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None),
)

@property
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

if self.precision == PrecisionType.HALF:
if self.precision == "16":
dtype = torch.float16
elif self.precision == PrecisionType.BFLOAT:
elif self.precision == "bf16":
dtype = torch.bfloat16
else:
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")
Expand Down
13 changes: 6 additions & 7 deletions src/lightning_fabric/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Dict, Generator, Optional
from typing import Any, cast, Dict, Generator, Optional

import torch
from torch import Tensor
Expand All @@ -36,16 +36,15 @@ class MixedPrecision(Precision):
"""

def __init__(
self, precision: Literal[16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
super().__init__()
if scaler is None and precision == 16:
self.precision = cast(Literal["16", "bf16"], str(precision))
if scaler is None and self.precision == "16":
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and precision == "bf16":
if scaler is not None and self.precision == "bf16":
raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.")
self.precision = precision
self.device = device
self.scaler = scaler

Expand All @@ -55,7 +54,7 @@ def forward_context(self) -> Generator[None, None, None]:
yield

def convert_input(self, data: Tensor) -> Tensor:
precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16}
precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16}
dst_type = precision_to_type[self.precision]
return _convert_fp_tensor(data, dst_type)

Expand Down
7 changes: 6 additions & 1 deletion src/lightning_fabric/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import Literal

from lightning_fabric.plugins.precision.utils import _convert_fp_tensor
from lightning_fabric.utilities.types import _PARAMETERS, Optimizable

_PRECISION_INPUT_INT = Literal[64, 32, 16]
_PRECISION_INPUT_STR = Literal["64", "32", "16", "bf16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]


class Precision:
"""Base class for all plugins handling the precision-specific parts of the training.

The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.
"""

precision: Union[str, int] = 32
precision: _PRECISION_INPUT_STR = "32"

def convert_module(self, module: Module) -> Module:
"""Convert the module parameters to the precision type this plugin handles.
Expand Down
3 changes: 2 additions & 1 deletion src/lightning_fabric/plugins/precision/tpu_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal

from lightning_fabric.plugins.precision import TPUPrecision
from lightning_fabric.plugins.precision.utils import _convert_fp_tensor
Expand All @@ -23,7 +24,7 @@
class TPUBf16Precision(TPUPrecision):
"""Plugin that enables bfloats on TPUs."""

precision: str = "bf16"
precision: Literal["bf16"] = "bf16"

def __init__(self) -> None:
super().__init__()
Expand Down
9 changes: 4 additions & 5 deletions src/lightning_fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from lightning_fabric.strategies.ddp import DDPStrategy
from lightning_fabric.strategies.strategy import _Sharded
from lightning_fabric.utilities.distributed import log
from lightning_fabric.utilities.enums import PrecisionType
from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only
from lightning_fabric.utilities.seed import reset_seed
from lightning_fabric.utilities.types import _PATH
Expand Down Expand Up @@ -349,9 +348,9 @@ def module_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3:
assert self._config_initialized

if self.precision.precision == PrecisionType.HALF:
if self.precision.precision == "16":
dtype = torch.float16
elif self.precision.precision == PrecisionType.BFLOAT:
elif self.precision.precision == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
Expand Down Expand Up @@ -499,7 +498,7 @@ def _format_config(self) -> None:

def _format_precision_config(self) -> None:
assert isinstance(self.config, dict)
if self.precision.precision == PrecisionType.HALF:
if self.precision.precision == "16":
if "fp16" not in self.config:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")
Expand All @@ -511,7 +510,7 @@ def _format_precision_config(self) -> None:
"hysteresis": self.hysteresis,
"min_loss_scale": self.min_loss_scale,
}
elif "bf16" not in self.config and self.precision.precision == PrecisionType.BFLOAT:
elif "bf16" not in self.config and self.precision.precision == "bf16":
rank_zero_info("Enabling DeepSpeed BF16.")
self.config["bf16"] = {"enabled": True}

Expand Down
3 changes: 1 addition & 2 deletions src/lightning_fabric/strategies/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from lightning_fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning_fabric.strategies.ddp import DDPStrategy
from lightning_fabric.strategies.strategy import _BackwardSyncControl
from lightning_fabric.utilities.enums import PrecisionType
from lightning_fabric.utilities.imports import _IS_WINDOWS

_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and module_available("fairscale.nn")
Expand Down Expand Up @@ -116,7 +115,7 @@ def _reinit_optimizers_with_oss(optimizers: List[Optimizer], precision: Precisio
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
is_fp16 = precision.precision in (PrecisionType.MIXED, PrecisionType.HALF)
is_fp16 = precision.precision == "16"
# For multi-node training, compressing the model shards in fp16 before broadcasting
# improves performance. When using PyTorch AMP, it will not degrade
# the model performance.
Expand Down
Loading