|
40 | 40 | FullyShardedDataParallel, |
41 | 41 | MixedPrecision, |
42 | 42 | ) |
43 | | - from torch.distributed.fsdp.wrap import enable_wrap |
| 43 | + from torch.distributed.fsdp.wrap import enable_wrap # noqa: F401 |
44 | 44 |
|
45 | 45 | _FSDP_ALIASES = ("fsdp", "fsdp_full_shard_offload") |
46 | 46 |
|
@@ -86,9 +86,9 @@ def __init__( |
86 | 86 | precision_plugin: Optional[Precision] = None, |
87 | 87 | process_group_backend: Optional[str] = None, |
88 | 88 | timeout: Optional[timedelta] = default_pg_timeout, |
89 | | - cpu_offload: Optional[CPUOffload] = None, |
90 | | - backward_prefetch: Optional[BackwardPrefetch] = None, |
91 | | - mixed_precision: Optional[MixedPrecision] = None, |
| 89 | + cpu_offload: Optional["CPUOffload"] = None, |
| 90 | + backward_prefetch: Optional["BackwardPrefetch"] = None, |
| 91 | + mixed_precision: Optional["MixedPrecision"] = None, |
92 | 92 | **kwargs: Any, |
93 | 93 | ) -> None: |
94 | 94 | if not _TORCH_GREATER_EQUAL_1_12: |
@@ -156,7 +156,7 @@ def setup_environment(self) -> None: |
156 | 156 | self._setup_distributed() |
157 | 157 | super().setup_environment() |
158 | 158 |
|
159 | | - def setup_module(self, module: Module) -> FullyShardedDataParallel: |
| 159 | + def setup_module(self, module: Module) -> "FullyShardedDataParallel": |
160 | 160 | """Wraps the model into a |
161 | 161 | :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" |
162 | 162 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel |
|
0 commit comments