diff --git a/src/lightning_lite/CHANGELOG.md b/src/lightning_lite/CHANGELOG.md index a764b0e59f1c5..732e9452b08a7 100644 --- a/src/lightning_lite/CHANGELOG.md +++ b/src/lightning_lite/CHANGELOG.md @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningLite.setup_module()` and `LightningLite.setup_optimizers()` to support strategies that need to set up the model before an optimizer can be created ([#15185](https://github.com/Lightning-AI/lightning/pull/15185)) +- Added support for Fully Sharded Data Parallel (FSDP) training in Lightning Lite ([#14967](https://github.com/Lightning-AI/lightning/issues/14967)) + + ### Changed - The `LightningLite.run()` method is no longer abstract ([#14992](https://github.com/Lightning-AI/lightning/issues/14992)) diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py index c60bf3c3bd2a4..850b25d0a7721 100644 --- a/src/lightning_lite/connector.py +++ b/src/lightning_lite/connector.py @@ -40,6 +40,7 @@ TorchElasticEnvironment, ) from lightning_lite.plugins.precision.double import DoublePrecision +from lightning_lite.plugins.precision.fsdp import FSDPPrecision from lightning_lite.strategies import ( DDPShardedStrategy, DDPSpawnShardedStrategy, @@ -53,6 +54,7 @@ XLAStrategy, ) from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES +from lightning_lite.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn from lightning_lite.utilities.device_parser import _determine_root_gpu_device from lightning_lite.utilities.imports import _IS_INTERACTIVE @@ -417,6 +419,13 @@ def _check_strategy_and_fallback(self) -> None: f"You selected `Lite(strategy='{strategy_flag}')` but process forking is not supported on this" f" platform. We recommed `Lite(strategy='ddp_spawn')` instead." ) + if ( + strategy_flag in _FSDP_ALIASES or isinstance(self._strategy_flag, FSDPStrategy) + ) and self._accelerator_flag not in ("cuda", "gpu"): + raise ValueError( + "You selected the FSDP strategy but FSDP is only available on GPU. Set `Lite(accelerator='gpu', ...)`" + " to continue or select a different strategy." + ) if strategy_flag: self._strategy_flag = strategy_flag @@ -465,9 +474,11 @@ def _check_and_init_precision(self) -> Precision: if self._precision_input == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" - return NativeMixedPrecision(self._precision_input, device) + + if isinstance(self.strategy, FSDPStrategy): + return FSDPPrecision(precision=self._precision_input, device=device) + return NativeMixedPrecision(precision=self._precision_input, device=device) raise RuntimeError("No precision set") diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 112fef1b775f4..5e41f15121acb 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -35,6 +35,7 @@ DDPShardedStrategy, DDPSpawnShardedStrategy, DeepSpeedStrategy, + FSDPStrategy, SingleDeviceStrategy, Strategy, XLAStrategy, @@ -593,14 +594,20 @@ def _prepare_run_method(self) -> None: # wrap the run method, so we can inject setup logic or spawn processes for the user setattr(self, "run", partial(self._run_impl, self.run)) - @staticmethod - def _validate_setup(module: nn.Module, optimizers: Sequence[Optimizer]) -> None: + def _validate_setup(self, module: nn.Module, optimizers: Sequence[Optimizer]) -> None: if isinstance(module, _LiteModule): raise ValueError("A model should be passed only once to the `setup` method.") if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): raise ValueError("An optimizer should be passed only once to the `setup` method.") + if isinstance(self._strategy, FSDPStrategy): + raise RuntimeError( + f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately." + " Create and set up the model first through `model = self.setup_model(model)`. Then create the" + " optimizer and set it up: `optimizer = self.setup_optimizer(optimizer)`." + ) + def _validate_setup_module(self, module: nn.Module) -> None: if isinstance(module, _LiteModule): raise ValueError("A model should be passed only once to the `setup_module` method.") diff --git a/src/lightning_lite/plugins/__init__.py b/src/lightning_lite/plugins/__init__.py index 785e5aa009d5b..d0416e70f9747 100644 --- a/src/lightning_lite/plugins/__init__.py +++ b/src/lightning_lite/plugins/__init__.py @@ -17,6 +17,7 @@ from lightning_lite.plugins.io.xla import XLACheckpointIO from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision from lightning_lite.plugins.precision.double import DoublePrecision +from lightning_lite.plugins.precision.fsdp import FSDPPrecision from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision from lightning_lite.plugins.precision.precision import Precision from lightning_lite.plugins.precision.tpu import TPUPrecision @@ -33,4 +34,5 @@ "NativeMixedPrecision", "TPUPrecision", "TPUBf16Precision", + "FSDPPrecision", ] diff --git a/src/lightning_lite/plugins/precision/__init__.py b/src/lightning_lite/plugins/precision/__init__.py index 412ef9274822c..c47ffeb3f9fc1 100644 --- a/src/lightning_lite/plugins/precision/__init__.py +++ b/src/lightning_lite/plugins/precision/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision from lightning_lite.plugins.precision.double import DoublePrecision +from lightning_lite.plugins.precision.fsdp import FSDPPrecision from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision from lightning_lite.plugins.precision.precision import Precision from lightning_lite.plugins.precision.tpu import TPUPrecision @@ -25,4 +26,5 @@ "Precision", "TPUPrecision", "TPUBf16Precision", + "FSDPPrecision", ] diff --git a/src/lightning_lite/plugins/precision/fsdp.py b/src/lightning_lite/plugins/precision/fsdp.py new file mode 100644 index 0000000000000..020369bcbc4cf --- /dev/null +++ b/src/lightning_lite/plugins/precision/fsdp.py @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, TYPE_CHECKING + +import torch +from typing_extensions import Literal + +from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision +from lightning_lite.utilities.enums import PrecisionType +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 + +if TYPE_CHECKING: + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + + +class FSDPPrecision(NativeMixedPrecision): + """AMP for Fully Sharded Data Parallel training.""" + + def __init__( + self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None + ) -> None: + if not _TORCH_GREATER_EQUAL_1_12: + raise NotImplementedError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.") + + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + + super().__init__( + precision=precision, + device=device, + scaler=(ShardedGradScaler() if scaler is None and precision == 16 else None), + ) + + @property + def mixed_precision_config(self) -> "MixedPrecision": + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + + if self.precision == PrecisionType.HALF: + dtype = torch.float16 + elif self.precision == PrecisionType.BFLOAT: + dtype = torch.bfloat16 + else: + raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.") + return MixedPrecision( + param_dtype=dtype, + reduce_dtype=dtype, + buffer_dtype=dtype, + ) diff --git a/src/lightning_lite/strategies/__init__.py b/src/lightning_lite/strategies/__init__.py index f9cf74e30e4c0..a8d235708b573 100644 --- a/src/lightning_lite/strategies/__init__.py +++ b/src/lightning_lite/strategies/__init__.py @@ -17,6 +17,7 @@ from lightning_lite.strategies.dp import DataParallelStrategy # noqa: F401 from lightning_lite.strategies.fairscale import DDPShardedStrategy # noqa: F401 from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy # noqa: F401 +from lightning_lite.strategies.fsdp import FSDPStrategy # noqa: F401 from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401 from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401 diff --git a/src/lightning_lite/strategies/ddp.py b/src/lightning_lite/strategies/ddp.py index d970cc9d0bf10..c72dccd509916 100644 --- a/src/lightning_lite/strategies/ddp.py +++ b/src/lightning_lite/strategies/ddp.py @@ -92,8 +92,7 @@ def num_processes(self) -> int: @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: - distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) - return distributed_sampler_kwargs + return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) @property def process_group_backend(self) -> Optional[str]: diff --git a/src/lightning_lite/strategies/ddp_spawn.py b/src/lightning_lite/strategies/ddp_spawn.py index 532c92ed8d837..9388a2722463d 100644 --- a/src/lightning_lite/strategies/ddp_spawn.py +++ b/src/lightning_lite/strategies/ddp_spawn.py @@ -99,8 +99,7 @@ def num_processes(self) -> int: @property def distributed_sampler_kwargs(self) -> Dict[str, int]: - distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) - return distributed_sampler_kwargs + return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) @property def process_group_backend(self) -> Optional[str]: diff --git a/src/lightning_lite/strategies/deepspeed.py b/src/lightning_lite/strategies/deepspeed.py index 57920aa8a9246..74dc73c210c08 100644 --- a/src/lightning_lite/strategies/deepspeed.py +++ b/src/lightning_lite/strategies/deepspeed.py @@ -297,8 +297,7 @@ def zero_stage_3(self) -> bool: @property def distributed_sampler_kwargs(self) -> Dict[str, int]: - distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) - return distributed_sampler_kwargs + return dict(num_replicas=self.world_size, rank=self.global_rank) @property def model(self) -> "deepspeed.DeepSpeedEngine": diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py new file mode 100644 index 0000000000000..8053992d18525 --- /dev/null +++ b/src/lightning_lite/strategies/fsdp.py @@ -0,0 +1,308 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager +from datetime import timedelta +from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union + +import torch +from torch import Tensor +from torch.distributed import default_pg_timeout +from torch.nn import Module +from torch.optim import Optimizer + +from lightning_lite.accelerators import Accelerator +from lightning_lite.plugins import CheckpointIO, ClusterEnvironment, Precision +from lightning_lite.plugins.precision.fsdp import FSDPPrecision +from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher +from lightning_lite.strategies.parallel import ParallelStrategy +from lightning_lite.strategies.strategy import _BackwardSyncControl, _Sharded, TBroadcast +from lightning_lite.utilities.distributed import ( + _distributed_available, + _get_default_process_group_backend_for_device, + _init_dist_connection, + _sync_ddp_if_available, +) +from lightning_lite.utilities.distributed import group as _group +from lightning_lite.utilities.distributed import ReduceOp +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from lightning_lite.utilities.rank_zero import rank_zero_only +from lightning_lite.utilities.seed import reset_seed + +if TYPE_CHECKING: + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + CPUOffload, + FullyShardedDataParallel, + MixedPrecision, + ) + +_FSDP_ALIASES = ("fsdp", "fsdp_full_shard_offload") + + +class FSDPStrategy(ParallelStrategy, _Sharded): + r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. + + .. warning:: ``FSDPStrategy`` is in BETA and subject to change. The interface can + bring breaking changes and new features with the next release of PyTorch. + + Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model + size, whilst using efficient communication to reduce overhead. In practice, this means we can remain + at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar + to ZeRO-Stage 3. + + For more information `check out `__. + + Defaults have been set and options have been exposed, but may require configuration + based on your level of memory/speed efficiency. We suggest having a look at + `this tutorial `__ for more information. + + Arguments: + cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It + can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently + implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device + to work with the optimizer. This API is subject to change. Default is ``None`` in which case there + will be no offloading. + backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows + users to enable two different backward prefetching algorithms to help backward communication and + computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. + mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 + if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. + \**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class + when wrapping modules. + """ + + def __init__( + self, + accelerator: Optional[Accelerator] = None, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision: Optional[Precision] = None, + process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, + cpu_offload: Optional["CPUOffload"] = None, + backward_prefetch: Optional["BackwardPrefetch"] = None, + mixed_precision: Optional["MixedPrecision"] = None, + **kwargs: Any, + ) -> None: + if not _TORCH_GREATER_EQUAL_1_12: + raise NotImplementedError("`FSDPStrategy` is supported from PyTorch v1.12.0 onwards.") + + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + precision=precision, + ) + self._num_nodes = 1 + self._process_group_backend: Optional[str] = process_group_backend + self._timeout: Optional[timedelta] = timeout + self._backward_sync_control = _FSDPBackwardSyncControl() + self._ddp_kwargs = kwargs + + self.cpu_offload = cpu_offload + self.backward_prefetch = backward_prefetch + self.mixed_precision = mixed_precision + + @property + def root_device(self) -> torch.device: + assert self.parallel_devices is not None + return self.parallel_devices[self.local_rank] + + @property + def is_distributed(self) -> bool: + return True + + @property + def num_nodes(self) -> int: + return self._num_nodes + + @num_nodes.setter + def num_nodes(self, num_nodes: int) -> None: + self._num_nodes = num_nodes + + @property + def num_processes(self) -> int: + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + + @property + def distributed_sampler_kwargs(self) -> Dict: + return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) + + @property + def process_group_backend(self) -> Optional[str]: + return self._process_group_backend + + @property + def mixed_precision_config(self) -> Optional["MixedPrecision"]: + if self.mixed_precision: + return self.mixed_precision + if isinstance(self.precision, FSDPPrecision): + return self.precision.mixed_precision_config + + def _configure_launcher(self) -> None: + assert self.cluster_environment is not None + if not self.cluster_environment.creates_processes_externally: + self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + + def setup_environment(self) -> None: + self._setup_distributed() + super().setup_environment() + + def setup_module_and_optimizers( + self, module: Module, optimizers: List[Optimizer] + ) -> Tuple[Module, List[Optimizer]]: + raise NotImplementedError( + f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)." + " Please do it in this order: Create the model, call `setup_module`, create the optimizer," + " call `setup_optimizer`." + ) + + def setup_module(self, module: Module) -> "FullyShardedDataParallel": + """Wraps the model into a + :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + + if ( + any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()) + and "auto_wrap_policy" in self._ddp_kwargs + ): + # If model is already wrapped, we need to avoid sending the `auto_wrap_policy` + del self._ddp_kwargs["auto_wrap_policy"] + return FullyShardedDataParallel( + module=module, + cpu_offload=self.cpu_offload, + backward_prefetch=self.backward_prefetch, + mixed_precision=self.mixed_precision_config, + device_id=self.root_device.index, + **self._ddp_kwargs, + ) + + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Set up an optimizer for a model wrapped with FSDP. + + This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify + that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the + flattened parameters. + """ + from torch.distributed.fsdp import FlatParameter + + num_groups = len(optimizer.param_groups) + if num_groups > 1: + raise ValueError( + "An optimizer used with an FSDP model does not support multiple param groups." + f" Found {num_groups} parameter groups." + ) + + if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]): + return optimizer + + raise ValueError( + "The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer" + " after setting up the model." + ) + + def module_to_device(self, module: Module) -> None: + pass + + @contextmanager + def module_sharded_context(self) -> Generator: + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + from torch.distributed.fsdp.wrap import enable_wrap + + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + cpu_offload=self.cpu_offload, + backward_prefetch=self.backward_prefetch, + mixed_precision=self.mixed_precision_config, + device_id=self.root_device.index, + **self._ddp_kwargs, + ): + yield + + def reduce( + self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + ) -> Tensor: + if isinstance(tensor, Tensor): + tensor = _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + return tensor + + def barrier(self, *args: Any, **kwargs: Any) -> None: + if not _distributed_available(): + return + if torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=[self.root_device.index]) + else: + torch.distributed.barrier() + + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + obj = [obj] + if self.global_rank != src: + obj = [None] # type: ignore[list-item] + torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] + + @classmethod + def register_strategies(cls, strategy_registry: Dict) -> None: + if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available(): + return + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload + + strategy_registry.register( + "fsdp", + cls, + description="Fully Sharded Data Parallel training from torch.distributed.", + ) + strategy_registry.register( + "fsdp_full_shard_offload", + cls, + description="Native FSDP with Full Sharding and CPU Offloading", + cpu_offload=CPUOffload(offload_params=True), + ) + + def _setup_distributed(self) -> None: + reset_seed() + self._set_world_ranks() + rank_zero_only.rank = self.global_rank + self._process_group_backend = self._get_process_group_backend() + assert self.cluster_environment is not None + _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + + def _get_process_group_backend(self) -> str: + return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) + + def _set_world_ranks(self) -> None: + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() + + +class _FSDPBackwardSyncControl(_BackwardSyncControl): + @contextmanager + def no_backward_sync(self, module: Module) -> Generator: + """Blocks gradient synchronization inside the + :class:`~torch.distributed.fsdp.FullyShardedDataParallel` wrapper.""" + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + + if not isinstance(module, FullyShardedDataParallel): + raise TypeError( + "Blocking backward sync is only possible if the module passed to" + f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `FullyShardedDataParallel`." + f" Got: {module.__class__.__name__}." + ) + with module.no_sync(): + yield diff --git a/src/lightning_lite/strategies/parallel.py b/src/lightning_lite/strategies/parallel.py index 85243f7406cce..7cb25f54183e8 100644 --- a/src/lightning_lite/strategies/parallel.py +++ b/src/lightning_lite/strategies/parallel.py @@ -76,10 +76,10 @@ def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> No @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: - distributed_sampler_kwargs = dict( - num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, rank=self.global_rank + return dict( + num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, + rank=self.global_rank, ) - return distributed_sampler_kwargs def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: """Perform a all_gather on all processes.""" diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index f4b91dbfbfb96..29a089a577a3f 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -132,8 +132,7 @@ def num_processes(self) -> int: @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: - distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) - return distributed_sampler_kwargs + return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) @property def _is_single_process_single_device(self) -> bool: diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index e7df64c2acc10..465c65bfa7539 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -599,8 +599,7 @@ def _initialize_deepspeed_inference(self, model: Module) -> None: @property def distributed_sampler_kwargs(self) -> Dict[str, int]: - distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) - return distributed_sampler_kwargs + return dict(num_replicas=self.world_size, rank=self.global_rank) def setup_optimizers(self, trainer: "pl.Trainer") -> None: """Creates optimizers and schedulers. diff --git a/src/pytorch_lightning/strategies/parallel.py b/src/pytorch_lightning/strategies/parallel.py index dc3f96eecf3fd..c48f37f71be0a 100644 --- a/src/pytorch_lightning/strategies/parallel.py +++ b/src/pytorch_lightning/strategies/parallel.py @@ -77,10 +77,10 @@ def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> No @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: - distributed_sampler_kwargs = dict( - num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, rank=self.global_rank + return dict( + num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, + rank=self.global_rank, ) - return distributed_sampler_kwargs def reconciliate_processes(self, trace: str) -> None: """Function to re-conciliate processes on failure.""" diff --git a/tests/tests_lite/plugins/environments/test_slurm.py b/tests/tests_lite/plugins/environments/test_slurm.py index 768e1f468da99..c7e9cb7cf637f 100644 --- a/tests/tests_lite/plugins/environments/test_slurm.py +++ b/tests/tests_lite/plugins/environments/test_slurm.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import os +import shutil import sys from unittest import mock @@ -120,6 +121,7 @@ def test_detect(): @RunIf(skip_windows=True) +@pytest.mark.skipif(shutil.which("srun") is not None, reason="must run on a machine where srun is not available") def test_srun_available_and_not_used(monkeypatch): """Test that a warning is emitted if Lightning suspects the user forgot to run their script with `srun`.""" monkeypatch.setattr(sys, "argv", ["train.py", "--lr", "0.01"]) diff --git a/tests/tests_lite/plugins/precision/test_fsdp.py b/tests/tests_lite/plugins/precision/test_fsdp.py new file mode 100644 index 0000000000000..03a5fcfe33463 --- /dev/null +++ b/tests/tests_lite/plugins/precision/test_fsdp.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest +import torch +from tests_lite.helpers.runif import RunIf + +from lightning_lite.plugins import FSDPPrecision + + +@mock.patch("lightning_lite.plugins.precision.fsdp._TORCH_GREATER_EQUAL_1_12", False) +def test_fsdp_precision_support(*_): + with pytest.raises(NotImplementedError, match="`FSDPPrecision` is supported from PyTorch v1.12.0"): + FSDPPrecision(precision=16, device="cuda") + + +@RunIf(min_torch="1.12", min_cuda_gpus=1) +@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) +def test_fsdp_precision_config(precision, expected): + plugin = FSDPPrecision(precision=precision, device="cuda") + config = plugin.mixed_precision_config + assert config.param_dtype == expected + assert config.buffer_dtype == expected + assert config.reduce_dtype == expected diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py new file mode 100644 index 0000000000000..8f609d53c253a --- /dev/null +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock +from unittest.mock import MagicMock, Mock + +import pytest +import torch +import torch.nn as nn +from tests_lite.helpers.runif import RunIf +from torch.optim import Adam + +from lightning_lite.strategies import FSDPStrategy +from lightning_lite.strategies.fsdp import _FSDPBackwardSyncControl +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 + +if _TORCH_GREATER_EQUAL_1_12: + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision + + +@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False) +def test_fsdp_support(*_): + with pytest.raises(NotImplementedError, match="`FSDPStrategy` is supported from PyTorch v1.12.0"): + FSDPStrategy() + + +@RunIf(min_torch="1.12") +def test_fsdp_custom_mixed_precision(*_): + """Test that passing a custom mixed precision config works.""" + config = MixedPrecision() + strategy = FSDPStrategy(mixed_precision=config) + assert strategy.mixed_precision_config == config + + +@RunIf(min_torch="1.12") +def test_fsdp_setup_optimizer_validation(): + """Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters.""" + module = nn.Linear(2, 2) + strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")]) + + bad_optimizer = Adam([{"params": [module.weight]}, {"params": [module.bias], "lr": 1e-3}]) + with pytest.raises(ValueError, match="does not support multiple param groups"): + strategy.setup_optimizer(bad_optimizer) + + bad_optimizer = Adam(module.parameters()) + with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"): + strategy.setup_optimizer(bad_optimizer) + + +@RunIf(min_torch="1.12") +def test_fsdp_no_backward_sync(): + """Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in + FullyShardedDataParallel.""" + + strategy = FSDPStrategy() + assert isinstance(strategy._backward_sync_control, _FSDPBackwardSyncControl) + + with pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`" + ): + with strategy._backward_sync_control.no_backward_sync(Mock()): + pass + + module = MagicMock(spec=FullyShardedDataParallel) + with strategy._backward_sync_control.no_backward_sync(module): + pass + + module.no_sync.assert_called_once() diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py new file mode 100644 index 0000000000000..052133e265e4c --- /dev/null +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -0,0 +1,118 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile + +import pytest +import torch +from tests_lite.helpers.models import RandomDataset +from tests_lite.helpers.runif import RunIf +from torch.utils.data import DataLoader + +from lightning_lite import LightningLite +from lightning_lite.plugins import FSDPPrecision +from lightning_lite.strategies import FSDPStrategy +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 + +if _TORCH_GREATER_EQUAL_1_12: + from torch.distributed.fsdp import FullyShardedDataParallel + from torch.distributed.fsdp.wrap import wrap + + +def _get_model(manual_wrapping=False): + model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + if not manual_wrapping: + return model + + for i, layer in enumerate(model): + if i % 2 == 0: + model[i] = wrap(layer) + return model + + +def _step(lite, model, batch): + forward_module = model._forward_module + original_module = model.module + assert isinstance(forward_module, FullyShardedDataParallel) + assert isinstance(lite._precision, FSDPPrecision) + + precision = torch.float16 if lite._precision.precision == 16 else torch.bfloat16 + assert forward_module.mixed_precision.param_dtype == precision + assert forward_module.mixed_precision.reduce_dtype == precision + assert forward_module.mixed_precision.buffer_dtype == precision + + for layer_num in [0, 2]: + assert isinstance(original_module[layer_num], FullyShardedDataParallel) + assert original_module[layer_num].mixed_precision.param_dtype == precision + assert original_module[layer_num].mixed_precision.reduce_dtype == precision + assert original_module[layer_num].mixed_precision.buffer_dtype == precision + + output = model(batch) + loss = torch.nn.functional.mse_loss(output, torch.ones_like(output)) + return loss + + +def _assert_save_equality(lite, model, ckpt_path): + current_state_dict = lite._strategy.get_module_state_dict(model) + + checkpoint = lite.load(ckpt_path) + loaded_model = _get_model() + loaded_model.load_state_dict(checkpoint) + + # model parameters are identical after loading + for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()): + assert torch.equal(current_param.float().cpu(), loaded_param.cpu()) + + +def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool: + return unwrapped_params >= 2 + + +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") +@pytest.mark.parametrize("precision", (16, pytest.param("bf16", marks=RunIf(bf16_cuda=True)))) +@pytest.mark.parametrize("manual_wrapping", [True, False]) +def test_fsdp_train_save_load(manual_wrapping, precision): + """Test FSDP training, saving and loading with different wrapping and precision settings.""" + strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy) + lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) + lite.launch() + + with lite.sharded_model(): + model = _get_model(manual_wrapping) + + dataloader = DataLoader(RandomDataset(32, 64)) + + # model needs to be set up first in FSDP + model = lite.setup_module(model) + + # get parameters on the wrapped model + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + # optimizer nees to be set up independently + optimizer = lite.setup_optimizers(optimizer) + + dataloader = lite.setup_dataloaders(dataloader) + model.train() + + data_iter = iter(dataloader) + batch = next(data_iter) + loss = _step(lite, model, batch) + lite.backward(loss) + optimizer.step() + optimizer.zero_grad() + + with tempfile.TemporaryFile() as ckpt_path: + ckpt_path = lite.broadcast(str(ckpt_path)) + lite._strategy.save_checkpoint(model.state_dict(), ckpt_path) + + _assert_save_equality(lite, model, ckpt_path) diff --git a/tests/tests_lite/strategies/test_registry.py b/tests/tests_lite/strategies/test_registry.py index 93c0071d9cd47..81a49eec08934 100644 --- a/tests/tests_lite/strategies/test_registry.py +++ b/tests/tests_lite/strategies/test_registry.py @@ -13,6 +13,7 @@ # limitations under the License. from lightning_lite.strategies import STRATEGY_REGISTRY +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 def test_strategy_registry_with_new_strategy(): @@ -41,7 +42,7 @@ def __init__(self, param1, param2): def test_available_strategies_in_registry(): - assert set(STRATEGY_REGISTRY.available_strategies()) == { + expected = { "ddp_sharded_find_unused_parameters_false", "ddp_sharded", "ddp_find_unused_parameters_false", @@ -66,3 +67,6 @@ def test_available_strategies_in_registry(): "xla", "dp", } + if _TORCH_GREATER_EQUAL_1_12: + expected |= {"fsdp", "fsdp_full_shard_offload"} + assert set(STRATEGY_REGISTRY.available_strategies()) == expected diff --git a/tests/tests_lite/test_connector.py b/tests/tests_lite/test_connector.py index d48a3cf1cac1d..072ecdfe99af3 100644 --- a/tests/tests_lite/test_connector.py +++ b/tests/tests_lite/test_connector.py @@ -884,3 +884,10 @@ def test_arguments_from_environment_collision(): ValueError, match=escape("Your code has `LightningLite(precision=64, ...)` but it conflicts") ): _Connector(precision=64) + + +@RunIf(min_torch="1.12") +def test_fsdp_unsupported_on_cpu(): + """Test that we raise an error if attempting to run FSDP without GPU.""" + with pytest.raises(ValueError, match="You selected the FSDP strategy but FSDP is only available on GPU"): + _Connector(strategy="fsdp")