Skip to content

Commit 7903625

Browse files
author
Sisil Mehta
committed
[FSDP] Adding Native FSDP Strategy
1 parent c099c8b commit 7903625

File tree

4 files changed

+407
-2
lines changed

4 files changed

+407
-2
lines changed

pytorch_lightning/strategies/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
2020
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy # noqa: F401
2121
from pytorch_lightning.strategies.dp import DataParallelStrategy # noqa: F401
22+
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy # noqa: F401
2223
from pytorch_lightning.strategies.fully_sharded import DDPFullyShardedStrategy # noqa: F401
2324
from pytorch_lightning.strategies.horovod import HorovodStrategy # noqa: F401
2425
from pytorch_lightning.strategies.ipu import IPUStrategy # noqa: F401
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import contextlib
15+
import logging
16+
import os
17+
from typing import Union, Any, Generator, Dict, List, Optional
18+
19+
import pytorch_lightning as pl
20+
import torch
21+
from pytorch_lightning.overrides.distributed import prepare_for_backward
22+
from pytorch_lightning.plugins.environments.cluster_environment import (
23+
ClusterEnvironment,
24+
)
25+
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
26+
from pytorch_lightning.plugins.precision import PrecisionPlugin
27+
from pytorch_lightning.strategies.parallel import ParallelStrategy
28+
from pytorch_lightning.utilities import rank_zero_only
29+
from pytorch_lightning.utilities.distributed import (
30+
init_dist_connection,
31+
sync_ddp_if_available,
32+
ReduceOp,
33+
group as _group,
34+
_get_process_group_backend_from_env,
35+
distributed_available,
36+
get_default_process_group_backend_for_device,
37+
)
38+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
39+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
40+
from pytorch_lightning.utilities.optimizer import optimizers_to_device
41+
from pytorch_lightning.utilities.seed import reset_seed
42+
from torch.distributed.distributed_c10d import _get_default_group
43+
from torch.distributed.fsdp.fully_sharded_data_parallel import (
44+
BackwardPrefetch,
45+
CPUOffload,
46+
FullyShardedDataParallel,
47+
)
48+
from torch.distributed.fsdp.wrap import enable_wrap
49+
50+
51+
log = logging.getLogger(__name__)
52+
53+
54+
class DDPFullyShardedNativeStrategy(ParallelStrategy):
55+
56+
strategy_name = "fsdp_native"
57+
58+
def __init__(
59+
self,
60+
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
61+
cpu_offload: Optional[CPUOffload] = None,
62+
backward_prefetch: Optional[BackwardPrefetch] = None,
63+
parallel_devices: Optional[List[torch.device]] = None,
64+
cluster_environment: Optional[ClusterEnvironment] = None,
65+
checkpoint_io: Optional[CheckpointIO] = None,
66+
precision_plugin: Optional[PrecisionPlugin] = None,
67+
process_group_backend: Optional[str] = None,
68+
) -> None:
69+
"""Plugin for Fully Sharded Data Parallel provided by Pytorch.Distributed.
70+
71+
Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model
72+
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
73+
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar
74+
to ZeRO-Stage 3.
75+
`For more information: https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/`.
76+
.. warning:: ``FullyShardedPlugin`` is in beta and subject to change.
77+
78+
Defaults have been set and options have been exposed, but may require configuration
79+
based on your level of memory/speed efficiency. We suggest having a look at this tutorial for
80+
more information.
81+
`https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html`
82+
83+
Arguments:
84+
cpu_offload (Optional [CPUOffload]):
85+
CPU offloading config. Currently, only parameter and gradient CPU
86+
offload is supported. It can be enabled via passing in
87+
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
88+
currently implicitly enables gradient offloading to CPU in order for
89+
params and grads to be on same device to work with optimizer. This
90+
API is subject to change. Default is ``None`` in which case there
91+
will be no offloading.
92+
backward_prefetch: (Optional[BackwardPrefetch]):
93+
This is an experimental feature that is subject to change in the
94+
the near future. It allows users to enable two different backward_prefetch
95+
algorithms to help backward communication and computation overlapping.
96+
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
97+
"""
98+
super().__init__(
99+
accelerator=accelerator,
100+
parallel_devices=parallel_devices,
101+
cluster_environment=cluster_environment,
102+
checkpoint_io=checkpoint_io,
103+
precision_plugin=precision_plugin,
104+
)
105+
self._process_group = None
106+
self.num_processes = (
107+
len(self.parallel_devices) if self.parallel_devices is not None else 0
108+
)
109+
self._has_loaded_state_dict: bool = False
110+
self._process_group_backend: Optional[str] = process_group_backend
111+
self.cpu_offload = cpu_offload
112+
self.backward_prefetch = backward_prefetch
113+
114+
@property
115+
def root_device(self) -> torch.device:
116+
return self.parallel_devices[self.local_rank]
117+
118+
@property
119+
def process_group(self):
120+
if self._process_group is None:
121+
# The plugin should have already initilized ddp in setup_environment()
122+
self._process_group = _get_default_group()
123+
return self._process_group
124+
125+
@property
126+
def setup_optimizers_in_pre_dispatch(self) -> bool:
127+
# Setup optimizers after the Fully Sharded Model has been made
128+
return True
129+
130+
@property
131+
def process_group_backend(self) -> Optional[str]:
132+
return self._process_group_backend
133+
134+
def setup_environment(self) -> None:
135+
self.setup_distributed()
136+
super().setup_environment()
137+
138+
def setup_distributed(self) -> None:
139+
if not self.root_device.type == "cuda":
140+
raise MisconfigurationException(
141+
"You selected strategy to be `ddp_fully_sharded_native`, but GPU is not available."
142+
)
143+
reset_seed()
144+
# set warning rank
145+
rank_zero_only.rank = self.global_rank
146+
147+
self._process_group_backend = self._get_process_group_backend()
148+
init_dist_connection(self.cluster_environment, self._process_group_backend)
149+
150+
def _get_process_group_backend(self) -> str:
151+
return (
152+
self._process_group_backend
153+
or _get_process_group_backend_from_env()
154+
or get_default_process_group_backend_for_device(self.root_device)
155+
)
156+
157+
def setup(self, trainer: "pl.Trainer") -> None:
158+
self.accelerator.setup(trainer)
159+
160+
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
161+
self.model = self._layer_sync.apply(self.model)
162+
163+
if not self.cpu_offload:
164+
# When using CPU Offload, FSDP will manage the CUDA movement for us.
165+
# Note: this would be problematic for large model (which could not fit in one GPU)
166+
# as FSDP module.to(device) would first summon all parameters
167+
self.model_to_device()
168+
169+
self.barrier()
170+
self.setup_optimizers(trainer)
171+
optimizers_to_device(self.optimizers, self.root_device)
172+
self.setup_precision_plugin()
173+
174+
def model_to_device(self) -> None:
175+
# ensure we update the device type in the lightning module
176+
log.info(
177+
f"{self.__class__.__name__}: moving model to device [{self.root_device}]..."
178+
)
179+
self.lightning_module.to(self.root_device)
180+
181+
@contextlib.contextmanager
182+
def model_sharded_context(self) -> Generator:
183+
log.detail(f"{self.__class__.__name__}: entered model_sharded_context.")
184+
185+
with enable_wrap(
186+
wrapper_cls=FullyShardedDataParallel,
187+
process_group=self.process_group,
188+
cpu_offload=self.cpu_offload,
189+
backward_prefetch=self.backward_prefetch,
190+
):
191+
yield
192+
193+
def barrier(self, *args, **kwargs) -> None:
194+
if not distributed_available():
195+
return
196+
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
197+
torch.distributed.barrier(device_ids=self._determine_device_ids())
198+
else:
199+
torch.distributed.barrier()
200+
201+
def broadcast(self, obj: object, src: int = 0) -> object:
202+
obj = [obj]
203+
if self.global_rank != src:
204+
obj = [None]
205+
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
206+
return obj[0]
207+
208+
def reduce(
209+
self,
210+
tensor,
211+
group: Optional[Any] = None,
212+
reduce_op: Union[ReduceOp, str] = "mean",
213+
) -> torch.Tensor:
214+
"""Reduces a tensor from several distributed processes to one aggregated tensor.
215+
216+
Args:
217+
tensor: the tensor to sync and reduce
218+
group: the process group to gather results from. Defaults to all processes (world)
219+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
220+
Can also be a string 'sum' to calculate the sum during reduction.
221+
222+
Return:
223+
reduced value, except when the input was not a tensor the output remains is unchanged
224+
"""
225+
if isinstance(tensor, torch.Tensor):
226+
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
227+
return tensor
228+
229+
def pre_backward(self, closure_loss: torch.Tensor) -> None:
230+
"""Run before precision plugin executes backward."""
231+
if not self.lightning_module.automatic_optimization:
232+
prepare_for_backward(self.model, closure_loss)
233+
234+
def _determine_device_ids(self):
235+
if self.root_device.type == "cpu":
236+
return None
237+
return [self.root_device.index]
238+
239+
def teardown(self) -> None:
240+
log.info(f"{self.__class__.__name__}: tearing down plugin...")
241+
super().teardown()
242+
243+
if self._layer_sync:
244+
self.model = self._layer_sync.revert(self.model)
245+
246+
if self.root_device.type == "cuda":
247+
# GPU teardown
248+
if not os.environ.get("PL_SKIP_CPU_COPY_ON_DDP_TEARDOWN"):
249+
log.info(f"{self.__class__.__name__}: moving model to CPU...")
250+
self.lightning_module.cpu()
251+
# clean up memory
252+
torch.cuda.empty_cache()
253+
254+
self._has_loaded_state_dict = False
255+
256+
@classmethod
257+
def register_strategies(cls, strategy_registry: Dict) -> None:
258+
strategy_registry.register(
259+
"fsdp_native",
260+
cls,
261+
description="Fully Sharded Data Parallel training from pytorch.distributed.",
262+
)
263+
strategy_registry.register(
264+
"fsdp_native_full_shard_offload",
265+
cls,
266+
description="Native FSDP with Full Sharding and CPU Offloading",
267+
cpu_offload=CPUOffload(offload_params=True),
268+
)

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm
5151
from pytorch_lightning.strategies import (
5252
DDP2Strategy,
53+
DDPFullyShardedNativeStrategy,
5354
DDPFullyShardedStrategy,
5455
DDPShardedStrategy,
5556
DDPSpawnShardedStrategy,
@@ -686,7 +687,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
686687

687688
if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)):
688689
return ShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
689-
if isinstance(self.strategy, DDPFullyShardedStrategy):
690+
if isinstance(self.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)):
690691
return FullyShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
691692
return NativeMixedPrecisionPlugin(self._precision_flag, device)
692693

@@ -727,7 +728,7 @@ def _validate_precision_choice(self) -> None:
727728
"it's not supported. Try using `amp_type='native'` instead."
728729
)
729730
if self._precision_flag in (16, "bf16") and self._amp_type_flag == AMPType.APEX:
730-
if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy, DDPFullyShardedStrategy)):
731+
if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy, DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)):
731732
raise MisconfigurationException(
732733
"Sharded plugins are not supported with apex, please switch to `amp_backend='native'`."
733734
)
@@ -813,6 +814,7 @@ def is_distributed(self) -> bool:
813814
DDPStrategy,
814815
DDPSpawnShardedStrategy,
815816
DDPShardedStrategy,
817+
DDPFullyShardedNativeStrategy,
816818
DDPFullyShardedStrategy,
817819
DDPSpawnStrategy,
818820
DeepSpeedStrategy,

0 commit comments

Comments
 (0)