1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import logging
1514from contextlib import contextmanager
1615from datetime import timedelta
17- from typing import Any , Dict , Generator , List , Optional , Union
16+ from typing import Any , Dict , Generator , List , Optional , Union , TYPE_CHECKING
1817
1918import torch
2019from torch import Tensor
2120from torch .distributed import default_pg_timeout
2221from torch .nn import Module
2322
24- import pytorch_lightning as pl
2523from lightning_lite .accelerators import Accelerator
2624from lightning_lite .plugins import CheckpointIO , ClusterEnvironment
25+ from lightning_lite .plugins .precision .fsdp import FSDPPrecision
2726from lightning_lite .utilities .distributed import get_default_process_group_backend_for_device , distributed_available
2827from lightning_lite .utilities .distributed import group as _group
2928from lightning_lite .utilities .distributed import init_dist_connection , ReduceOp , sync_ddp_if_available
30- from lightning_lite .utilities .optimizer import optimizers_to_device
3129from lightning_lite .utilities .seed import reset_seed
3230from lightning_lite .plugins import Precision
33- from pytorch_lightning .plugins .precision .fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
34- from pytorch_lightning .strategies .launchers .subprocess_script import _SubprocessScriptLauncher
31+ from lightning_lite .strategies .launchers .subprocess_script import _SubprocessScriptLauncher
3532from lightning_lite .strategies .parallel import ParallelStrategy
36- from pytorch_lightning .strategies .strategy import TBroadcast
37- from pytorch_lightning .trainer .states import TrainerFn
38- from pytorch_lightning .utilities .exceptions import MisconfigurationException
39- from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_12
40- from pytorch_lightning .utilities .model_helpers import is_overridden
41- from pytorch_lightning .utilities .rank_zero import rank_zero_info , rank_zero_only
42- from pytorch_lightning .utilities .types import ProcessGroup , STEP_OUTPUT
43-
44- _distributed_available = torch .distributed .is_available ()
45- _fsdp_available = _TORCH_GREATER_EQUAL_1_12 and _distributed_available
46- if _fsdp_available :
33+ from lightning_lite .strategies .strategy import TBroadcast
34+ from lightning_lite .utilities .imports import _TORCH_GREATER_EQUAL_1_12
35+ from lightning_lite .utilities .rank_zero import rank_zero_only
36+
37+ if TYPE_CHECKING :
4738 from torch .distributed .fsdp .fully_sharded_data_parallel import (
4839 BackwardPrefetch ,
4940 CPUOffload ,
5041 FullyShardedDataParallel ,
5142 MixedPrecision ,
5243 )
5344 from torch .distributed .fsdp .wrap import enable_wrap
54- else :
55- FullyShardedDataParallel = None # type: ignore[misc,assignment]
56- MixedPrecision = None # type: ignore[misc,assignment]
57- BackwardPrefetch = None # type: ignore[misc,assignment]
58- CPUOffload = None # type: ignore[misc,assignment]
59-
60- if _distributed_available :
61- from torch .distributed .distributed_c10d import _get_default_group
62-
63- log = logging .getLogger (__name__ )
6445
6546
6647class FSDPStrategy (ParallelStrategy ):
@@ -120,9 +101,7 @@ def __init__(
120101 ** kwargs : Any ,
121102 ) -> None :
122103 if not _TORCH_GREATER_EQUAL_1_12 :
123- raise MisconfigurationException (
124- "`FSDPStrategy` is supported from PyTorch v1.12.0 onwards."
125- )
104+ raise RuntimeError ("`FSDPStrategy` is supported from PyTorch v1.12.0 onwards." )
126105
127106 super ().__init__ (
128107 accelerator = accelerator ,
@@ -169,13 +148,13 @@ def distributed_sampler_kwargs(self) -> Dict:
169148 def process_group_backend (self ) -> Optional [str ]:
170149 return self ._process_group_backend
171150
172- # @property
173- # def mixed_precision_config(self) -> Optional[MixedPrecision]:
174- # if self.mixed_precision:
175- # return self.mixed_precision
176- # plugin = self.precision_plugin
177- # if isinstance(plugin, FullyShardedNativeNativeMixedPrecisionPlugin ):
178- # return plugin.mixed_precision_config
151+ @property
152+ def mixed_precision_config (self ) -> Optional [MixedPrecision ]:
153+ if self .mixed_precision :
154+ return self .mixed_precision
155+ plugin = self .precision_plugin
156+ if isinstance (plugin , FSDPPrecision ):
157+ return plugin .mixed_precision_config
179158
180159 def _configure_launcher (self ) -> None :
181160 assert self .cluster_environment is not None
@@ -189,6 +168,7 @@ def setup_environment(self) -> None:
189168 def setup_module (self , module : Module ) -> FullyShardedDataParallel :
190169 """Wraps the model into a
191170 :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
171+ from torch .distributed .fsdp .fully_sharded_data_parallel import FullyShardedDataParallel
192172 if (
193173 any (isinstance (mod , FullyShardedDataParallel ) for mod in module .modules ())
194174 and "auto_wrap_policy" in self ._ddp_kwargs
@@ -209,12 +189,14 @@ def module_to_device(self, module: Module) -> None:
209189
210190 @contextmanager
211191 def module_sharded_context (self ) -> Generator :
192+ from torch .distributed .fsdp .fully_sharded_data_parallel import FullyShardedDataParallel
193+ from torch .distributed .fsdp .wrap import enable_wrap
194+
212195 with enable_wrap (
213196 wrapper_cls = FullyShardedDataParallel ,
214- # process_group=self.process_group,
215197 cpu_offload = self .cpu_offload ,
216198 backward_prefetch = self .backward_prefetch ,
217- mixed_precision = self .precision_plugin . mixed_precision_config ,
199+ mixed_precision = self .mixed_precision_config ,
218200 device_id = self .root_device .index ,
219201 ** self ._ddp_kwargs ,
220202 ):
@@ -244,6 +226,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
244226
245227 @classmethod
246228 def register_strategies (cls , strategy_registry : Dict ) -> None :
229+ from torch .distributed .fsdp .fully_sharded_data_parallel import CPUOffload
230+
247231 strategy_registry .register (
248232 "fsdp" ,
249233 cls ,
0 commit comments