Skip to content

Commit cc65718

Browse files
committed
imports
1 parent 80d24fe commit cc65718

File tree

4 files changed

+6
-1
lines changed

4 files changed

+6
-1
lines changed

src/lightning_lite/plugins/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from lightning_lite.plugins.io.xla_plugin import XLACheckpointIO
1919
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
2020
from lightning_lite.plugins.precision.double import DoublePrecision
21+
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
2122
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
2223
from lightning_lite.plugins.precision.precision import Precision
2324
from lightning_lite.plugins.precision.tpu import TPUPrecision

src/lightning_lite/plugins/precision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
1515
from lightning_lite.plugins.precision.double import DoublePrecision
16+
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
1617
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
1718
from lightning_lite.plugins.precision.precision import Precision
1819
from lightning_lite.plugins.precision.tpu import TPUPrecision

src/lightning_lite/plugins/precision/fsdp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
class FSDPPrecision(NativeMixedPrecision):
2828
"""AMP for Fully Sharded Data Parallel training."""
2929

30-
def __init__(self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None) -> None:
30+
def __init__(
31+
self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None
32+
) -> None:
3133
if not _TORCH_GREATER_EQUAL_1_12:
3234
raise RuntimeError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.")
3335

src/lightning_lite/strategies/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lightning_lite.strategies.dp import DataParallelStrategy # noqa: F401
1818
from lightning_lite.strategies.fairscale import DDPShardedStrategy # noqa: F401
1919
from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy # noqa: F401
20+
from lightning_lite.strategies.fsdp import FSDPStrategy # noqa: F401
2021
from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401
2122
from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry
2223
from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401

0 commit comments

Comments
 (0)