2323from pytorch_lightning .plugins .environments .cluster_environment import ClusterEnvironment
2424from pytorch_lightning .plugins .io .checkpoint_plugin import CheckpointIO
2525from pytorch_lightning .plugins .precision import PrecisionPlugin
26+ from pytorch_lightning .plugins .precision .fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
27+ from pytorch_lightning .strategies .launchers .subprocess_script import _SubprocessScriptLauncher
2628from pytorch_lightning .strategies .parallel import ParallelStrategy
2729from pytorch_lightning .strategies .strategy import TBroadcast
2830from pytorch_lightning .trainer .states import TrainerFn
3537from pytorch_lightning .utilities .distributed import group as _group
3638from pytorch_lightning .utilities .distributed import init_dist_connection , ReduceOp , sync_ddp_if_available
3739from pytorch_lightning .utilities .exceptions import MisconfigurationException
38- from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_11
40+ from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_12
3941from pytorch_lightning .utilities .optimizer import optimizers_to_device
42+ from pytorch_lightning .utilities .rank_zero import rank_zero_info
4043from pytorch_lightning .utilities .seed import reset_seed
4144
42- if _TORCH_GREATER_EQUAL_1_11 :
45+ if _TORCH_GREATER_EQUAL_1_12 :
4346 from torch .distributed .fsdp .fully_sharded_data_parallel import (
4447 BackwardPrefetch ,
4548 CPUOffload ,
4649 FullyShardedDataParallel ,
50+ MixedPrecision ,
4751 )
4852 from torch .distributed .fsdp .wrap import enable_wrap
49-
53+ else :
54+ MixedPrecision = None
55+ BackwardPrefetch = None # type: ignore[misc,assignment]
56+ CPUOffload = None # type: ignore[misc,assignment]
5057
5158log = logging .getLogger (__name__ )
5259
@@ -56,18 +63,20 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
5663 strategy_name = "fsdp_native"
5764 _registered_strategies : List [str ] = []
5865
59- def __init__ ( # type: ignore[no-untyped-def]
66+ def __init__ (
6067 self ,
6168 accelerator : Optional ["pl.accelerators.accelerator.Accelerator" ] = None ,
6269 parallel_devices : Optional [List [torch .device ]] = None ,
6370 cluster_environment : Optional [ClusterEnvironment ] = None ,
6471 checkpoint_io : Optional [CheckpointIO ] = None ,
6572 precision_plugin : Optional [PrecisionPlugin ] = None ,
6673 process_group_backend : Optional [str ] = None ,
67- cpu_offload = None ,
68- backward_prefetch = None ,
74+ cpu_offload : Optional [CPUOffload ] = None ,
75+ backward_prefetch : Optional [BackwardPrefetch ] = None ,
76+ mixed_precision : Optional [MixedPrecision ] = None ,
77+ ** kwargs : Any ,
6978 ) -> None :
70- """Strategy for Fully Sharded Data Parallel provided by torch.Distributed.
79+ r """Strategy for Fully Sharded Data Parallel provided by torch.Distributed.
7180
7281 Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
7382 size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
@@ -84,22 +93,29 @@ def __init__( # type: ignore[no-untyped-def]
8493 `https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html`
8594
8695 Arguments:
87- cpu_offload (Optional [CPUOffload]) :
96+ cpu_offload:
8897 CPU offloading config. Currently, only parameter and gradient CPU
8998 offload is supported. It can be enabled via passing in
9099 ``cpu_offload=CPUOffload(offload_params=True)``. Note that this
91100 currently implicitly enables gradient offloading to CPU in order for
92101 params and grads to be on same device to work with optimizer. This
93102 API is subject to change. Default is ``None`` in which case there
94103 will be no offloading.
95- backward_prefetch: (Optional[BackwardPrefetch]):
104+ backward_prefetch:
96105 This is an experimental feature that is subject to change in the
97106 the near future. It allows users to enable two different backward_prefetch
98107 algorithms to help backward communication and computation overlapping.
99108 Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
109+ mixed_precision:
110+ Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`
111+ or BF16 if ``precision=bf16`` unless a config is passed in.
112+ This is only available in PyTorch 1.12 and later.
113+ \**kwargs: Passed to the FSDP Context manager which will configure the FSDP class when wrapping modules.
100114 """
101- if not _TORCH_GREATER_EQUAL_1_11 :
102- raise MisconfigurationException ("DDPFullyShardedNativeStrategy is supported from pytorch v1.11.0 onwards." )
115+ if not _TORCH_GREATER_EQUAL_1_12 :
116+ raise MisconfigurationException (
117+ "`DDPFullyShardedNativeStrategy` is supported from PyTorch v1.12.0 onwards."
118+ )
103119
104120 super ().__init__ (
105121 accelerator = accelerator ,
@@ -109,16 +125,23 @@ def __init__( # type: ignore[no-untyped-def]
109125 precision_plugin = precision_plugin ,
110126 )
111127 self ._process_group = None
112- self .num_processes = len (self .parallel_devices ) if self .parallel_devices is not None else 0
113- self ._process_group_backend : Optional [str ] = process_group_backend
114- self .cpu_offload : Optional [CPUOffload ] = cpu_offload
115- self .backward_prefetch : Optional [BackwardPrefetch ] = backward_prefetch
128+ self .num_nodes = 1
129+ self ._process_group_backend = process_group_backend
130+ self .cpu_offload = cpu_offload
131+ self .backward_prefetch = backward_prefetch
132+ self .mixed_precision = mixed_precision
133+ self ._rank_0_will_call_children_scripts : bool = False
134+ self .kwargs = kwargs
116135
117136 @property
118137 def root_device (self ) -> torch .device :
119138 assert self .parallel_devices is not None
120139 return self .parallel_devices [self .local_rank ]
121140
141+ @property
142+ def num_processes (self ) -> int :
143+ return len (self .parallel_devices ) if self .parallel_devices is not None else 0
144+
122145 @property
123146 def process_group (self ) -> Optional [ProcessGroup ]:
124147 if self ._process_group is None :
@@ -130,10 +153,28 @@ def process_group(self) -> Optional[ProcessGroup]:
130153 def process_group_backend (self ) -> Optional [str ]:
131154 return self ._process_group_backend
132155
156+ @property
157+ def mixed_precision_config (self ) -> Optional [MixedPrecision ]:
158+ if self .mixed_precision :
159+ return self .mixed_precision
160+ plugin = self .precision_plugin
161+ if isinstance (plugin , FullyShardedNativeMixedPrecisionPlugin ):
162+ return plugin .mixed_precision_config
163+
164+ @property
165+ def distributed_sampler_kwargs (self ) -> Dict :
166+ return dict (num_replicas = (self .num_nodes * self .num_processes ), rank = self .global_rank )
167+
133168 def setup_environment (self ) -> None :
169+ log .detail (f"{ self .__class__ .__name__ } : setting up distributed..." )
134170 reset_seed ()
171+
172+ # determine which process we are and world size
173+ self .set_world_ranks ()
174+
135175 # set warning rank
136176 rank_zero_only .rank = self .global_rank
177+
137178 self ._process_group_backend = self ._get_process_group_backend ()
138179 assert self .cluster_environment is not None
139180 init_dist_connection (self .cluster_environment , self ._process_group_backend )
@@ -146,36 +187,51 @@ def _get_process_group_backend(self) -> str:
146187 or get_default_process_group_backend_for_device (self .root_device )
147188 )
148189
190+ def set_world_ranks (self ) -> None :
191+ if self .cluster_environment is None :
192+ return
193+ self .cluster_environment .set_global_rank (self .node_rank * self .num_processes + self .local_rank )
194+ self .cluster_environment .set_world_size (self .num_nodes * self .num_processes )
195+ rank_zero_only .rank = self .cluster_environment .global_rank ()
196+
197+ def _configure_launcher (self ) -> None :
198+ assert self .cluster_environment is not None
199+ if not self .cluster_environment .creates_processes_externally :
200+ self ._launcher = _SubprocessScriptLauncher (self .cluster_environment , self .num_processes , self .num_nodes )
201+ self ._rank_0_will_call_children_scripts = True
202+
149203 def setup (self , trainer : "pl.Trainer" ) -> None :
150204 self .accelerator .setup (trainer )
205+ # share ddp pids to all processes
206+ self ._rank_0_will_call_children_scripts = self .broadcast (self ._rank_0_will_call_children_scripts )
151207
152208 if trainer .state .fn == TrainerFn .FITTING and self ._layer_sync :
153209 assert self .model is not None
154210 self .model = self ._layer_sync .apply (self .model )
155211
156- if not self .cpu_offload :
157- self .model_to_device ()
212+ # we set the device so that optimizers can be created with distributed comms.
213+ assert self .lightning_module is not None
214+ self .lightning_module ._device = self .root_device
158215
159216 self .barrier ()
160217 self .setup_optimizers (trainer )
161218 optimizers_to_device (self .optimizers , self .root_device )
162219 self .setup_precision_plugin ()
163220
164221 def model_to_device (self ) -> None :
165- # ensure we update the device type in the lightning module
166- assert self .lightning_module is not None
167- log .info (f"{ self .__class__ .__name__ } : moving model to device [{ self .root_device } ]..." )
168- self .lightning_module .to (self .root_device )
222+ pass
169223
170224 @contextlib .contextmanager
171225 def model_sharded_context (self ) -> Generator :
172226 log .detail (f"{ self .__class__ .__name__ } : entered model_sharded_context." )
173-
174227 with enable_wrap (
175228 wrapper_cls = FullyShardedDataParallel ,
176229 process_group = self .process_group ,
177230 cpu_offload = self .cpu_offload ,
178231 backward_prefetch = self .backward_prefetch ,
232+ mixed_precision = self .mixed_precision_config ,
233+ device_id = self .root_device .index ,
234+ ** self .kwargs ,
179235 ):
180236 yield
181237
@@ -219,7 +275,7 @@ def _determine_device_ids(self) -> List[int]:
219275 return [self .root_device .index ]
220276
221277 def teardown (self ) -> None :
222- log . info (f"{ self .__class__ .__name__ } : tearing down strategy..." )
278+ rank_zero_info (f"{ self .__class__ .__name__ } : tearing down strategy..." )
223279 if (
224280 self .lightning_module is not None
225281 and self .lightning_module .trainer is not None
@@ -229,15 +285,18 @@ def teardown(self) -> None:
229285 assert self .model is not None
230286 self .model = self ._layer_sync .revert (self .model )
231287
232- super ().teardown ()
288+ assert self .cluster_environment is not None
289+ self .cluster_environment .teardown ()
290+ self .precision_plugin .teardown ()
291+ self .accelerator .teardown ()
233292
234293 @classmethod
235294 def get_registered_strategies (cls ) -> List [str ]:
236295 return cls ._registered_strategies
237296
238297 @classmethod
239298 def register_strategies (cls , strategy_registry : Dict ) -> None :
240- if _TORCH_GREATER_EQUAL_1_11 :
299+ if _TORCH_GREATER_EQUAL_1_12 :
241300 strategy_registry .register (
242301 "fsdp_native" ,
243302 cls ,
0 commit comments