Skip to content

Commit ef4c67e

Browse files
author
Sisil Mehta
committed
[FSDP] Adding Native FSDP Strategy
1 parent c099c8b commit ef4c67e

File tree

4 files changed

+415
-1
lines changed

4 files changed

+415
-1
lines changed

pytorch_lightning/strategies/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pathlib import Path
15+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
16+
1517

1618
from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
1719
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
1820
from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401
1921
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
2022
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy # noqa: F401
2123
from pytorch_lightning.strategies.dp import DataParallelStrategy # noqa: F401
24+
if _TORCH_GREATER_EQUAL_1_11:
25+
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy # noqa: F401
2226
from pytorch_lightning.strategies.fully_sharded import DDPFullyShardedStrategy # noqa: F401
2327
from pytorch_lightning.strategies.horovod import HorovodStrategy # noqa: F401
2428
from pytorch_lightning.strategies.ipu import IPUStrategy # noqa: F401
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import contextlib
15+
import logging
16+
import os
17+
from typing import Union, Any, Generator, Dict, List, Optional
18+
19+
import pytorch_lightning as pl
20+
import torch
21+
from pytorch_lightning.overrides.distributed import prepare_for_backward
22+
from pytorch_lightning.plugins.environments.cluster_environment import (
23+
ClusterEnvironment,
24+
)
25+
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
26+
from pytorch_lightning.plugins.precision import PrecisionPlugin
27+
from pytorch_lightning.strategies.parallel import ParallelStrategy
28+
from pytorch_lightning.utilities import rank_zero_only
29+
from pytorch_lightning.utilities.distributed import (
30+
init_dist_connection,
31+
sync_ddp_if_available,
32+
ReduceOp,
33+
group as _group,
34+
_get_process_group_backend_from_env,
35+
distributed_available,
36+
get_default_process_group_backend_for_device,
37+
)
38+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
39+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
40+
from pytorch_lightning.utilities.optimizer import optimizers_to_device
41+
from pytorch_lightning.utilities.seed import reset_seed
42+
from torch.distributed.distributed_c10d import _get_default_group
43+
from torch.distributed.fsdp.fully_sharded_data_parallel import (
44+
BackwardPrefetch,
45+
CPUOffload,
46+
FullyShardedDataParallel,
47+
)
48+
from torch.distributed.fsdp.wrap import enable_wrap
49+
50+
51+
log = logging.getLogger(__name__)
52+
53+
54+
class DDPFullyShardedNativeStrategy(ParallelStrategy):
55+
56+
strategy_name = "fsdp_native"
57+
58+
def __init__(
59+
self,
60+
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
61+
cpu_offload: Optional[CPUOffload] = None,
62+
backward_prefetch: Optional[BackwardPrefetch] = None,
63+
parallel_devices: Optional[List[torch.device]] = None,
64+
cluster_environment: Optional[ClusterEnvironment] = None,
65+
checkpoint_io: Optional[CheckpointIO] = None,
66+
precision_plugin: Optional[PrecisionPlugin] = None,
67+
process_group_backend: Optional[str] = None,
68+
) -> None:
69+
"""Plugin for Fully Sharded Data Parallel provided by Pytorch.Distributed.
70+
71+
Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model
72+
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
73+
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar
74+
to ZeRO-Stage 3.
75+
`For more information: https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/`.
76+
77+
.. warning:: ``DDPFullyShardedNativeStrategy`` is in beta and subject to change. The interface can
78+
bring breaking changes and new features with the next release of Pytorch.
79+
80+
Defaults have been set and options have been exposed, but may require configuration
81+
based on your level of memory/speed efficiency. We suggest having a look at this tutorial for
82+
more information.
83+
`https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html`
84+
85+
Arguments:
86+
cpu_offload (Optional [CPUOffload]):
87+
CPU offloading config. Currently, only parameter and gradient CPU
88+
offload is supported. It can be enabled via passing in
89+
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
90+
currently implicitly enables gradient offloading to CPU in order for
91+
params and grads to be on same device to work with optimizer. This
92+
API is subject to change. Default is ``None`` in which case there
93+
will be no offloading.
94+
backward_prefetch: (Optional[BackwardPrefetch]):
95+
This is an experimental feature that is subject to change in the
96+
the near future. It allows users to enable two different backward_prefetch
97+
algorithms to help backward communication and computation overlapping.
98+
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
99+
"""
100+
super().__init__(
101+
accelerator=accelerator,
102+
parallel_devices=parallel_devices,
103+
cluster_environment=cluster_environment,
104+
checkpoint_io=checkpoint_io,
105+
precision_plugin=precision_plugin,
106+
)
107+
self._process_group = None
108+
self.num_processes = (
109+
len(self.parallel_devices) if self.parallel_devices is not None else 0
110+
)
111+
self._has_loaded_state_dict: bool = False
112+
self._process_group_backend: Optional[str] = process_group_backend
113+
self.cpu_offload = cpu_offload
114+
self.backward_prefetch = backward_prefetch
115+
116+
@property
117+
def root_device(self) -> torch.device:
118+
return self.parallel_devices[self.local_rank]
119+
120+
@property
121+
def process_group(self):
122+
if self._process_group is None:
123+
# The plugin should have already initilized ddp in setup_environment()
124+
self._process_group = _get_default_group()
125+
return self._process_group
126+
127+
@property
128+
def setup_optimizers_in_pre_dispatch(self) -> bool:
129+
# Setup optimizers after the Fully Sharded Model has been made
130+
return True
131+
132+
@property
133+
def process_group_backend(self) -> Optional[str]:
134+
return self._process_group_backend
135+
136+
def setup_environment(self) -> None:
137+
self.setup_distributed()
138+
super().setup_environment()
139+
140+
def setup_distributed(self) -> None:
141+
if not self.root_device.type == "cuda":
142+
raise MisconfigurationException(
143+
"You selected strategy to be `ddp_fully_sharded_native`, but GPU is not available."
144+
)
145+
reset_seed()
146+
# set warning rank
147+
rank_zero_only.rank = self.global_rank
148+
149+
self._process_group_backend = self._get_process_group_backend()
150+
init_dist_connection(self.cluster_environment, self._process_group_backend)
151+
152+
def _get_process_group_backend(self) -> str:
153+
return (
154+
self._process_group_backend
155+
or _get_process_group_backend_from_env()
156+
or get_default_process_group_backend_for_device(self.root_device)
157+
)
158+
159+
def setup(self, trainer: "pl.Trainer") -> None:
160+
self.accelerator.setup(trainer)
161+
162+
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
163+
self.model = self._layer_sync.apply(self.model)
164+
165+
if not self.cpu_offload:
166+
# When using CPU Offload, FSDP will manage the CUDA movement for us.
167+
# Note: this would be problematic for large model (which could not fit in one GPU)
168+
# as FSDP module.to(device) would first summon all parameters
169+
self.model_to_device()
170+
171+
self.barrier()
172+
self.setup_optimizers(trainer)
173+
optimizers_to_device(self.optimizers, self.root_device)
174+
self.setup_precision_plugin()
175+
176+
def model_to_device(self) -> None:
177+
# ensure we update the device type in the lightning module
178+
log.info(
179+
f"{self.__class__.__name__}: moving model to device [{self.root_device}]..."
180+
)
181+
self.lightning_module.to(self.root_device)
182+
183+
@contextlib.contextmanager
184+
def model_sharded_context(self) -> Generator:
185+
log.detail(f"{self.__class__.__name__}: entered model_sharded_context.")
186+
187+
with enable_wrap(
188+
wrapper_cls=FullyShardedDataParallel,
189+
process_group=self.process_group,
190+
cpu_offload=self.cpu_offload,
191+
backward_prefetch=self.backward_prefetch,
192+
):
193+
yield
194+
195+
def barrier(self, *args, **kwargs) -> None:
196+
if not distributed_available():
197+
return
198+
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
199+
torch.distributed.barrier(device_ids=self._determine_device_ids())
200+
else:
201+
torch.distributed.barrier()
202+
203+
def broadcast(self, obj: object, src: int = 0) -> object:
204+
obj = [obj]
205+
if self.global_rank != src:
206+
obj = [None]
207+
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
208+
return obj[0]
209+
210+
def reduce(
211+
self,
212+
tensor,
213+
group: Optional[Any] = None,
214+
reduce_op: Union[ReduceOp, str] = "mean",
215+
) -> torch.Tensor:
216+
"""Reduces a tensor from several distributed processes to one aggregated tensor.
217+
218+
Args:
219+
tensor: the tensor to sync and reduce
220+
group: the process group to gather results from. Defaults to all processes (world)
221+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
222+
Can also be a string 'sum' to calculate the sum during reduction.
223+
224+
Return:
225+
reduced value, except when the input was not a tensor the output remains is unchanged
226+
"""
227+
if isinstance(tensor, torch.Tensor):
228+
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
229+
return tensor
230+
231+
def pre_backward(self, closure_loss: torch.Tensor) -> None:
232+
"""Run before precision plugin executes backward."""
233+
if not self.lightning_module.automatic_optimization:
234+
prepare_for_backward(self.model, closure_loss)
235+
236+
def _determine_device_ids(self):
237+
if self.root_device.type == "cpu":
238+
return None
239+
return [self.root_device.index]
240+
241+
def teardown(self) -> None:
242+
log.info(f"{self.__class__.__name__}: tearing down plugin...")
243+
super().teardown()
244+
245+
if self._layer_sync:
246+
self.model = self._layer_sync.revert(self.model)
247+
248+
if self.root_device.type == "cuda":
249+
# GPU teardown
250+
if not os.environ.get("PL_SKIP_CPU_COPY_ON_DDP_TEARDOWN"):
251+
log.info(f"{self.__class__.__name__}: moving model to CPU...")
252+
self.lightning_module.cpu()
253+
# clean up memory
254+
torch.cuda.empty_cache()
255+
256+
self._has_loaded_state_dict = False
257+
258+
@classmethod
259+
def register_strategies(cls, strategy_registry: Dict) -> None:
260+
strategy_registry.register(
261+
"fsdp_native",
262+
cls,
263+
description="Fully Sharded Data Parallel training from pytorch.distributed.",
264+
)
265+
strategy_registry.register(
266+
"fsdp_native_full_shard_offload",
267+
cls,
268+
description="Native FSDP with Full Sharding and CPU Offloading",
269+
cpu_offload=CPUOffload(offload_params=True),
270+
)

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm
5151
from pytorch_lightning.strategies import (
5252
DDP2Strategy,
53+
DDPFullyShardedNativeStrategy,
5354
DDPFullyShardedStrategy,
5455
DDPShardedStrategy,
5556
DDPSpawnShardedStrategy,
@@ -688,6 +689,8 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
688689
return ShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
689690
if isinstance(self.strategy, DDPFullyShardedStrategy):
690691
return FullyShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
692+
if isinstance(self.strategy, DDPFullyShardedNativeStrategy):
693+
raise MisconfigurationException("DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision")
691694
return NativeMixedPrecisionPlugin(self._precision_flag, device)
692695

693696
if self._amp_type_flag == AMPType.APEX:
@@ -727,7 +730,7 @@ def _validate_precision_choice(self) -> None:
727730
"it's not supported. Try using `amp_type='native'` instead."
728731
)
729732
if self._precision_flag in (16, "bf16") and self._amp_type_flag == AMPType.APEX:
730-
if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy, DDPFullyShardedStrategy)):
733+
if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy, DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)):
731734
raise MisconfigurationException(
732735
"Sharded plugins are not supported with apex, please switch to `amp_backend='native'`."
733736
)
@@ -813,6 +816,7 @@ def is_distributed(self) -> bool:
813816
DDPStrategy,
814817
DDPSpawnShardedStrategy,
815818
DDPShardedStrategy,
819+
DDPFullyShardedNativeStrategy,
816820
DDPFullyShardedStrategy,
817821
DDPSpawnStrategy,
818822
DeepSpeedStrategy,

0 commit comments

Comments
 (0)