|
| 1 | +# Copyright The PyTorch Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import contextlib |
| 15 | +import logging |
| 16 | +import os |
| 17 | +from typing import Union, Any, Generator, Dict, List, Optional |
| 18 | + |
| 19 | +import pytorch_lightning as pl |
| 20 | +import torch |
| 21 | +from pytorch_lightning.overrides.distributed import prepare_for_backward |
| 22 | +from pytorch_lightning.plugins.environments.cluster_environment import ( |
| 23 | + ClusterEnvironment, |
| 24 | +) |
| 25 | +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO |
| 26 | +from pytorch_lightning.plugins.precision import PrecisionPlugin |
| 27 | +from pytorch_lightning.strategies.parallel import ParallelStrategy |
| 28 | +from pytorch_lightning.utilities import rank_zero_only |
| 29 | +from pytorch_lightning.utilities.distributed import ( |
| 30 | + init_dist_connection, |
| 31 | + sync_ddp_if_available, |
| 32 | + ReduceOp, |
| 33 | + group as _group, |
| 34 | + _get_process_group_backend_from_env, |
| 35 | + distributed_available, |
| 36 | + get_default_process_group_backend_for_device, |
| 37 | +) |
| 38 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 39 | +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 |
| 40 | +from pytorch_lightning.utilities.optimizer import optimizers_to_device |
| 41 | +from pytorch_lightning.utilities.seed import reset_seed |
| 42 | +from torch.distributed.distributed_c10d import _get_default_group |
| 43 | +from torch.distributed.fsdp.fully_sharded_data_parallel import ( |
| 44 | + BackwardPrefetch, |
| 45 | + CPUOffload, |
| 46 | + FullyShardedDataParallel, |
| 47 | +) |
| 48 | +from torch.distributed.fsdp.wrap import enable_wrap |
| 49 | + |
| 50 | + |
| 51 | +log = logging.getLogger(__name__) |
| 52 | + |
| 53 | + |
| 54 | +class DDPFullyShardedNativeStrategy(ParallelStrategy): |
| 55 | + |
| 56 | + strategy_name = "fsdp_native" |
| 57 | + |
| 58 | + def __init__( |
| 59 | + self, |
| 60 | + accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, |
| 61 | + cpu_offload: Optional[CPUOffload] = None, |
| 62 | + backward_prefetch: Optional[BackwardPrefetch] = None, |
| 63 | + parallel_devices: Optional[List[torch.device]] = None, |
| 64 | + cluster_environment: Optional[ClusterEnvironment] = None, |
| 65 | + checkpoint_io: Optional[CheckpointIO] = None, |
| 66 | + precision_plugin: Optional[PrecisionPlugin] = None, |
| 67 | + process_group_backend: Optional[str] = None, |
| 68 | + ) -> None: |
| 69 | + """Plugin for Fully Sharded Data Parallel provided by Pytorch.Distributed. |
| 70 | +
|
| 71 | + Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model |
| 72 | + size, whilst using efficient communication to reduce overhead. In practice, this means we can remain |
| 73 | + at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar |
| 74 | + to ZeRO-Stage 3. |
| 75 | + `For more information: https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/`. |
| 76 | + .. warning:: ``FullyShardedPlugin`` is in beta and subject to change. |
| 77 | +
|
| 78 | + Defaults have been set and options have been exposed, but may require configuration |
| 79 | + based on your level of memory/speed efficiency. We suggest having a look at this tutorial for |
| 80 | + more information. |
| 81 | + `https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html` |
| 82 | +
|
| 83 | + Arguments: |
| 84 | + cpu_offload (Optional [CPUOffload]): |
| 85 | + CPU offloading config. Currently, only parameter and gradient CPU |
| 86 | + offload is supported. It can be enabled via passing in |
| 87 | + ``cpu_offload=CPUOffload(offload_params=True)``. Note that this |
| 88 | + currently implicitly enables gradient offloading to CPU in order for |
| 89 | + params and grads to be on same device to work with optimizer. This |
| 90 | + API is subject to change. Default is ``None`` in which case there |
| 91 | + will be no offloading. |
| 92 | + backward_prefetch: (Optional[BackwardPrefetch]): |
| 93 | + This is an experimental feature that is subject to change in the |
| 94 | + the near future. It allows users to enable two different backward_prefetch |
| 95 | + algorithms to help backward communication and computation overlapping. |
| 96 | + Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. |
| 97 | + """ |
| 98 | + super().__init__( |
| 99 | + accelerator=accelerator, |
| 100 | + parallel_devices=parallel_devices, |
| 101 | + cluster_environment=cluster_environment, |
| 102 | + checkpoint_io=checkpoint_io, |
| 103 | + precision_plugin=precision_plugin, |
| 104 | + ) |
| 105 | + self._process_group = None |
| 106 | + self.num_processes = ( |
| 107 | + len(self.parallel_devices) if self.parallel_devices is not None else 0 |
| 108 | + ) |
| 109 | + self._has_loaded_state_dict: bool = False |
| 110 | + self._process_group_backend: Optional[str] = process_group_backend |
| 111 | + self.cpu_offload = cpu_offload |
| 112 | + self.backward_prefetch = backward_prefetch |
| 113 | + |
| 114 | + @property |
| 115 | + def root_device(self) -> torch.device: |
| 116 | + return self.parallel_devices[self.local_rank] |
| 117 | + |
| 118 | + @property |
| 119 | + def process_group(self): |
| 120 | + if self._process_group is None: |
| 121 | + # The plugin should have already initilized ddp in setup_environment() |
| 122 | + self._process_group = _get_default_group() |
| 123 | + return self._process_group |
| 124 | + |
| 125 | + @property |
| 126 | + def setup_optimizers_in_pre_dispatch(self) -> bool: |
| 127 | + # Setup optimizers after the Fully Sharded Model has been made |
| 128 | + return True |
| 129 | + |
| 130 | + @property |
| 131 | + def process_group_backend(self) -> Optional[str]: |
| 132 | + return self._process_group_backend |
| 133 | + |
| 134 | + def setup_environment(self) -> None: |
| 135 | + self.setup_distributed() |
| 136 | + super().setup_environment() |
| 137 | + |
| 138 | + def setup_distributed(self) -> None: |
| 139 | + if not self.root_device.type == "cuda": |
| 140 | + raise MisconfigurationException( |
| 141 | + "You selected strategy to be `ddp_fully_sharded_native`, but GPU is not available." |
| 142 | + ) |
| 143 | + reset_seed() |
| 144 | + # set warning rank |
| 145 | + rank_zero_only.rank = self.global_rank |
| 146 | + |
| 147 | + self._process_group_backend = self._get_process_group_backend() |
| 148 | + init_dist_connection(self.cluster_environment, self._process_group_backend) |
| 149 | + |
| 150 | + def _get_process_group_backend(self) -> str: |
| 151 | + return ( |
| 152 | + self._process_group_backend |
| 153 | + or _get_process_group_backend_from_env() |
| 154 | + or get_default_process_group_backend_for_device(self.root_device) |
| 155 | + ) |
| 156 | + |
| 157 | + def setup(self, trainer: "pl.Trainer") -> None: |
| 158 | + self.accelerator.setup(trainer) |
| 159 | + |
| 160 | + if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: |
| 161 | + self.model = self._layer_sync.apply(self.model) |
| 162 | + |
| 163 | + if not self.cpu_offload: |
| 164 | + # When using CPU Offload, FSDP will manage the CUDA movement for us. |
| 165 | + # Note: this would be problematic for large model (which could not fit in one GPU) |
| 166 | + # as FSDP module.to(device) would first summon all parameters |
| 167 | + self.model_to_device() |
| 168 | + |
| 169 | + self.barrier() |
| 170 | + self.setup_optimizers(trainer) |
| 171 | + optimizers_to_device(self.optimizers, self.root_device) |
| 172 | + self.setup_precision_plugin() |
| 173 | + |
| 174 | + def model_to_device(self) -> None: |
| 175 | + # ensure we update the device type in the lightning module |
| 176 | + log.info( |
| 177 | + f"{self.__class__.__name__}: moving model to device [{self.root_device}]..." |
| 178 | + ) |
| 179 | + self.lightning_module.to(self.root_device) |
| 180 | + |
| 181 | + @contextlib.contextmanager |
| 182 | + def model_sharded_context(self) -> Generator: |
| 183 | + log.detail(f"{self.__class__.__name__}: entered model_sharded_context.") |
| 184 | + |
| 185 | + with enable_wrap( |
| 186 | + wrapper_cls=FullyShardedDataParallel, |
| 187 | + process_group=self.process_group, |
| 188 | + cpu_offload=self.cpu_offload, |
| 189 | + backward_prefetch=self.backward_prefetch, |
| 190 | + ): |
| 191 | + yield |
| 192 | + |
| 193 | + def barrier(self, *args, **kwargs) -> None: |
| 194 | + if not distributed_available(): |
| 195 | + return |
| 196 | + if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": |
| 197 | + torch.distributed.barrier(device_ids=self._determine_device_ids()) |
| 198 | + else: |
| 199 | + torch.distributed.barrier() |
| 200 | + |
| 201 | + def broadcast(self, obj: object, src: int = 0) -> object: |
| 202 | + obj = [obj] |
| 203 | + if self.global_rank != src: |
| 204 | + obj = [None] |
| 205 | + torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) |
| 206 | + return obj[0] |
| 207 | + |
| 208 | + def reduce( |
| 209 | + self, |
| 210 | + tensor, |
| 211 | + group: Optional[Any] = None, |
| 212 | + reduce_op: Union[ReduceOp, str] = "mean", |
| 213 | + ) -> torch.Tensor: |
| 214 | + """Reduces a tensor from several distributed processes to one aggregated tensor. |
| 215 | +
|
| 216 | + Args: |
| 217 | + tensor: the tensor to sync and reduce |
| 218 | + group: the process group to gather results from. Defaults to all processes (world) |
| 219 | + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. |
| 220 | + Can also be a string 'sum' to calculate the sum during reduction. |
| 221 | +
|
| 222 | + Return: |
| 223 | + reduced value, except when the input was not a tensor the output remains is unchanged |
| 224 | + """ |
| 225 | + if isinstance(tensor, torch.Tensor): |
| 226 | + tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) |
| 227 | + return tensor |
| 228 | + |
| 229 | + def pre_backward(self, closure_loss: torch.Tensor) -> None: |
| 230 | + """Run before precision plugin executes backward.""" |
| 231 | + if not self.lightning_module.automatic_optimization: |
| 232 | + prepare_for_backward(self.model, closure_loss) |
| 233 | + |
| 234 | + def _determine_device_ids(self): |
| 235 | + if self.root_device.type == "cpu": |
| 236 | + return None |
| 237 | + return [self.root_device.index] |
| 238 | + |
| 239 | + def teardown(self) -> None: |
| 240 | + log.info(f"{self.__class__.__name__}: tearing down plugin...") |
| 241 | + super().teardown() |
| 242 | + |
| 243 | + if self._layer_sync: |
| 244 | + self.model = self._layer_sync.revert(self.model) |
| 245 | + |
| 246 | + if self.root_device.type == "cuda": |
| 247 | + # GPU teardown |
| 248 | + if not os.environ.get("PL_SKIP_CPU_COPY_ON_DDP_TEARDOWN"): |
| 249 | + log.info(f"{self.__class__.__name__}: moving model to CPU...") |
| 250 | + self.lightning_module.cpu() |
| 251 | + # clean up memory |
| 252 | + torch.cuda.empty_cache() |
| 253 | + |
| 254 | + self._has_loaded_state_dict = False |
| 255 | + |
| 256 | + @classmethod |
| 257 | + def register_strategies(cls, strategy_registry: Dict) -> None: |
| 258 | + strategy_registry.register( |
| 259 | + "fsdp_native", |
| 260 | + cls, |
| 261 | + description="Fully Sharded Data Parallel training from pytorch.distributed.", |
| 262 | + ) |
| 263 | + strategy_registry.register( |
| 264 | + "fsdp_native_full_shard_offload", |
| 265 | + cls, |
| 266 | + description="Native FSDP with Full Sharding and CPU Offloading", |
| 267 | + cpu_offload=CPUOffload(offload_params=True), |
| 268 | + ) |
0 commit comments