Skip to content

Commit 80d24fe

Browse files
committed
typing fixes
1 parent a2130b9 commit 80d24fe

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

src/lightning_lite/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def _check_strategy_and_fallback(self) -> None:
414414
strategy_flag in _FSDP_ALIASES or isinstance(self._strategy_flag, FSDPStrategy)
415415
) and self._accelerator_flag not in ("cuda", "gpu"):
416416
raise ValueError(
417-
f"You selected the FSDP strategy but FSDP is only available on GPU. Set `Lite(accelerator='gpu', ...)`"
417+
"You selected the FSDP strategy but FSDP is only available on GPU. Set `Lite(accelerator='gpu', ...)`"
418418
" to continue or select a different strategy."
419419
)
420420
if strategy_flag:

src/lightning_lite/plugins/precision/fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
class FSDPPrecision(NativeMixedPrecision):
2828
"""AMP for Fully Sharded Data Parallel training."""
2929

30-
def __init__(self, precision: Literal[16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None) -> None:
30+
def __init__(self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None) -> None:
3131
if not _TORCH_GREATER_EQUAL_1_12:
3232
raise RuntimeError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.")
3333

src/lightning_lite/strategies/fsdp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
FullyShardedDataParallel,
4141
MixedPrecision,
4242
)
43-
from torch.distributed.fsdp.wrap import enable_wrap
43+
from torch.distributed.fsdp.wrap import enable_wrap # noqa: F401
4444

4545
_FSDP_ALIASES = ("fsdp", "fsdp_full_shard_offload")
4646

@@ -86,9 +86,9 @@ def __init__(
8686
precision_plugin: Optional[Precision] = None,
8787
process_group_backend: Optional[str] = None,
8888
timeout: Optional[timedelta] = default_pg_timeout,
89-
cpu_offload: Optional[CPUOffload] = None,
90-
backward_prefetch: Optional[BackwardPrefetch] = None,
91-
mixed_precision: Optional[MixedPrecision] = None,
89+
cpu_offload: Optional["CPUOffload"] = None,
90+
backward_prefetch: Optional["BackwardPrefetch"] = None,
91+
mixed_precision: Optional["MixedPrecision"] = None,
9292
**kwargs: Any,
9393
) -> None:
9494
if not _TORCH_GREATER_EQUAL_1_12:
@@ -156,7 +156,7 @@ def setup_environment(self) -> None:
156156
self._setup_distributed()
157157
super().setup_environment()
158158

159-
def setup_module(self, module: Module) -> FullyShardedDataParallel:
159+
def setup_module(self, module: Module) -> "FullyShardedDataParallel":
160160
"""Wraps the model into a
161161
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
162162
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

0 commit comments

Comments
 (0)