From 57774edfc957ab410405e2cf252745d33166a8fb Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 1 Oct 2022 11:30:40 +0200 Subject: [PATCH 01/84] wip --- src/lightning_lite/connector.py | 1 + src/lightning_lite/strategies/ddp.py | 3 +- src/lightning_lite/strategies/ddp_spawn.py | 3 +- src/lightning_lite/strategies/deepspeed.py | 3 +- src/lightning_lite/strategies/fsdp.py | 436 ++++++++++++++++++ src/lightning_lite/strategies/parallel.py | 6 +- src/pytorch_lightning/strategies/ddp.py | 3 +- src/pytorch_lightning/strategies/deepspeed.py | 3 +- src/pytorch_lightning/strategies/parallel.py | 6 +- 9 files changed, 448 insertions(+), 16 deletions(-) create mode 100644 src/lightning_lite/strategies/fsdp.py diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py index 3e9a7560d6472..8f63653b3586f 100644 --- a/src/lightning_lite/connector.py +++ b/src/lightning_lite/connector.py @@ -53,6 +53,7 @@ XLAStrategy, ) from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES +from lightning_lite.strategies.fsdp import FSDPStrategy from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn from lightning_lite.utilities.device_parser import determine_root_gpu_device from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE diff --git a/src/lightning_lite/strategies/ddp.py b/src/lightning_lite/strategies/ddp.py index bd229be91934b..020c13b7c6b12 100644 --- a/src/lightning_lite/strategies/ddp.py +++ b/src/lightning_lite/strategies/ddp.py @@ -85,8 +85,7 @@ def num_processes(self) -> int: @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: - distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) - return distributed_sampler_kwargs + return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) @property def process_group_backend(self) -> Optional[str]: diff --git a/src/lightning_lite/strategies/ddp_spawn.py b/src/lightning_lite/strategies/ddp_spawn.py index def19d4ac0f24..44acb7e25a42a 100644 --- a/src/lightning_lite/strategies/ddp_spawn.py +++ b/src/lightning_lite/strategies/ddp_spawn.py @@ -92,8 +92,7 @@ def num_processes(self) -> int: @property def distributed_sampler_kwargs(self) -> Dict[str, int]: - distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) - return distributed_sampler_kwargs + return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) @property def process_group_backend(self) -> Optional[str]: diff --git a/src/lightning_lite/strategies/deepspeed.py b/src/lightning_lite/strategies/deepspeed.py index da532c5a567fd..10ca61b4ad31d 100644 --- a/src/lightning_lite/strategies/deepspeed.py +++ b/src/lightning_lite/strategies/deepspeed.py @@ -292,8 +292,7 @@ def zero_stage_3(self) -> bool: @property def distributed_sampler_kwargs(self) -> Dict[str, int]: - distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) - return distributed_sampler_kwargs + return dict(num_replicas=self.world_size, rank=self.global_rank) @property def model(self) -> "deepspeed.DeepSpeedEngine": diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py new file mode 100644 index 0000000000000..5b237647c22be --- /dev/null +++ b/src/lightning_lite/strategies/fsdp.py @@ -0,0 +1,436 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +from datetime import timedelta +from typing import Any, Dict, Generator, List, Optional, Union + +import torch +from torch import Tensor +from torch.distributed import default_pg_timeout +from torch.nn import Module + +import pytorch_lightning as pl +from lightning_lite.accelerators import Accelerator +from lightning_lite.plugins import CheckpointIO, ClusterEnvironment +from lightning_lite.utilities.distributed import get_default_process_group_backend_for_device +from lightning_lite.utilities.distributed import group as _group +from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available +from lightning_lite.utilities.optimizer import optimizers_to_device +from lightning_lite.utilities.seed import reset_seed +from lightning_lite.plugins import Precision +from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin +from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher +from lightning_lite.strategies.parallel import ParallelStrategy +from pytorch_lightning.strategies.strategy import TBroadcast +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only +from pytorch_lightning.utilities.types import ProcessGroup, STEP_OUTPUT + +_distributed_available = torch.distributed.is_available() +_fsdp_available = _TORCH_GREATER_EQUAL_1_12 and _distributed_available +if _fsdp_available: + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + CPUOffload, + FullyShardedDataParallel, + MixedPrecision, + ) + from torch.distributed.fsdp.wrap import enable_wrap +else: + FullyShardedDataParallel = None # type: ignore[misc,assignment] + MixedPrecision = None # type: ignore[misc,assignment] + BackwardPrefetch = None # type: ignore[misc,assignment] + CPUOffload = None # type: ignore[misc,assignment] + +if _distributed_available: + from torch.distributed.distributed_c10d import _get_default_group + +log = logging.getLogger(__name__) + + +class FSDPStrategy(ParallelStrategy): + r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. + + .. warning:: ``DDPFullyShardedNativeStrategy`` is in BETA and subject to change. The interface can + bring breaking changes and new features with the next release of PyTorch. + + Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model + size, whilst using efficient communication to reduce overhead. In practice, this means we can remain + at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar + to ZeRO-Stage 3. + + For more information `check out `__. + + Defaults have been set and options have been exposed, but may require configuration + based on your level of memory/speed efficiency. We suggest having a look at + `this tutorial `__ for more information. + + Arguments: + cpu_offload: + CPU offloading config. Currently, only parameter and gradient CPU + offload is supported. It can be enabled via passing in + ``cpu_offload=CPUOffload(offload_params=True)``. Note that this + currently implicitly enables gradient offloading to CPU in order for + params and grads to be on same device to work with optimizer. This + API is subject to change. Default is ``None`` in which case there + will be no offloading. + backward_prefetch: + This is an experimental feature that is subject to change in the + the near future. It allows users to enable two different backward_prefetch + algorithms to help backward communication and computation overlapping. + The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. + mixed_precision: + Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` + or BF16 if ``precision=bf16`` unless a config is passed in. + This is only available in PyTorch 1.12 and later. + \**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules. + + """ + + strategy_name = "fsdp_native" + _registered_strategies: List[str] = [] + + def __init__( + self, + accelerator: Optional[Accelerator] = None, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[Precision] = None, + process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, + cpu_offload: Optional[CPUOffload] = None, + backward_prefetch: Optional[BackwardPrefetch] = None, + mixed_precision: Optional[MixedPrecision] = None, + **kwargs: Any, + ) -> None: + if not _TORCH_GREATER_EQUAL_1_12: + raise MisconfigurationException( + "`FSDPStrategy` is supported from PyTorch v1.12.0 onwards." + ) + + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) + self._num_nodes = 1 + self._process_group_backend: Optional[str] = process_group_backend + self._timeout: Optional[timedelta] = timeout + self._ddp_kwargs = kwargs + + self.cpu_offload = cpu_offload + self.backward_prefetch = backward_prefetch + self.mixed_precision = mixed_precision + + @property + def root_device(self) -> torch.device: + assert self.parallel_devices is not None + return self.parallel_devices[self.local_rank] + + @property + def is_distributed(self) -> bool: + return True + + @property + def num_nodes(self) -> int: + return self._num_nodes + + @num_nodes.setter + def num_nodes(self, num_nodes: int) -> None: + self._num_nodes = num_nodes + + @property + def num_processes(self) -> int: + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + + @property + def distributed_sampler_kwargs(self) -> Dict: + return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) + + @property + def process_group_backend(self) -> Optional[str]: + return self._process_group_backend + + # @property + # def process_group(self) -> Optional[ProcessGroup]: + # if self._process_group is None: + # # The strategy should have already initilized process group in setup_environment() + # self._process_group = _get_default_group() + # return self._process_group + + # @property + # def mixed_precision_config(self) -> Optional[MixedPrecision]: + # if self.mixed_precision: + # return self.mixed_precision + # plugin = self.precision_plugin + # if isinstance(plugin, FullyShardedNativeNativeMixedPrecisionPlugin): + # return plugin.mixed_precision_config + + def _configure_launcher(self) -> None: + assert self.cluster_environment is not None + if not self.cluster_environment.creates_processes_externally: + self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + + def setup_environment(self) -> None: + self._setup_distributed() + super().setup_environment() + + def setup_module(self, module: Module) -> FullyShardedDataParallel: + """Wraps the model into a + :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" + if ( + any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()) + and "auto_wrap_policy" in self._ddp_kwargs + ): + # If model is already wrapped, we need to avoid sending the `auto_wrap_policy` + del self._ddp_kwargs["auto_wrap_policy"] + return FullyShardedDataParallel( + module=module, + process_group=self.process_group, + cpu_offload=self.cpu_offload, + backward_prefetch=self.backward_prefetch, + mixed_precision=self.mixed_precision_config, + device_id=self.root_device.index, + **self._ddp_kwargs, + ) + + def module_to_device(self, module: Module) -> None: + pass + + + # + # def reduce( + # self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + # ) -> Tensor: + # """Reduces a tensor from several distributed processes to one aggregated tensor. + # + # Args: + # tensor: the tensor to sync and reduce + # group: the process group to gather results from. Defaults to all processes (world) + # reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + # Can also be a string 'sum' to calculate the sum during reduction. + # + # Return: + # reduced value, except when the input was not a tensor the output remains is unchanged + # """ + # if isinstance(tensor, Tensor): + # tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + # return tensor + # + # def barrier(self, *args: Any, **kwargs: Any) -> None: + # if not distributed_available(): + # return + # if torch.distributed.get_backend() == "nccl": + # torch.distributed.barrier(device_ids=self._determine_ddp_device_ids()) + # else: + # torch.distributed.barrier() + # + # def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + # obj = [obj] + # if self.global_rank != src: + # obj = [None] # type: ignore[list-item] + # torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) + # return obj[0] + @classmethod + def register_strategies(cls, strategy_registry: Dict) -> None: + strategy_registry.register( + "fsdp", + cls, + description="Fully Sharded Data Parallel training from torch.distributed.", + ) + strategy_registry.register( + "fsdp_full_shard_offload", + cls, + description="Native FSDP with Full Sharding and CPU Offloading", + cpu_offload=CPUOffload(offload_params=True), + ) + + def _setup_distributed(self) -> None: + reset_seed() + self._set_world_ranks() + rank_zero_only.rank = self.global_rank + self._process_group_backend = self._get_process_group_backend() + assert self.cluster_environment is not None + init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + + def _get_process_group_backend(self) -> str: + return self._process_group_backend or get_default_process_group_backend_for_device(self.root_device) + + def _set_world_ranks(self) -> None: + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() + + # def _determine_ddp_device_ids(self) -> Optional[List[int]]: + # if self.root_device.type == "cpu": + # return None + # return [self.root_device.index] + + + +# --- + + + + + + + + + + + + + + + + + + + def setup(self, trainer: "pl.Trainer") -> None: + assert self.accelerator is not None + self.accelerator.setup(trainer) + # share ddp pids to all processes + self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) + + if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: + assert self.model is not None + self.model = self._layer_sync.apply(self.model) + + # we set the device so that optimizers can be created with distributed comms. + assert self.lightning_module is not None + self.lightning_module._device = self.root_device + + assert isinstance(self.model, pl.LightningModule) + self.model = _LightningModuleWrapperBase(self.model) + if is_overridden("configure_sharded_model", self.lightning_module): + rank_zero_info( + "You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers" + " are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`." + ) + else: + self.model = self._setup_model(self.model) + self.barrier() + + self.setup_optimizers(trainer) + optimizers_to_device(self.optimizers, self.root_device) + + self.setup_precision_plugin() + + @contextlib.contextmanager + def model_sharded_context(self) -> Generator: + log.detail(f"{self.__class__.__name__}: entered model_sharded_context.") + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + process_group=self.process_group, + cpu_offload=self.cpu_offload, + backward_prefetch=self.backward_prefetch, + mixed_precision=self.mixed_precision_config, + device_id=self.root_device.index, + **self.kwargs, + ): + yield + + def barrier(self, name: Optional[str] = None) -> None: + if not _distributed_available: + return + if torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=self._determine_device_ids()) + else: + torch.distributed.barrier() + + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + obj = [obj] + if self.global_rank != src: + obj = [None] # type: ignore + torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] + + def reduce( + self, + tensor: Union[Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Tensor: + """Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + return tensor + + def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + # we don't need precision context since casting is done by FSDP + # read `mixed_precision` docstring here: https://pytorch.org/docs/stable/fsdp.html + assert self.model is not None + return self.model(*args, **kwargs) + + def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + assert self.model is not None + return self.model(*args, **kwargs) + + def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + assert self.model is not None + return self.model(*args, **kwargs) + + def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + assert self.model is not None + return self.model(*args, **kwargs) + + def _determine_device_ids(self) -> List[int]: + return [self.root_device.index] + + def teardown(self) -> None: + rank_zero_info(f"{self.__class__.__name__}: tearing down strategy...") + + pl_module = self.lightning_module + if ( + pl_module is not None + # `self.lightning_module._trainer` can be None if teardown gets called on an exception before + # the trainer gets set on the LightningModule + and pl_module._trainer is not None + and pl_module._trainer.state.fn == TrainerFn.FITTING + and self._layer_sync + ): + assert self.model is not None + self.model = self._layer_sync.revert(self.model) + + assert self.cluster_environment is not None + assert self.accelerator is not None + self.cluster_environment.teardown() + self.precision_plugin.teardown() + self.accelerator.teardown() + + @classmethod + def get_registered_strategies(cls) -> List[str]: + return cls._registered_strategies + + diff --git a/src/lightning_lite/strategies/parallel.py b/src/lightning_lite/strategies/parallel.py index 2036c7943049b..19b1e631ca5e3 100644 --- a/src/lightning_lite/strategies/parallel.py +++ b/src/lightning_lite/strategies/parallel.py @@ -78,10 +78,10 @@ def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> No @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: - distributed_sampler_kwargs = dict( - num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, rank=self.global_rank + return dict( + num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, + rank=self.global_rank, ) - return distributed_sampler_kwargs def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: """Perform a all_gather on all processes.""" diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index 5f28908341a23..05b7401954f7e 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -131,8 +131,7 @@ def num_processes(self) -> int: @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: - distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) - return distributed_sampler_kwargs + return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) @property def _is_single_process_single_device(self) -> bool: diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 82f0aed9d1366..a31b31a593948 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -596,8 +596,7 @@ def _initialize_deepspeed_inference(self, model: Module) -> None: @property def distributed_sampler_kwargs(self) -> Dict[str, int]: - distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) - return distributed_sampler_kwargs + return dict(num_replicas=self.world_size, rank=self.global_rank) def setup_optimizers(self, trainer: "pl.Trainer") -> None: """Creates optimizers and schedulers. diff --git a/src/pytorch_lightning/strategies/parallel.py b/src/pytorch_lightning/strategies/parallel.py index 9975b7464f393..4b6ba1faf33e9 100644 --- a/src/pytorch_lightning/strategies/parallel.py +++ b/src/pytorch_lightning/strategies/parallel.py @@ -77,10 +77,10 @@ def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> No @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: - distributed_sampler_kwargs = dict( - num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, rank=self.global_rank + return dict( + num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, + rank=self.global_rank, ) - return distributed_sampler_kwargs def reconciliate_processes(self, trace: str) -> None: """Function to re-conciliate processes on failure.""" From 043783e5f3233a7e8e7399bf92273ddc5d143fc0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 1 Oct 2022 12:00:11 +0200 Subject: [PATCH 02/84] wip precision --- src/lightning_lite/plugins/precision/fsdp.py | 55 +++++ src/lightning_lite/strategies/fsdp.py | 233 +++---------------- 2 files changed, 91 insertions(+), 197 deletions(-) create mode 100644 src/lightning_lite/plugins/precision/fsdp.py diff --git a/src/lightning_lite/plugins/precision/fsdp.py b/src/lightning_lite/plugins/precision/fsdp.py new file mode 100644 index 0000000000000..e211934a1ca9b --- /dev/null +++ b/src/lightning_lite/plugins/precision/fsdp.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch + +from lightning_lite.utilities.enums import PrecisionType +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 + +if _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available(): + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +else: + MixedPrecision = None # type: ignore[misc,assignment] + ShardedGradScaler = None # type: ignore[misc,assignment] + + +class FSDPPrecision(NativeMixedPrecisionPlugin): + """AMP for Fully Sharded Data Parallel training.""" + + def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None: + if not _TORCH_GREATER_EQUAL_1_12: + raise RuntimeError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.") + super().__init__( + precision=precision, + device=device, + scaler=(ShardedGradScaler() if scaler is None and precision == 16 else None), + ) + + @property + def mixed_precision_config(self) -> Optional[MixedPrecision]: + if self.precision == PrecisionType.HALF: + dtype = torch.float16 + elif self.precision == PrecisionType.BFLOAT: + dtype = torch.bfloat16 + else: + raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.") + return MixedPrecision( + param_dtype=dtype, + reduce_dtype=dtype, + buffer_dtype=dtype, + ) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 5b237647c22be..bdca66ed3c690 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import logging +from contextlib import contextmanager from datetime import timedelta from typing import Any, Dict, Generator, List, Optional, Union @@ -24,7 +24,7 @@ import pytorch_lightning as pl from lightning_lite.accelerators import Accelerator from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.utilities.distributed import get_default_process_group_backend_for_device +from lightning_lite.utilities.distributed import get_default_process_group_backend_for_device, distributed_available from lightning_lite.utilities.distributed import group as _group from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from lightning_lite.utilities.optimizer import optimizers_to_device @@ -169,13 +169,6 @@ def distributed_sampler_kwargs(self) -> Dict: def process_group_backend(self) -> Optional[str]: return self._process_group_backend - # @property - # def process_group(self) -> Optional[ProcessGroup]: - # if self._process_group is None: - # # The strategy should have already initilized process group in setup_environment() - # self._process_group = _get_default_group() - # return self._process_group - # @property # def mixed_precision_config(self) -> Optional[MixedPrecision]: # if self.mixed_precision: @@ -204,7 +197,6 @@ def setup_module(self, module: Module) -> FullyShardedDataParallel: del self._ddp_kwargs["auto_wrap_policy"] return FullyShardedDataParallel( module=module, - process_group=self.process_group, cpu_offload=self.cpu_offload, backward_prefetch=self.backward_prefetch, mixed_precision=self.mixed_precision_config, @@ -215,40 +207,41 @@ def setup_module(self, module: Module) -> FullyShardedDataParallel: def module_to_device(self, module: Module) -> None: pass + @contextmanager + def module_sharded_context(self) -> Generator: + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + # process_group=self.process_group, + cpu_offload=self.cpu_offload, + backward_prefetch=self.backward_prefetch, + mixed_precision=self.precision_plugin.mixed_precision_config, + device_id=self.root_device.index, + **self._ddp_kwargs, + ): + yield + + def reduce( + self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + ) -> Tensor: + if isinstance(tensor, Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + return tensor + + def barrier(self, *args: Any, **kwargs: Any) -> None: + if not distributed_available(): + return + if torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=[self.root_device.index]) + else: + torch.distributed.barrier() + + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + obj = [obj] + if self.global_rank != src: + obj = [None] # type: ignore[list-item] + torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] - # - # def reduce( - # self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - # ) -> Tensor: - # """Reduces a tensor from several distributed processes to one aggregated tensor. - # - # Args: - # tensor: the tensor to sync and reduce - # group: the process group to gather results from. Defaults to all processes (world) - # reduce_op: the reduction operation. Defaults to 'mean'/'avg'. - # Can also be a string 'sum' to calculate the sum during reduction. - # - # Return: - # reduced value, except when the input was not a tensor the output remains is unchanged - # """ - # if isinstance(tensor, Tensor): - # tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - # return tensor - # - # def barrier(self, *args: Any, **kwargs: Any) -> None: - # if not distributed_available(): - # return - # if torch.distributed.get_backend() == "nccl": - # torch.distributed.barrier(device_ids=self._determine_ddp_device_ids()) - # else: - # torch.distributed.barrier() - # - # def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - # obj = [obj] - # if self.global_rank != src: - # obj = [None] # type: ignore[list-item] - # torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) - # return obj[0] @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( @@ -280,157 +273,3 @@ def _set_world_ranks(self) -> None: self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) rank_zero_only.rank = self.cluster_environment.global_rank() - - # def _determine_ddp_device_ids(self) -> Optional[List[int]]: - # if self.root_device.type == "cpu": - # return None - # return [self.root_device.index] - - - -# --- - - - - - - - - - - - - - - - - - - - def setup(self, trainer: "pl.Trainer") -> None: - assert self.accelerator is not None - self.accelerator.setup(trainer) - # share ddp pids to all processes - self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) - - if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: - assert self.model is not None - self.model = self._layer_sync.apply(self.model) - - # we set the device so that optimizers can be created with distributed comms. - assert self.lightning_module is not None - self.lightning_module._device = self.root_device - - assert isinstance(self.model, pl.LightningModule) - self.model = _LightningModuleWrapperBase(self.model) - if is_overridden("configure_sharded_model", self.lightning_module): - rank_zero_info( - "You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers" - " are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`." - ) - else: - self.model = self._setup_model(self.model) - self.barrier() - - self.setup_optimizers(trainer) - optimizers_to_device(self.optimizers, self.root_device) - - self.setup_precision_plugin() - - @contextlib.contextmanager - def model_sharded_context(self) -> Generator: - log.detail(f"{self.__class__.__name__}: entered model_sharded_context.") - with enable_wrap( - wrapper_cls=FullyShardedDataParallel, - process_group=self.process_group, - cpu_offload=self.cpu_offload, - backward_prefetch=self.backward_prefetch, - mixed_precision=self.mixed_precision_config, - device_id=self.root_device.index, - **self.kwargs, - ): - yield - - def barrier(self, name: Optional[str] = None) -> None: - if not _distributed_available: - return - if torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=self._determine_device_ids()) - else: - torch.distributed.barrier() - - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - obj = [obj] - if self.global_rank != src: - obj = [None] # type: ignore - torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) - return obj[0] - - def reduce( - self, - tensor: Union[Tensor, Any], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = "mean", - ) -> Tensor: - """Reduces a tensor from several distributed processes to one aggregated tensor. - - Args: - tensor: the tensor to sync and reduce - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to 'mean'/'avg'. - Can also be a string 'sum' to calculate the sum during reduction. - - Return: - reduced value, except when the input was not a tensor the output remains is unchanged - """ - if isinstance(tensor, Tensor): - tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - - def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - # we don't need precision context since casting is done by FSDP - # read `mixed_precision` docstring here: https://pytorch.org/docs/stable/fsdp.html - assert self.model is not None - return self.model(*args, **kwargs) - - def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - assert self.model is not None - return self.model(*args, **kwargs) - - def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - assert self.model is not None - return self.model(*args, **kwargs) - - def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - assert self.model is not None - return self.model(*args, **kwargs) - - def _determine_device_ids(self) -> List[int]: - return [self.root_device.index] - - def teardown(self) -> None: - rank_zero_info(f"{self.__class__.__name__}: tearing down strategy...") - - pl_module = self.lightning_module - if ( - pl_module is not None - # `self.lightning_module._trainer` can be None if teardown gets called on an exception before - # the trainer gets set on the LightningModule - and pl_module._trainer is not None - and pl_module._trainer.state.fn == TrainerFn.FITTING - and self._layer_sync - ): - assert self.model is not None - self.model = self._layer_sync.revert(self.model) - - assert self.cluster_environment is not None - assert self.accelerator is not None - self.cluster_environment.teardown() - self.precision_plugin.teardown() - self.accelerator.teardown() - - @classmethod - def get_registered_strategies(cls) -> List[str]: - return cls._registered_strategies - - From 0caf973e1e32dbca77696cfec3faf0fe2338eb6c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 1 Oct 2022 14:57:18 +0200 Subject: [PATCH 03/84] fsdp --- src/lightning_lite/plugins/precision/fsdp.py | 23 ++++---- src/lightning_lite/strategies/fsdp.py | 62 ++++++++------------ 2 files changed, 35 insertions(+), 50 deletions(-) diff --git a/src/lightning_lite/plugins/precision/fsdp.py b/src/lightning_lite/plugins/precision/fsdp.py index e211934a1ca9b..626dec876b714 100644 --- a/src/lightning_lite/plugins/precision/fsdp.py +++ b/src/lightning_lite/plugins/precision/fsdp.py @@ -11,29 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Optional, TYPE_CHECKING, Literal import torch from lightning_lite.utilities.enums import PrecisionType -from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from lightning_lite.plugins.precision import NativeMixedPrecision +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 -if _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available(): +if TYPE_CHECKING: from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -else: - MixedPrecision = None # type: ignore[misc,assignment] - ShardedGradScaler = None # type: ignore[misc,assignment] -class FSDPPrecision(NativeMixedPrecisionPlugin): +class FSDPPrecision(NativeMixedPrecision): """AMP for Fully Sharded Data Parallel training.""" - def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None: + def __init__(self, precision: Literal[16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None) -> None: if not _TORCH_GREATER_EQUAL_1_12: raise RuntimeError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.") + + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + super().__init__( precision=precision, device=device, @@ -41,7 +40,9 @@ def __init__(self, precision: Union[str, int], device: str, scaler: Optional[Sha ) @property - def mixed_precision_config(self) -> Optional[MixedPrecision]: + def mixed_precision_config(self) -> MixedPrecision: + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + if self.precision == PrecisionType.HALF: dtype = torch.float16 elif self.precision == PrecisionType.BFLOAT: diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index bdca66ed3c690..591df147d4e60 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -11,39 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Any, Dict, Generator, List, Optional, Union, TYPE_CHECKING import torch from torch import Tensor from torch.distributed import default_pg_timeout from torch.nn import Module -import pytorch_lightning as pl from lightning_lite.accelerators import Accelerator from lightning_lite.plugins import CheckpointIO, ClusterEnvironment +from lightning_lite.plugins.precision.fsdp import FSDPPrecision from lightning_lite.utilities.distributed import get_default_process_group_backend_for_device, distributed_available from lightning_lite.utilities.distributed import group as _group from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available -from lightning_lite.utilities.optimizer import optimizers_to_device from lightning_lite.utilities.seed import reset_seed from lightning_lite.plugins import Precision -from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin -from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher +from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning_lite.strategies.parallel import ParallelStrategy -from pytorch_lightning.strategies.strategy import TBroadcast -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only -from pytorch_lightning.utilities.types import ProcessGroup, STEP_OUTPUT - -_distributed_available = torch.distributed.is_available() -_fsdp_available = _TORCH_GREATER_EQUAL_1_12 and _distributed_available -if _fsdp_available: +from lightning_lite.strategies.strategy import TBroadcast +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from lightning_lite.utilities.rank_zero import rank_zero_only + +if TYPE_CHECKING: from torch.distributed.fsdp.fully_sharded_data_parallel import ( BackwardPrefetch, CPUOffload, @@ -51,16 +42,6 @@ MixedPrecision, ) from torch.distributed.fsdp.wrap import enable_wrap -else: - FullyShardedDataParallel = None # type: ignore[misc,assignment] - MixedPrecision = None # type: ignore[misc,assignment] - BackwardPrefetch = None # type: ignore[misc,assignment] - CPUOffload = None # type: ignore[misc,assignment] - -if _distributed_available: - from torch.distributed.distributed_c10d import _get_default_group - -log = logging.getLogger(__name__) class FSDPStrategy(ParallelStrategy): @@ -120,9 +101,7 @@ def __init__( **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_1_12: - raise MisconfigurationException( - "`FSDPStrategy` is supported from PyTorch v1.12.0 onwards." - ) + raise RuntimeError("`FSDPStrategy` is supported from PyTorch v1.12.0 onwards.") super().__init__( accelerator=accelerator, @@ -169,13 +148,13 @@ def distributed_sampler_kwargs(self) -> Dict: def process_group_backend(self) -> Optional[str]: return self._process_group_backend - # @property - # def mixed_precision_config(self) -> Optional[MixedPrecision]: - # if self.mixed_precision: - # return self.mixed_precision - # plugin = self.precision_plugin - # if isinstance(plugin, FullyShardedNativeNativeMixedPrecisionPlugin): - # return plugin.mixed_precision_config + @property + def mixed_precision_config(self) -> Optional[MixedPrecision]: + if self.mixed_precision: + return self.mixed_precision + plugin = self.precision_plugin + if isinstance(plugin, FSDPPrecision): + return plugin.mixed_precision_config def _configure_launcher(self) -> None: assert self.cluster_environment is not None @@ -189,6 +168,7 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> FullyShardedDataParallel: """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel if ( any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()) and "auto_wrap_policy" in self._ddp_kwargs @@ -209,12 +189,14 @@ def module_to_device(self, module: Module) -> None: @contextmanager def module_sharded_context(self) -> Generator: + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + from torch.distributed.fsdp.wrap import enable_wrap + with enable_wrap( wrapper_cls=FullyShardedDataParallel, - # process_group=self.process_group, cpu_offload=self.cpu_offload, backward_prefetch=self.backward_prefetch, - mixed_precision=self.precision_plugin.mixed_precision_config, + mixed_precision=self.mixed_precision_config, device_id=self.root_device.index, **self._ddp_kwargs, ): @@ -244,6 +226,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload + strategy_registry.register( "fsdp", cls, From a2130b95469b9bb412dc3409a74e93fc10db8e97 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 1 Oct 2022 15:17:24 +0200 Subject: [PATCH 04/84] fsdp support in lite --- src/lightning_lite/connector.py | 16 +++++-- src/lightning_lite/plugins/precision/fsdp.py | 4 +- src/lightning_lite/strategies/fsdp.py | 50 ++++++++------------ 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py index 8f63653b3586f..c1372230e87d0 100644 --- a/src/lightning_lite/connector.py +++ b/src/lightning_lite/connector.py @@ -40,6 +40,7 @@ TorchElasticEnvironment, ) from lightning_lite.plugins.precision.double import DoublePrecision +from lightning_lite.plugins.precision.fsdp import FSDPPrecision from lightning_lite.strategies import ( DDPShardedStrategy, DDPSpawnShardedStrategy, @@ -53,7 +54,7 @@ XLAStrategy, ) from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES -from lightning_lite.strategies.fsdp import FSDPStrategy +from lightning_lite.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn from lightning_lite.utilities.device_parser import determine_root_gpu_device from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE @@ -409,6 +410,13 @@ def _check_strategy_and_fallback(self) -> None: f"You selected `Lite(strategy='{strategy_flag}')` but process forking is not supported on this" f" platform. We recommed `Lite(strategy='ddp_spawn')` instead." ) + if ( + strategy_flag in _FSDP_ALIASES or isinstance(self._strategy_flag, FSDPStrategy) + ) and self._accelerator_flag not in ("cuda", "gpu"): + raise ValueError( + f"You selected the FSDP strategy but FSDP is only available on GPU. Set `Lite(accelerator='gpu', ...)`" + " to continue or select a different strategy." + ) if strategy_flag: self._strategy_flag = strategy_flag @@ -457,9 +465,11 @@ def _check_and_init_precision(self) -> Precision: if self._precision_flag == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" - return NativeMixedPrecision(self._precision_flag, device) + + if isinstance(self.strategy, FSDPStrategy): + return FSDPPrecision(precision=self._precision_flag, device=device) + return NativeMixedPrecision(precision=self._precision_flag, device=device) raise RuntimeError("No precision set") diff --git a/src/lightning_lite/plugins/precision/fsdp.py b/src/lightning_lite/plugins/precision/fsdp.py index 626dec876b714..ecc938d912122 100644 --- a/src/lightning_lite/plugins/precision/fsdp.py +++ b/src/lightning_lite/plugins/precision/fsdp.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, TYPE_CHECKING, Literal +from typing import Literal, Optional, TYPE_CHECKING import torch -from lightning_lite.utilities.enums import PrecisionType from lightning_lite.plugins.precision import NativeMixedPrecision +from lightning_lite.utilities.enums import PrecisionType from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 if TYPE_CHECKING: diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 591df147d4e60..444d7fa34eafa 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, List, Optional, Union, TYPE_CHECKING +from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING, Union import torch from torch import Tensor @@ -21,18 +21,17 @@ from torch.nn import Module from lightning_lite.accelerators import Accelerator -from lightning_lite.plugins import CheckpointIO, ClusterEnvironment +from lightning_lite.plugins import CheckpointIO, ClusterEnvironment, Precision from lightning_lite.plugins.precision.fsdp import FSDPPrecision -from lightning_lite.utilities.distributed import get_default_process_group_backend_for_device, distributed_available -from lightning_lite.utilities.distributed import group as _group -from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available -from lightning_lite.utilities.seed import reset_seed -from lightning_lite.plugins import Precision from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning_lite.strategies.parallel import ParallelStrategy from lightning_lite.strategies.strategy import TBroadcast +from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device +from lightning_lite.utilities.distributed import group as _group +from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from lightning_lite.utilities.rank_zero import rank_zero_only +from lightning_lite.utilities.seed import reset_seed if TYPE_CHECKING: from torch.distributed.fsdp.fully_sharded_data_parallel import ( @@ -43,11 +42,13 @@ ) from torch.distributed.fsdp.wrap import enable_wrap +_FSDP_ALIASES = ("fsdp", "fsdp_full_shard_offload") + class FSDPStrategy(ParallelStrategy): r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. - .. warning:: ``DDPFullyShardedNativeStrategy`` is in BETA and subject to change. The interface can + .. warning:: ``FSDPStrategy`` is in BETA and subject to change. The interface can bring breaking changes and new features with the next release of PyTorch. Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model @@ -62,30 +63,20 @@ class FSDPStrategy(ParallelStrategy): `this tutorial `__ for more information. Arguments: - cpu_offload: - CPU offloading config. Currently, only parameter and gradient CPU - offload is supported. It can be enabled via passing in - ``cpu_offload=CPUOffload(offload_params=True)``. Note that this - currently implicitly enables gradient offloading to CPU in order for - params and grads to be on same device to work with optimizer. This - API is subject to change. Default is ``None`` in which case there + cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It + can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently + implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device + to work with the optimizer. This API is subject to change. Default is ``None`` in which case there will be no offloading. - backward_prefetch: - This is an experimental feature that is subject to change in the - the near future. It allows users to enable two different backward_prefetch - algorithms to help backward communication and computation overlapping. - The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. - mixed_precision: - Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` - or BF16 if ``precision=bf16`` unless a config is passed in. - This is only available in PyTorch 1.12 and later. - \**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules. - + backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows + users to enable two different backward prefetching algorithms to help backward communication and + computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. + mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 + if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. + \**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class + when wrapping modules. """ - strategy_name = "fsdp_native" - _registered_strategies: List[str] = [] - def __init__( self, accelerator: Optional[Accelerator] = None, @@ -169,6 +160,7 @@ def setup_module(self, module: Module) -> FullyShardedDataParallel: """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + if ( any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()) and "auto_wrap_policy" in self._ddp_kwargs From 80d24fec91b83dd564b43d8d3d0c803cd36765d3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 1 Oct 2022 17:11:58 +0200 Subject: [PATCH 05/84] typing fixes --- src/lightning_lite/connector.py | 2 +- src/lightning_lite/plugins/precision/fsdp.py | 2 +- src/lightning_lite/strategies/fsdp.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py index c1372230e87d0..824e590c14868 100644 --- a/src/lightning_lite/connector.py +++ b/src/lightning_lite/connector.py @@ -414,7 +414,7 @@ def _check_strategy_and_fallback(self) -> None: strategy_flag in _FSDP_ALIASES or isinstance(self._strategy_flag, FSDPStrategy) ) and self._accelerator_flag not in ("cuda", "gpu"): raise ValueError( - f"You selected the FSDP strategy but FSDP is only available on GPU. Set `Lite(accelerator='gpu', ...)`" + "You selected the FSDP strategy but FSDP is only available on GPU. Set `Lite(accelerator='gpu', ...)`" " to continue or select a different strategy." ) if strategy_flag: diff --git a/src/lightning_lite/plugins/precision/fsdp.py b/src/lightning_lite/plugins/precision/fsdp.py index ecc938d912122..4ef4bbfe168cf 100644 --- a/src/lightning_lite/plugins/precision/fsdp.py +++ b/src/lightning_lite/plugins/precision/fsdp.py @@ -27,7 +27,7 @@ class FSDPPrecision(NativeMixedPrecision): """AMP for Fully Sharded Data Parallel training.""" - def __init__(self, precision: Literal[16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None) -> None: + def __init__(self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None) -> None: if not _TORCH_GREATER_EQUAL_1_12: raise RuntimeError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.") diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 444d7fa34eafa..d4e5d52cb4441 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -40,7 +40,7 @@ FullyShardedDataParallel, MixedPrecision, ) - from torch.distributed.fsdp.wrap import enable_wrap + from torch.distributed.fsdp.wrap import enable_wrap # noqa: F401 _FSDP_ALIASES = ("fsdp", "fsdp_full_shard_offload") @@ -86,9 +86,9 @@ def __init__( precision_plugin: Optional[Precision] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, - cpu_offload: Optional[CPUOffload] = None, - backward_prefetch: Optional[BackwardPrefetch] = None, - mixed_precision: Optional[MixedPrecision] = None, + cpu_offload: Optional["CPUOffload"] = None, + backward_prefetch: Optional["BackwardPrefetch"] = None, + mixed_precision: Optional["MixedPrecision"] = None, **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_1_12: @@ -156,7 +156,7 @@ def setup_environment(self) -> None: self._setup_distributed() super().setup_environment() - def setup_module(self, module: Module) -> FullyShardedDataParallel: + def setup_module(self, module: Module) -> "FullyShardedDataParallel": """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel From cc65718d947f7746970108bdd72a6fc637395eac Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 1 Oct 2022 17:16:01 +0200 Subject: [PATCH 06/84] imports --- src/lightning_lite/plugins/__init__.py | 1 + src/lightning_lite/plugins/precision/__init__.py | 1 + src/lightning_lite/plugins/precision/fsdp.py | 4 +++- src/lightning_lite/strategies/__init__.py | 1 + 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/lightning_lite/plugins/__init__.py b/src/lightning_lite/plugins/__init__.py index 54aa3a4e4e113..0d166904491be 100644 --- a/src/lightning_lite/plugins/__init__.py +++ b/src/lightning_lite/plugins/__init__.py @@ -18,6 +18,7 @@ from lightning_lite.plugins.io.xla_plugin import XLACheckpointIO from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision from lightning_lite.plugins.precision.double import DoublePrecision +from lightning_lite.plugins.precision.fsdp import FSDPPrecision from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision from lightning_lite.plugins.precision.precision import Precision from lightning_lite.plugins.precision.tpu import TPUPrecision diff --git a/src/lightning_lite/plugins/precision/__init__.py b/src/lightning_lite/plugins/precision/__init__.py index 412ef9274822c..c390edd8e36f2 100644 --- a/src/lightning_lite/plugins/precision/__init__.py +++ b/src/lightning_lite/plugins/precision/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision from lightning_lite.plugins.precision.double import DoublePrecision +from lightning_lite.plugins.precision.fsdp import FSDPPrecision from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision from lightning_lite.plugins.precision.precision import Precision from lightning_lite.plugins.precision.tpu import TPUPrecision diff --git a/src/lightning_lite/plugins/precision/fsdp.py b/src/lightning_lite/plugins/precision/fsdp.py index 4ef4bbfe168cf..45f38838774a8 100644 --- a/src/lightning_lite/plugins/precision/fsdp.py +++ b/src/lightning_lite/plugins/precision/fsdp.py @@ -27,7 +27,9 @@ class FSDPPrecision(NativeMixedPrecision): """AMP for Fully Sharded Data Parallel training.""" - def __init__(self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None) -> None: + def __init__( + self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None + ) -> None: if not _TORCH_GREATER_EQUAL_1_12: raise RuntimeError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.") diff --git a/src/lightning_lite/strategies/__init__.py b/src/lightning_lite/strategies/__init__.py index f9cf74e30e4c0..a8d235708b573 100644 --- a/src/lightning_lite/strategies/__init__.py +++ b/src/lightning_lite/strategies/__init__.py @@ -17,6 +17,7 @@ from lightning_lite.strategies.dp import DataParallelStrategy # noqa: F401 from lightning_lite.strategies.fairscale import DDPShardedStrategy # noqa: F401 from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy # noqa: F401 +from lightning_lite.strategies.fsdp import FSDPStrategy # noqa: F401 from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401 from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401 From 9d91f8645297831edf678ce7d43bfc545ac17687 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 1 Oct 2022 17:18:09 +0200 Subject: [PATCH 07/84] import fixes --- src/lightning_lite/plugins/precision/fsdp.py | 4 ++-- src/lightning_lite/strategies/fsdp.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/plugins/precision/fsdp.py b/src/lightning_lite/plugins/precision/fsdp.py index 45f38838774a8..c508a2944c015 100644 --- a/src/lightning_lite/plugins/precision/fsdp.py +++ b/src/lightning_lite/plugins/precision/fsdp.py @@ -15,7 +15,7 @@ import torch -from lightning_lite.plugins.precision import NativeMixedPrecision +from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision from lightning_lite.utilities.enums import PrecisionType from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 @@ -42,7 +42,7 @@ def __init__( ) @property - def mixed_precision_config(self) -> MixedPrecision: + def mixed_precision_config(self) -> "MixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision if self.precision == PrecisionType.HALF: diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index d4e5d52cb4441..c914b13790d87 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -140,7 +140,7 @@ def process_group_backend(self) -> Optional[str]: return self._process_group_backend @property - def mixed_precision_config(self) -> Optional[MixedPrecision]: + def mixed_precision_config(self) -> Optional["MixedPrecision"]: if self.mixed_precision: return self.mixed_precision plugin = self.precision_plugin From b535621fcd8ecb98197f3766ab30f45dd7aafafc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 1 Oct 2022 17:20:33 +0200 Subject: [PATCH 08/84] fix test --- tests/tests_lite/strategies/test_registry.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_lite/strategies/test_registry.py b/tests/tests_lite/strategies/test_registry.py index 93c0071d9cd47..36bd11b45b339 100644 --- a/tests/tests_lite/strategies/test_registry.py +++ b/tests/tests_lite/strategies/test_registry.py @@ -65,4 +65,6 @@ def test_available_strategies_in_registry(): "tpu_spawn", "xla", "dp", + "fsdp", + "fsdp_full_shard_offload", } From 8e85f697afbd0697c584a82eb1026c12b559f275 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 2 Oct 2022 10:45:30 +0200 Subject: [PATCH 09/84] more tests --- src/lightning_lite/plugins/precision/fsdp.py | 2 +- src/lightning_lite/strategies/fsdp.py | 2 +- .../tests_lite/plugins/precision/test_fsdp.py | 36 +++++++++++ tests/tests_lite/strategies/test_fsdp.py | 60 +++++++++++++++++++ tests/tests_lite/test_connector.py | 9 ++- 5 files changed, 106 insertions(+), 3 deletions(-) create mode 100644 tests/tests_lite/plugins/precision/test_fsdp.py create mode 100644 tests/tests_lite/strategies/test_fsdp.py diff --git a/src/lightning_lite/plugins/precision/fsdp.py b/src/lightning_lite/plugins/precision/fsdp.py index c508a2944c015..e8aafea518b12 100644 --- a/src/lightning_lite/plugins/precision/fsdp.py +++ b/src/lightning_lite/plugins/precision/fsdp.py @@ -31,7 +31,7 @@ def __init__( self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None ) -> None: if not _TORCH_GREATER_EQUAL_1_12: - raise RuntimeError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.") + raise NotImplementedError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.") from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index c914b13790d87..85ba4d4f0edb6 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -92,7 +92,7 @@ def __init__( **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_1_12: - raise RuntimeError("`FSDPStrategy` is supported from PyTorch v1.12.0 onwards.") + raise NotImplementedError("`FSDPStrategy` is supported from PyTorch v1.12.0 onwards.") super().__init__( accelerator=accelerator, diff --git a/tests/tests_lite/plugins/precision/test_fsdp.py b/tests/tests_lite/plugins/precision/test_fsdp.py new file mode 100644 index 0000000000000..03a5fcfe33463 --- /dev/null +++ b/tests/tests_lite/plugins/precision/test_fsdp.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest +import torch +from tests_lite.helpers.runif import RunIf + +from lightning_lite.plugins import FSDPPrecision + + +@mock.patch("lightning_lite.plugins.precision.fsdp._TORCH_GREATER_EQUAL_1_12", False) +def test_fsdp_precision_support(*_): + with pytest.raises(NotImplementedError, match="`FSDPPrecision` is supported from PyTorch v1.12.0"): + FSDPPrecision(precision=16, device="cuda") + + +@RunIf(min_torch="1.12", min_cuda_gpus=1) +@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) +def test_fsdp_precision_config(precision, expected): + plugin = FSDPPrecision(precision=precision, device="cuda") + config = plugin.mixed_precision_config + assert config.param_dtype == expected + assert config.buffer_dtype == expected + assert config.reduce_dtype == expected diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py new file mode 100644 index 0000000000000..f90f7bf7b8c1c --- /dev/null +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -0,0 +1,60 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from typing import Any, Dict, Optional +from unittest import mock + +import pytest +import torch +import torch.nn as nn +from lightning_lite.plugins.precision.fsdp import FSDPPrecision +from lightning_lite.strategies import FSDPStrategy +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from tests_lite.helpers.runif import RunIf + +if _TORCH_GREATER_EQUAL_1_12: + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision + from torch.distributed.fsdp.wrap import wrap + + +@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False) +def test_fsdp_support(*_): + with pytest.raises(NotImplementedError, match="`FSDPStrategy` is supported from PyTorch v1.12.0"): + FSDPStrategy() + + +@RunIf(min_torch="1.12") +# @mock.patch("lightning_lite.strategies.fsdp.FullyShardedDataParallel") +def test_fsdp_custom_mixed_precision(*_): + """Test that passing a custom mixed precision config works.""" + config = MixedPrecision() + strategy = FSDPStrategy(mixed_precision=config) + assert strategy.mixed_precision_config == config + # + # + # wrapped_module = strategy.setup_module(nn.Linear(3, 3)) + # assert wrapped_module.mixed_precision == config + + +def custom_auto_wrap_policy( + module, + recurse, + unwrapped_params: int, + min_num_params: int = int(1e8), +) -> bool: + return unwrapped_params >= 2 + + diff --git a/tests/tests_lite/test_connector.py b/tests/tests_lite/test_connector.py index 73a22fc473c3f..66a23b2ca21fd 100644 --- a/tests/tests_lite/test_connector.py +++ b/tests/tests_lite/test_connector.py @@ -42,7 +42,7 @@ DDPSpawnStrategy, DDPStrategy, DeepSpeedStrategy, - SingleDeviceStrategy, + SingleDeviceStrategy, FSDPStrategy, ) from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES from lightning_lite.utilities.exceptions import MisconfigurationException @@ -746,3 +746,10 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin plugins=plugin, ) assert isinstance(trainer.precision_plugin, plugin_cls) + + +@RunIf(min_torch="1.12") +def test_fsdp_unsupported_on_cpu(): + """Test that we raise an error if attempting to run FSDP without GPU.""" + with pytest.raises(ValueError, match="You selected the FSDP strategy but FSDP is only available on GPU"): + _Connector(strategy="fsdp") From 5385a19a74ace4f8d2da8000ba52905ff74bf718 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 2 Oct 2022 10:45:37 +0200 Subject: [PATCH 10/84] integration tests --- .../strategies/test_fsdp_integration.py | 126 ++++++++++++++++++ .../test_ddp_fully_sharded_native.py | 4 + 2 files changed, 130 insertions(+) create mode 100644 tests/tests_lite/strategies/test_fsdp_integration.py diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py new file mode 100644 index 0000000000000..e1fab8d38ddae --- /dev/null +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -0,0 +1,126 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile + +import pytest +import torch +from tests_lite.helpers.models import BoringLite +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.fsdp.wrap import wrap + +from lightning_lite.plugins import FSDPPrecision +from tests.tests_lite.helpers.runif import RunIf + + +class FSDPLite(BoringLite): + def get_model(self): + model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + for i, layer in enumerate(model): + if i % 2 == 0: + model[i] = wrap(layer) + + def step(self, model, batch): + forward_module = model._forward_module + original_module = model.module + assert isinstance(forward_module, FullyShardedDataParallel) + assert isinstance(self._precision_plugin, FSDPPrecision) + # the root module should not be resharding + assert forward_module.reshard_after_forward is False + + precision = torch.float16 if self._precision_plugin.precision == 16 else torch.bfloat16 + assert forward_module.mixed_precision.param_dtype == precision + assert forward_module.mixed_precision.reduce_dtype == precision + assert forward_module.mixed_precision.buffer_dtype == precision + + for layer_num in [0, 2]: + assert isinstance(original_module[layer_num], FullyShardedDataParallel) + # The nested layers should have `reshard_after_forward` set to True + assert original_module[layer_num].reshard_after_forward + + assert original_module[layer_num].mixed_precision.param_dtype == precision + assert original_module[layer_num].mixed_precision.reduce_dtype == precision + assert original_module[layer_num].mixed_precision.buffer_dtype == precision + + return super().step(model, batch) + + def run(self): + super().run() + with tempfile.TemporaryFile() as ckpt_path: + ckpt_path = self.broadcast(str(ckpt_path)) + + + checkpoint = dict( + model=self.model.state_dict(), + optimizer = self.optimizer.state_dict() + ) + + self._strategy.save_checkpoint(checkpoint, ckpt_path) + + _assert_save_equality(self, ckpt_path) + +# +# +# class TestFSDPModelAutoWrapped(BoringModel): +# def __init__(self): +# super().__init__() +# self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) +# +# def configure_optimizers(self): +# return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) +# +# def on_train_batch_end(self, outputs, batch, batch_idx) -> None: +# self._assert_layer_fsdp_instance() +# +# def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: +# self._assert_layer_fsdp_instance() +# +# def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: +# self._assert_layer_fsdp_instance() +# +# def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: +# self._assert_layer_fsdp_instance() +# +# def _assert_layer_fsdp_instance(self) -> None: +# assert isinstance(self.layer, torch.nn.Sequential) +# assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin) +# +# precision = torch.float16 if self.precision == 16 else torch.bfloat16 +# for layer_num in [0, 2]: +# assert isinstance(self.layer[layer_num], FullyShardedDataParallel) +# # Assert that the nested layers are set reshard_after_forward to True +# assert self.layer[layer_num].reshard_after_forward +# +# assert self.layer[layer_num].mixed_precision.param_dtype == precision +# assert self.layer[layer_num].mixed_precision.reduce_dtype == precision +# assert self.layer[layer_num].mixed_precision.buffer_dtype == precision + + + +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") +@pytest.mark.parametrize("precision", (16, pytest.param("bf16", marks=RunIf(bf16_cuda=True)))) +def test_fsdp_train_save_load(precision): + """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" + FSDPLite(accelerator="cuda", strategy="fsdp", devices=2, precision=precision).run() + + +def _assert_save_equality(lite, ckpt_path): + model_state_dict = lite._strategy.get_module_state_dict() + + if lite.is_global_zero: + checkpoint = lite.load(ckpt_path) + saved_model = lite.get_model().load_state_dict(checkpoint["state_dict"]) + + # model parameters are identical after loading + for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): + assert torch.equal(ddp_param.float().cpu(), shard_param) diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index be8bced2cbf5f..96efbbd522c3c 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -27,6 +27,7 @@ def custom_auto_wrap_policy( return unwrapped_params >= 2 +# lite: adopted @RunIf(min_torch="1.12") def test_invalid_on_cpu(tmpdir): """Test to ensure that we raise Misconfiguration for Native FSDP on CPU.""" @@ -40,6 +41,7 @@ def test_invalid_on_cpu(tmpdir): trainer.strategy.setup_environment() +# lite: adopted @RunIf(min_torch="1.12", min_cuda_gpus=1) @pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) def test_precision_plugin_config(precision, expected): @@ -50,6 +52,7 @@ def test_precision_plugin_config(precision, expected): assert config.reduce_dtype == expected +# lite: adopted @RunIf(min_torch="1.12") def test_fsdp_custom_mixed_precision(tmpdir): """Test to ensure that passing a custom mixed precision config works.""" @@ -154,6 +157,7 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer[layer_num].mixed_precision.buffer_dtype == precision +# lite: unimplemented @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): """Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run.""" From de24f124e4bcb5da6fef721c751e6584ba6880a3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 2 Oct 2022 13:53:49 +0200 Subject: [PATCH 11/84] debug debug debug debug debug debug debug debug debug debug debug debug debug debug clean up debug debug debug debug debug debug debug changelog reset x --- src/lightning_lite/plugins/__init__.py | 1 + .../plugins/precision/__init__.py | 1 + src/pytorch_lightning/CHANGELOG.md | 3 +- tests/tests_lite/strategies/test_fsdp.py | 13 +-- .../strategies/test_fsdp_integration.py | 89 ++++++------------- tests/tests_lite/test_connector.py | 2 +- .../test_ddp_fully_sharded_native.py | 2 + 7 files changed, 37 insertions(+), 74 deletions(-) diff --git a/src/lightning_lite/plugins/__init__.py b/src/lightning_lite/plugins/__init__.py index 0d166904491be..ede84b67efcf4 100644 --- a/src/lightning_lite/plugins/__init__.py +++ b/src/lightning_lite/plugins/__init__.py @@ -35,4 +35,5 @@ "Precision", "TPUPrecision", "TPUBf16Precision", + "FSDPPrecision", ] diff --git a/src/lightning_lite/plugins/precision/__init__.py b/src/lightning_lite/plugins/precision/__init__.py index c390edd8e36f2..c47ffeb3f9fc1 100644 --- a/src/lightning_lite/plugins/precision/__init__.py +++ b/src/lightning_lite/plugins/precision/__init__.py @@ -26,4 +26,5 @@ "Precision", "TPUPrecision", "TPUBf16Precision", + "FSDPPrecision", ] diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 5acbca7439a9b..b2c06353a20dd 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -57,10 +57,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Introduce `ckpt_path="hpc"` keyword for checkpoint loading ([#14911](https://github.,com/Lightning-AI/lightning/pull/14911)) - - Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709)) +- Added support for Fully Sharded Data Parallel (FSDP) training in Lightning Lite ([#14967](https://github.com/Lightning-AI/lightning/issues/14967)) + ### Changed diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index f90f7bf7b8c1c..6eac845d834d0 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -13,21 +13,16 @@ # limitations under the License. -import os -from typing import Any, Dict, Optional from unittest import mock import pytest -import torch -import torch.nn as nn -from lightning_lite.plugins.precision.fsdp import FSDPPrecision +from tests_lite.helpers.runif import RunIf + from lightning_lite.strategies import FSDPStrategy from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 -from tests_lite.helpers.runif import RunIf if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision - from torch.distributed.fsdp.wrap import wrap + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision @mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False) @@ -56,5 +51,3 @@ def custom_auto_wrap_policy( min_num_params: int = int(1e8), ) -> bool: return unwrapped_params >= 2 - - diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index e1fab8d38ddae..d983664f433b0 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -16,27 +16,32 @@ import pytest import torch from tests_lite.helpers.models import BoringLite +from tests_lite.helpers.runif import RunIf from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp.wrap import wrap from lightning_lite.plugins import FSDPPrecision -from tests.tests_lite.helpers.runif import RunIf class FSDPLite(BoringLite): + manual_wrapping = False + def get_model(self): model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + if not self.manual_wrapping: + return model + for i, layer in enumerate(model): if i % 2 == 0: model[i] = wrap(layer) + model = wrap(model) + return model def step(self, model, batch): forward_module = model._forward_module original_module = model.module assert isinstance(forward_module, FullyShardedDataParallel) assert isinstance(self._precision_plugin, FSDPPrecision) - # the root module should not be resharding - assert forward_module.reshard_after_forward is False precision = torch.float16 if self._precision_plugin.precision == 16 else torch.bfloat16 assert forward_module.mixed_precision.param_dtype == precision @@ -44,9 +49,10 @@ def step(self, model, batch): assert forward_module.mixed_precision.buffer_dtype == precision for layer_num in [0, 2]: - assert isinstance(original_module[layer_num], FullyShardedDataParallel) - # The nested layers should have `reshard_after_forward` set to True - assert original_module[layer_num].reshard_after_forward + if self.manual_wrapping: + assert isinstance(original_module[layer_num], FullyShardedDataParallel) + else: + assert isinstance(forward_module[layer_num], FullyShardedDataParallel) assert original_module[layer_num].mixed_precision.param_dtype == precision assert original_module[layer_num].mixed_precision.reduce_dtype == precision @@ -58,69 +64,28 @@ def run(self): super().run() with tempfile.TemporaryFile() as ckpt_path: ckpt_path = self.broadcast(str(ckpt_path)) + self._strategy.save_checkpoint(self.model.state_dict(), ckpt_path) + self._assert_save_equality(ckpt_path) - checkpoint = dict( - model=self.model.state_dict(), - optimizer = self.optimizer.state_dict() - ) - - self._strategy.save_checkpoint(checkpoint, ckpt_path) + def _assert_save_equality(self, ckpt_path): + current_state_dict = self._strategy.get_module_state_dict(self.model) - _assert_save_equality(self, ckpt_path) - -# -# -# class TestFSDPModelAutoWrapped(BoringModel): -# def __init__(self): -# super().__init__() -# self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) -# -# def configure_optimizers(self): -# return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) -# -# def on_train_batch_end(self, outputs, batch, batch_idx) -> None: -# self._assert_layer_fsdp_instance() -# -# def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: -# self._assert_layer_fsdp_instance() -# -# def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: -# self._assert_layer_fsdp_instance() -# -# def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None: -# self._assert_layer_fsdp_instance() -# -# def _assert_layer_fsdp_instance(self) -> None: -# assert isinstance(self.layer, torch.nn.Sequential) -# assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin) -# -# precision = torch.float16 if self.precision == 16 else torch.bfloat16 -# for layer_num in [0, 2]: -# assert isinstance(self.layer[layer_num], FullyShardedDataParallel) -# # Assert that the nested layers are set reshard_after_forward to True -# assert self.layer[layer_num].reshard_after_forward -# -# assert self.layer[layer_num].mixed_precision.param_dtype == precision -# assert self.layer[layer_num].mixed_precision.reduce_dtype == precision -# assert self.layer[layer_num].mixed_precision.buffer_dtype == precision + checkpoint = self.load(ckpt_path) + loaded_model = self.get_model() + loaded_model.load_state_dict(checkpoint) + # model parameters are identical after loading + for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()): + assert torch.equal(current_param.float().cpu(), loaded_param.cpu()) @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize("precision", (16, pytest.param("bf16", marks=RunIf(bf16_cuda=True)))) -def test_fsdp_train_save_load(precision): +@pytest.mark.parametrize("manual_wrapping", [True, False]) +def test_fsdp_train_save_load(manual_wrapping, precision): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" - FSDPLite(accelerator="cuda", strategy="fsdp", devices=2, precision=precision).run() + lite = FSDPLite(accelerator="cuda", strategy="fsdp", devices=2, precision=precision) + lite.manual_wrapping = manual_wrapping + lite.run() - -def _assert_save_equality(lite, ckpt_path): - model_state_dict = lite._strategy.get_module_state_dict() - - if lite.is_global_zero: - checkpoint = lite.load(ckpt_path) - saved_model = lite.get_model().load_state_dict(checkpoint["state_dict"]) - - # model parameters are identical after loading - for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): - assert torch.equal(ddp_param.float().cpu(), shard_param) diff --git a/tests/tests_lite/test_connector.py b/tests/tests_lite/test_connector.py index 66a23b2ca21fd..ca2502e94a4a3 100644 --- a/tests/tests_lite/test_connector.py +++ b/tests/tests_lite/test_connector.py @@ -42,7 +42,7 @@ DDPSpawnStrategy, DDPStrategy, DeepSpeedStrategy, - SingleDeviceStrategy, FSDPStrategy, + SingleDeviceStrategy, ) from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES from lightning_lite.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index 96efbbd522c3c..2dcefb2b44ba7 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -175,6 +175,7 @@ def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) +# lite: adopted @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize("precision", (16, pytest.param("bf16", marks=RunIf(bf16_cuda=True)))) def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision): @@ -186,6 +187,7 @@ def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision): _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) +# lite: adopted @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize( "model, strategy", From 9051a138bf9a4264d5b4bee20ea6a2ddcd37ba2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 2 Oct 2022 11:56:00 +0000 Subject: [PATCH 12/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_lite/strategies/test_fsdp_integration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index d983664f433b0..e875b105e0630 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -88,4 +88,3 @@ def test_fsdp_train_save_load(manual_wrapping, precision): lite = FSDPLite(accelerator="cuda", strategy="fsdp", devices=2, precision=precision) lite.manual_wrapping = manual_wrapping lite.run() - From 782bd0ac3680ce7fcaeaca4dd35babeac709e0be Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 3 Oct 2022 00:05:03 +0200 Subject: [PATCH 13/84] fix autowrap policy --- tests/tests_lite/strategies/test_fsdp.py | 8 -------- .../tests_lite/strategies/test_fsdp_integration.py | 14 ++++++++++++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 6eac845d834d0..91c7cd8aa50a7 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -43,11 +43,3 @@ def test_fsdp_custom_mixed_precision(*_): # wrapped_module = strategy.setup_module(nn.Linear(3, 3)) # assert wrapped_module.mixed_precision == config - -def custom_auto_wrap_policy( - module, - recurse, - unwrapped_params: int, - min_num_params: int = int(1e8), -) -> bool: - return unwrapped_params >= 2 diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index d983664f433b0..ff85497b55bf2 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -21,6 +21,7 @@ from torch.distributed.fsdp.wrap import wrap from lightning_lite.plugins import FSDPPrecision +from lightning_lite.strategies import FSDPStrategy class FSDPLite(BoringLite): @@ -80,12 +81,21 @@ def _assert_save_equality(self, ckpt_path): assert torch.equal(current_param.float().cpu(), loaded_param.cpu()) +def custom_auto_wrap_policy( + module, + recurse, + unwrapped_params: int, + min_num_params: int = int(1e8), +) -> bool: + return unwrapped_params >= 2 + + @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize("precision", (16, pytest.param("bf16", marks=RunIf(bf16_cuda=True)))) @pytest.mark.parametrize("manual_wrapping", [True, False]) def test_fsdp_train_save_load(manual_wrapping, precision): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" - lite = FSDPLite(accelerator="cuda", strategy="fsdp", devices=2, precision=precision) + strategy = FSDPStrategy(auto_wrap_policy=custom_auto_wrap_policy) + lite = FSDPLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) lite.manual_wrapping = manual_wrapping lite.run() - From 34251a2fd5761c79dcfd49a410f1cb8baecfeb8e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 2 Oct 2022 22:07:39 +0000 Subject: [PATCH 14/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_lite/strategies/test_fsdp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 91c7cd8aa50a7..17d3be06cbcd5 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -42,4 +42,3 @@ def test_fsdp_custom_mixed_precision(*_): # # wrapped_module = strategy.setup_module(nn.Linear(3, 3)) # assert wrapped_module.mixed_precision == config - From c8eb2b1398f14a7ff0ed5630ad4578de295a6485 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 3 Oct 2022 00:08:29 +0200 Subject: [PATCH 15/84] debug --- tests/tests_lite/strategies/test_fsdp_integration.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index ff85497b55bf2..2d248e22db0b0 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -81,12 +81,7 @@ def _assert_save_equality(self, ckpt_path): assert torch.equal(current_param.float().cpu(), loaded_param.cpu()) -def custom_auto_wrap_policy( - module, - recurse, - unwrapped_params: int, - min_num_params: int = int(1e8), -) -> bool: +def custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool: return unwrapped_params >= 2 @@ -95,7 +90,7 @@ def custom_auto_wrap_policy( @pytest.mark.parametrize("manual_wrapping", [True, False]) def test_fsdp_train_save_load(manual_wrapping, precision): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" - strategy = FSDPStrategy(auto_wrap_policy=custom_auto_wrap_policy) + strategy = FSDPStrategy() if manual_wrapping else FSDPStrategy(auto_wrap_policy=custom_auto_wrap_policy) lite = FSDPLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) lite.manual_wrapping = manual_wrapping lite.run() From c9dd26e801bc86a7888735794f4a0e2b7b9c3d78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 3 Oct 2022 00:13:16 +0200 Subject: [PATCH 16/84] debug --- tests/tests_lite/strategies/test_fsdp_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 2d248e22db0b0..c28645c7c5c5f 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -53,7 +53,7 @@ def step(self, model, batch): if self.manual_wrapping: assert isinstance(original_module[layer_num], FullyShardedDataParallel) else: - assert isinstance(forward_module[layer_num], FullyShardedDataParallel) + assert isinstance(original_module[layer_num], FullyShardedDataParallel) assert original_module[layer_num].mixed_precision.param_dtype == precision assert original_module[layer_num].mixed_precision.reduce_dtype == precision From 4832c9ed5100d3138cbc2257ee13356f261733bf Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 4 Oct 2022 00:02:55 +0200 Subject: [PATCH 17/84] simplify --- tests/tests_lite/strategies/test_fsdp_integration.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index c28645c7c5c5f..b977b8735ed14 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -50,11 +50,7 @@ def step(self, model, batch): assert forward_module.mixed_precision.buffer_dtype == precision for layer_num in [0, 2]: - if self.manual_wrapping: - assert isinstance(original_module[layer_num], FullyShardedDataParallel) - else: - assert isinstance(original_module[layer_num], FullyShardedDataParallel) - + assert isinstance(original_module[layer_num], FullyShardedDataParallel) assert original_module[layer_num].mixed_precision.param_dtype == precision assert original_module[layer_num].mixed_precision.reduce_dtype == precision assert original_module[layer_num].mixed_precision.buffer_dtype == precision @@ -89,7 +85,7 @@ def custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_para @pytest.mark.parametrize("precision", (16, pytest.param("bf16", marks=RunIf(bf16_cuda=True)))) @pytest.mark.parametrize("manual_wrapping", [True, False]) def test_fsdp_train_save_load(manual_wrapping, precision): - """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" + """Test FSDP training, saving and loading with different wrapping and precision settings.""" strategy = FSDPStrategy() if manual_wrapping else FSDPStrategy(auto_wrap_policy=custom_auto_wrap_policy) lite = FSDPLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) lite.manual_wrapping = manual_wrapping From 70437da06809ae5e3196304df8a50c543427ba1b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 19 Oct 2022 08:02:10 +0200 Subject: [PATCH 18/84] support individual setup of model and optimizer --- src/lightning_lite/lite.py | 88 ++++++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 19 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 17b051c7becf0..d6fdc472bfe26 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -135,7 +135,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any: def setup( self, - model: nn.Module, + model: Optional[nn.Module] = None, *optimizers: Optimizer, move_to_device: bool = True, ) -> Any: # no specific return because the way we want our API to look does not play well with mypy @@ -151,26 +151,26 @@ def setup( The tuple of the wrapped model and list of optimizers, in the same order they were passed in. """ self._validate_setup(model, optimizers) - original_model = model - - model = self._precision.convert_module(model) - - if move_to_device: - model = self._move_model_to_device(model=model, optimizers=list(optimizers)) - # Let accelerator/plugin wrap and connect the models and optimizers - model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers)) - model = _LiteModule(model, self._precision, original_module=original_model) - - # Update the _DeviceDtypeModuleMixin's device parameter - model.to(self.device if move_to_device else next(model.parameters()).device) - - optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] - self._models_setup += 1 + if model and not optimizers: + # set up a model without optimizers (e.g., for inference) + model = self._setup_model(model, move_to_device=move_to_device) + elif not model and optimizers: + # set up one or more optimizers separately from the model; some strategies don't support that + optimizers = self._setup_optimizers(*optimizers) + elif model and optimizers: + # set up model and optimizers jointly; some strategies require this + model, optimizers = self._setup_model_and_optimizers(model, *optimizers, move_to_device=move_to_device) + + outputs = [] + if model: + outputs.append(model) if optimizers: # join both types in a list for API convenience - return [model] + optimizers # type: ignore - return model + outputs.extend(optimizers) # type: ignore + if len(outputs) == 1: + return outputs[0] + return outputs def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True @@ -396,6 +396,53 @@ def _run_with_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> An ), _replace_dunder_methods(BatchSampler): return run_method(*args, **kwargs) + def _setup_model_and_optimizers( + self, + model: nn.Module, + *optimizers: Optimizer, + move_to_device: bool = True, + ) -> Tuple[_LiteModule, List[_LiteOptimizer]]: + original_model = model + + model = self._precision.convert_module(model) + + if move_to_device: + model = self._move_model_to_device(model=model, optimizers=list(optimizers)) + + # Let strategy wrap and connect the models and optimizers + model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers)) + model = _LiteModule(model, self._precision, original_module=original_model) + + # Update the _DeviceDtypeModuleMixin's device parameter + model.to(self.device if move_to_device else next(model.parameters()).device) + + optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] + self._models_setup += 1 + return model, optimizers + + def _setup_model(self, model: nn.Module, move_to_device: bool = True) -> _LiteModule: + original_model = model + + model = self._precision.convert_module(model) + + if move_to_device: + model = self._move_model_to_device(model=model, optimizers=[]) + + # Let strategy wrap and connect the model alone + model = self._strategy.setup_module(model) + model = _LiteModule(model, self._precision, original_module=original_model) + + # Update the _DeviceDtypeModuleMixin's device parameter + model.to(self.device if move_to_device else next(model.parameters()).device) + + self._models_setup += 1 + return model + + def _setup_optimizers(self, *optimizers: Optimizer) -> List[_LiteOptimizer]: + optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers] + optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] + return optimizers + def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: initial_device = next(model.parameters()).device if any(param.device != initial_device for param in model.parameters()): @@ -436,13 +483,16 @@ def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> Distribut return DistributedSamplerWrapper(dataloader.sampler, **kwargs) @staticmethod - def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: + def _validate_setup(model: Optional[nn.Module], optimizers: Sequence[Optimizer]) -> None: if isinstance(model, _LiteModule): raise ValueError("A model should be passed only once to the `setup` method.") if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): raise ValueError("An optimizer should be passed only once to the `setup` method.") + if model is None and not optimizers: + raise ValueError("`setup` requires at least a model or an optimizer.") + @staticmethod def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None: if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders): From 50981a3126a97d60939dac5612319541f45d5289 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 19 Oct 2022 08:19:09 +0200 Subject: [PATCH 19/84] error messaging --- src/lightning_lite/strategies/deepspeed.py | 17 +++++++++++------ src/lightning_lite/strategies/fairscale.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/lightning_lite/strategies/deepspeed.py b/src/lightning_lite/strategies/deepspeed.py index bb8a4790af275..20ec1981bd1f0 100644 --- a/src/lightning_lite/strategies/deepspeed.py +++ b/src/lightning_lite/strategies/deepspeed.py @@ -300,7 +300,7 @@ def model(self) -> "deepspeed.DeepSpeedEngine": return self._deepspeed_engine def setup_module_and_optimizers( - self, model: Module, optimizers: List[Optimizer] + self, module: Module, optimizers: List[Optimizer] ) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]: """Setup a model and multiple optimizers together. @@ -316,10 +316,17 @@ def setup_module_and_optimizers( f" Got {len(optimizers)} optimizers instead." ) - self._deepspeed_engine, optimizer = self._setup_module_and_optimizer(model, optimizers[0]) + self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0]) self._set_deepspeed_activation_checkpointing() return self._deepspeed_engine, [optimizer] + def setup_module(self, module: Module) -> "deepspeed.DeepSpeedEngine": + self._deepspeed_engine, _ = self._initialize_engine(module) + return self._deepspeed_engine + + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + raise RuntimeError("not supported") # TODO: proper error message + @contextlib.contextmanager def module_sharded_context(self) -> Generator[None, None, None]: # Current limitation in Lite: The config needs to be fully determined at the time of calling the @@ -396,11 +403,10 @@ def register_strategies(cls, strategy_registry: Dict) -> None: offload_optimizer_device="nvme", ) - def _setup_module_and_optimizer( + def _initialize_engine( self, model: Module, - optimizer: Optional[Optimizer], - lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None, + optimizer: Optional[Optimizer] = None, ) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. @@ -413,7 +419,6 @@ def _setup_module_and_optimizer( model=model, model_parameters=model_parameters, optimizer=optimizer, - lr_scheduler=lr_scheduler, dist_init_required=False, ) return deepspeed_engine, deepspeed_optimizer diff --git a/src/lightning_lite/strategies/fairscale.py b/src/lightning_lite/strategies/fairscale.py index 994c215cd1a92..4a739a7d14604 100644 --- a/src/lightning_lite/strategies/fairscale.py +++ b/src/lightning_lite/strategies/fairscale.py @@ -87,6 +87,12 @@ def setup_module_and_optimizers( model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers + def setup_module(self, module: Module) -> Module: + raise RuntimeError("not supported") # TODO: proper error message + + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + raise RuntimeError("not supported") # TODO: proper error message + @contextmanager def block_backward_sync(self, module: Module) -> Generator: """Blocks syncing gradients behaviour on backwards pass. @@ -163,6 +169,12 @@ def setup_module_and_optimizers( model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers + def setup_module(self, module: Module) -> Module: + raise RuntimeError("not supported") # TODO: proper error message + + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + raise RuntimeError("not supported") # TODO: proper error message + @contextmanager def block_backward_sync(self, module: Module) -> Generator: """Blocks syncing gradients behaviour on backwards pass. From 4eadd2486edd84faa3aa88a1bfbca3faf9a7763f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 19 Oct 2022 12:45:48 +0200 Subject: [PATCH 20/84] test --- src/lightning_lite/strategies/fsdp.py | 12 ++++- .../strategies/test_fsdp_integration.py | 52 ++++++++++++++----- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 85ba4d4f0edb6..1d5989635f0a5 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -13,12 +13,13 @@ # limitations under the License. from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING, Union, Tuple import torch from torch import Tensor from torch.distributed import default_pg_timeout from torch.nn import Module +from torch.optim import Optimizer from lightning_lite.accelerators import Accelerator from lightning_lite.plugins import CheckpointIO, ClusterEnvironment, Precision @@ -156,6 +157,11 @@ def setup_environment(self) -> None: self._setup_distributed() super().setup_environment() + def setup_module_and_optimizers( + self, module: Module, optimizers: List[Optimizer] + ) -> Tuple[Module, List[Optimizer]]: + raise RuntimeError("not supported") # TODO: proper error msg + def setup_module(self, module: Module) -> "FullyShardedDataParallel": """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" @@ -176,6 +182,10 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel": **self._ddp_kwargs, ) + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + # TODO: some validation here + return optimizer + def module_to_device(self, module: Module) -> None: pass diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index b977b8735ed14..37591dca75bfb 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -15,18 +15,50 @@ import pytest import torch -from tests_lite.helpers.models import BoringLite from tests_lite.helpers.runif import RunIf from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp.wrap import wrap +from torch.utils.data import DataLoader +from lightning_lite import LightningLite from lightning_lite.plugins import FSDPPrecision from lightning_lite.strategies import FSDPStrategy +from tests_lite.helpers.models import RandomDataset -class FSDPLite(BoringLite): +class FSDPLite(LightningLite): manual_wrapping = False + def run(self): + model = self.get_model() + + dataloader = DataLoader(RandomDataset(32, 64)) + + # model needs to be set up first in FSDP + model = self.setup(model) + + # get parameters on the wrapped model + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + # optimizer nees to be set up independently + optimizer = self.setup(None, optimizer) + + dataloader = self.setup_dataloaders(dataloader) + model.train() + + data_iter = iter(dataloader) + batch = next(data_iter) + loss = self.step(model, batch) + self.backward(loss) + optimizer.step() + optimizer.zero_grad() + + with tempfile.TemporaryFile() as ckpt_path: + ckpt_path = self.broadcast(str(ckpt_path)) + self._strategy.save_checkpoint(model.state_dict(), ckpt_path) + + self._assert_save_equality(model, ckpt_path) + def get_model(self): model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) if not self.manual_wrapping: @@ -55,18 +87,12 @@ def step(self, model, batch): assert original_module[layer_num].mixed_precision.reduce_dtype == precision assert original_module[layer_num].mixed_precision.buffer_dtype == precision - return super().step(model, batch) - - def run(self): - super().run() - with tempfile.TemporaryFile() as ckpt_path: - ckpt_path = self.broadcast(str(ckpt_path)) - self._strategy.save_checkpoint(self.model.state_dict(), ckpt_path) - - self._assert_save_equality(ckpt_path) + output = model(batch) + loss = torch.nn.functional.mse_loss(output, torch.ones_like(output)) + return loss - def _assert_save_equality(self, ckpt_path): - current_state_dict = self._strategy.get_module_state_dict(self.model) + def _assert_save_equality(self, model, ckpt_path): + current_state_dict = self._strategy.get_module_state_dict(model) checkpoint = self.load(ckpt_path) loaded_model = self.get_model() From dec4f9c86683c2ae3c2ef5bb7ef431a5fb362ad6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 19 Oct 2022 12:49:37 +0200 Subject: [PATCH 21/84] debug --- src/lightning_lite/strategies/fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 1d5989635f0a5..c193d7fc64a41 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -84,7 +84,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[Precision] = None, + precision: Optional[Precision] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, cpu_offload: Optional["CPUOffload"] = None, @@ -100,7 +100,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, - precision_plugin=precision_plugin, + precision=precision, ) self._num_nodes = 1 self._process_group_backend: Optional[str] = process_group_backend From d5f1c9ee31ac8d675a3b70553ac12518b22b5ad9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 19 Oct 2022 12:50:36 +0200 Subject: [PATCH 22/84] debug --- src/lightning_lite/strategies/fsdp.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index c193d7fc64a41..ae8a7e88e4f65 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -144,9 +144,8 @@ def process_group_backend(self) -> Optional[str]: def mixed_precision_config(self) -> Optional["MixedPrecision"]: if self.mixed_precision: return self.mixed_precision - plugin = self.precision_plugin - if isinstance(plugin, FSDPPrecision): - return plugin.mixed_precision_config + if isinstance(self.precision, FSDPPrecision): + return self.precision.mixed_precision_config def _configure_launcher(self) -> None: assert self.cluster_environment is not None From 559187b68e2c0ce055d26797a4a629d9d95e4713 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 19 Oct 2022 12:52:24 +0200 Subject: [PATCH 23/84] debug --- tests/tests_lite/strategies/test_fsdp_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 37591dca75bfb..867217762329c 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -74,9 +74,9 @@ def step(self, model, batch): forward_module = model._forward_module original_module = model.module assert isinstance(forward_module, FullyShardedDataParallel) - assert isinstance(self._precision_plugin, FSDPPrecision) + assert isinstance(self._precision, FSDPPrecision) - precision = torch.float16 if self._precision_plugin.precision == 16 else torch.bfloat16 + precision = torch.float16 if self._precision.precision == 16 else torch.bfloat16 assert forward_module.mixed_precision.param_dtype == precision assert forward_module.mixed_precision.reduce_dtype == precision assert forward_module.mixed_precision.buffer_dtype == precision From e286dd9aa852d294dd20ee173dd31712cc733c59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 19 Oct 2022 12:54:16 +0200 Subject: [PATCH 24/84] debug --- src/lightning_lite/connector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py index 861a4e5d527c7..113f0f2fcdeda 100644 --- a/src/lightning_lite/connector.py +++ b/src/lightning_lite/connector.py @@ -464,11 +464,10 @@ def _check_and_init_precision(self) -> Precision: else "Using bfloat16 Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" - return NativeMixedPrecision(self._precision_input, device) if isinstance(self.strategy, FSDPStrategy): - return FSDPPrecision(precision=self._precision_flag, device=device) - return NativeMixedPrecision(precision=self._precision_flag, device=device) + return FSDPPrecision(precision=self._precision_input, device=device) + return NativeMixedPrecision(precision=self._precision_input, device=device) raise RuntimeError("No precision set") From b4613ecf10dd40778c50a82630bafd80bffe1244 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 20 Oct 2022 21:52:17 +0200 Subject: [PATCH 25/84] wip --- src/lightning_lite/lite.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index b01b248b9067c..78a7c1f12b7d9 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -135,7 +135,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any: def setup( self, - model: Optional[nn.Module] = None, + model: nn.Module, *optimizers: Optimizer, move_to_device: bool = True, ) -> Any: # no specific return because the way we want our API to look does not play well with mypy @@ -148,7 +148,7 @@ def setup( and alternatively use :meth:`to_device` manually. Returns: - The tuple of the wrapped model and list of optimizers, in the same order they were passed in. + The tuple containing wrapped model and the optimizers, in the same order they were passed in. """ self._validate_setup(model, optimizers) @@ -529,16 +529,13 @@ def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> Distribut return DistributedSamplerWrapper(dataloader.sampler, **kwargs) @staticmethod - def _validate_setup(model: Optional[nn.Module], optimizers: Sequence[Optimizer]) -> None: + def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: if isinstance(model, _LiteModule): raise ValueError("A model should be passed only once to the `setup` method.") if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): raise ValueError("An optimizer should be passed only once to the `setup` method.") - if model is None and not optimizers: - raise ValueError("`setup` requires at least a model or an optimizer.") - @staticmethod def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None: if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders): From 281e26ccfd98d10723f395037c3bdd3583f249da Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 20 Oct 2022 22:03:07 +0200 Subject: [PATCH 26/84] wip --- src/lightning_lite/lite.py | 85 ++++++++--------------- src/lightning_lite/strategies/strategy.py | 2 +- tests/tests_lite/test_lite.py | 1 + 3 files changed, 30 insertions(+), 58 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 78a7c1f12b7d9..ee3b86cbc374a 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -154,23 +154,41 @@ def setup( if model and not optimizers: # set up a model without optimizers (e.g., for inference) - model = self._setup_model(model, move_to_device=move_to_device) + model = self.setup_model(model, move_to_device=move_to_device) elif not model and optimizers: # set up one or more optimizers separately from the model; some strategies don't support that - optimizers = self._setup_optimizers(*optimizers) + optimizers = self.setup_optimizers(*optimizers) elif model and optimizers: # set up model and optimizers jointly; some strategies require this model, optimizers = self._setup_model_and_optimizers(model, *optimizers, move_to_device=move_to_device) - outputs = [] - if model: - outputs.append(model) if optimizers: - # join both types in a list for API convenience - outputs.extend(optimizers) # type: ignore - if len(outputs) == 1: - return outputs[0] - return outputs + # join both types in a tuple for API convenience + return model, *optimizers + return model + + def setup_model(self, model: nn.Module, move_to_device: bool = True) -> _LiteModule: + original_model = model + + model = self._precision.convert_module(model) + + if move_to_device: + model = self._move_model_to_device(model=model, optimizers=[]) + + # Let strategy wrap and connect the model alone + model = self._strategy.setup_module(model) + model = _LiteModule(model, self._precision, original_module=original_model) + + # Update the _DeviceDtypeModuleMixin's device parameter + model.to(self.device if move_to_device else next(model.parameters()).device) + + self._models_setup += 1 + return model + + def setup_optimizers(self, *optimizers: Optimizer) -> Tuple[_LiteOptimizer]: + optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers] + optimizers = tuple(_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers) + return optimizers def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True @@ -442,53 +460,6 @@ def _run_with_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> An ), _replace_dunder_methods(BatchSampler): return run_method(*args, **kwargs) - def _setup_model_and_optimizers( - self, - model: nn.Module, - *optimizers: Optimizer, - move_to_device: bool = True, - ) -> Tuple[_LiteModule, List[_LiteOptimizer]]: - original_model = model - - model = self._precision.convert_module(model) - - if move_to_device: - model = self._move_model_to_device(model=model, optimizers=list(optimizers)) - - # Let strategy wrap and connect the models and optimizers - model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers)) - model = _LiteModule(model, self._precision, original_module=original_model) - - # Update the _DeviceDtypeModuleMixin's device parameter - model.to(self.device if move_to_device else next(model.parameters()).device) - - optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] - self._models_setup += 1 - return model, optimizers - - def _setup_model(self, model: nn.Module, move_to_device: bool = True) -> _LiteModule: - original_model = model - - model = self._precision.convert_module(model) - - if move_to_device: - model = self._move_model_to_device(model=model, optimizers=[]) - - # Let strategy wrap and connect the model alone - model = self._strategy.setup_module(model) - model = _LiteModule(model, self._precision, original_module=original_model) - - # Update the _DeviceDtypeModuleMixin's device parameter - model.to(self.device if move_to_device else next(model.parameters()).device) - - self._models_setup += 1 - return model - - def _setup_optimizers(self, *optimizers: Optimizer) -> List[_LiteOptimizer]: - optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers] - optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] - return optimizers - def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: initial_device = next(model.parameters()).device if any(param.device != initial_device for param in model.parameters()): diff --git a/src/lightning_lite/strategies/strategy.py b/src/lightning_lite/strategies/strategy.py index 0cb7c870f81d6..de54252bac1d6 100644 --- a/src/lightning_lite/strategies/strategy.py +++ b/src/lightning_lite/strategies/strategy.py @@ -118,7 +118,7 @@ def setup_module_and_optimizers( """Set up a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will - call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs. + call :meth:`setup_model` and :meth:`_setup_optimizer` on the inputs. """ module = self.setup_module(module) optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers] diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 05937cc062b78..908dbbb55946b 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -184,6 +184,7 @@ def test_setup_twice_fails(): lite.setup(model, lite_optimizer) +# TODO: extend this test with the other setup calls def test_setup_tracks_num_models(): """Test that setup() tracks how many times it has setup a model.""" lite = EmptyLite() From 9d6971beae6477c0f942b93026cf358a74ece4a1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 20 Oct 2022 22:12:22 +0200 Subject: [PATCH 27/84] update structure --- src/lightning_lite/lite.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index ee3b86cbc374a..241d6a2ad8674 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -151,16 +151,23 @@ def setup( The tuple containing wrapped model and the optimizers, in the same order they were passed in. """ self._validate_setup(model, optimizers) + original_model = model + + model = self._precision.convert_module(model) + + if move_to_device: + model = self._move_model_to_device(model=model, optimizers=list(optimizers)) + + # Let accelerator/plugin wrap and connect the models and optimizers + model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers)) + model = _LiteModule(model, self._precision, original_module=original_model) + + # Update the _DeviceDtypeModuleMixin's device parameter + model.to(self.device if move_to_device else next(model.parameters()).device) + + optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] - if model and not optimizers: - # set up a model without optimizers (e.g., for inference) - model = self.setup_model(model, move_to_device=move_to_device) - elif not model and optimizers: - # set up one or more optimizers separately from the model; some strategies don't support that - optimizers = self.setup_optimizers(*optimizers) - elif model and optimizers: - # set up model and optimizers jointly; some strategies require this - model, optimizers = self._setup_model_and_optimizers(model, *optimizers, move_to_device=move_to_device) + self._models_setup += 1 if optimizers: # join both types in a tuple for API convenience @@ -185,10 +192,10 @@ def setup_model(self, model: nn.Module, move_to_device: bool = True) -> _LiteMod self._models_setup += 1 return model - def setup_optimizers(self, *optimizers: Optimizer) -> Tuple[_LiteOptimizer]: + def setup_optimizers(self, *optimizers: Optimizer) -> Union[_LiteOptimizer, Tuple[_LiteOptimizer, ...]]: optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers] - optimizers = tuple(_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers) - return optimizers + optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] + return optimizers[0] if len(optimizers) == 1 else tuple(optimizers) def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True From 20bc93ba27c41f694db17f0edeca0536ab660eff Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 21 Oct 2022 02:35:53 +0200 Subject: [PATCH 28/84] tests --- src/lightning_lite/lite.py | 51 ++++++++++++++++++++++- tests/tests_lite/test_lite.py | 76 +++++++++++++++++++++++++++++------ 2 files changed, 113 insertions(+), 14 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 241d6a2ad8674..168b2dfb436a5 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -159,7 +159,11 @@ def setup( model = self._move_model_to_device(model=model, optimizers=list(optimizers)) # Let accelerator/plugin wrap and connect the models and optimizers - model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers)) + if optimizers: + model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers)) + else: + model = self._strategy.setup_module(model) + model = _LiteModule(model, self._precision, original_module=original_model) # Update the _DeviceDtypeModuleMixin's device parameter @@ -175,6 +179,21 @@ def setup( return model def setup_model(self, model: nn.Module, move_to_device: bool = True) -> _LiteModule: + """Set up a model for accelerated training or inference. + + This is the same as calling ``.setup(model)`` with no optimizers. It is useful for inference or for certain + strategies like `FSDP` that require setting up the module before the optimizer can be created and set up. + See also :meth:`setup_optimizers`. + + Args: + model: A model to set up + move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` + and alternatively use :meth:`to_device` manually. + + Returns: + The wrapped model. + """ + self._validate_setup_model(model) original_model = model model = self._precision.convert_module(model) @@ -193,6 +212,18 @@ def setup_model(self, model: nn.Module, move_to_device: bool = True) -> _LiteMod return model def setup_optimizers(self, *optimizers: Optimizer) -> Union[_LiteOptimizer, Tuple[_LiteOptimizer, ...]]: + """Set up one or more optimizers for accelerated training. + + Some strategies do not allow setting up model and optimizer independently. For them, you should call + ``.setup(model, optimizer, ...)`` instead to jointly set them up. + + Args: + *optimizers: One or more optmizers to set up. + + Returns: + The wrapped model. + """ + self._validate_setup_optimizers(optimizers) optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers] optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] return optimizers[0] if len(optimizers) == 1 else tuple(optimizers) @@ -514,10 +545,26 @@ def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): raise ValueError("An optimizer should be passed only once to the `setup` method.") + @staticmethod + def _validate_setup_model(model: nn.Module) -> None: + if isinstance(model, _LiteModule): + raise ValueError("A model should be passed only once to the `setup_model` method.") + + @staticmethod + def _validate_setup_optimizers(optimizers: Sequence[Optimizer]) -> None: + if not optimizers: + raise ValueError("`setup_optimizers` requires at least one optimizer as input.") + + if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): + raise ValueError("An optimizer should be passed only once to the `setup_optimizers` method.") + @staticmethod def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None: + if not dataloaders: + raise ValueError("`setup_dataloaders` requires at least one dataloader as input.") + if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders): - raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method") + raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method.") if any(not isinstance(dl, DataLoader) for dl in dataloaders): raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 908dbbb55946b..4bf8799adef35 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -70,11 +70,13 @@ def run(self, *args, **kwargs): @mock.patch("lightning_lite.strategies.ddp.DistributedDataParallel") -def test_setup_model(ddp_mock): +@pytest.mark.parametrize("setup_method", ["setup", "setup_model"]) +def test_setup_model(ddp_mock, setup_method): """Test that the setup method lets the strategy wrap the model, but keeps a reference to the original model.""" lite = EmptyLite(accelerator="cpu", strategy="ddp", devices=2) model = nn.Linear(1, 2) - lite_model = lite.setup(model) + setup_method = getattr(lite, setup_method) + lite_model = setup_method(model) ddp_mock.assert_called_with(module=model, device_ids=ANY) assert lite_model.module == model assert lite_model.weight is model.weight @@ -93,7 +95,8 @@ def test_setup_model(ddp_mock): ], ) @pytest.mark.parametrize("move_to_device", [True, False]) -def test_setup_model_move_to_device(move_to_device, accelerator, initial_device, target_device): +@pytest.mark.parametrize("setup_method", ["setup", "setup_model"]) +def test_setup_model_move_to_device(setup_method, move_to_device, accelerator, initial_device, target_device): """Test that `move_to_device` leads to parameters being moved to the correct device and that the device attributes on the wrapper are updated.""" initial_device = torch.device(initial_device) @@ -103,7 +106,8 @@ def test_setup_model_move_to_device(move_to_device, accelerator, initial_device, lite = EmptyLite(accelerator=accelerator, devices=1) model = nn.Linear(1, 2) model.to(initial_device) - lite_model = lite.setup(model, move_to_device=move_to_device) + setup_method = getattr(lite, setup_method) + lite_model = setup_method(model, move_to_device=move_to_device) # all parameters on the expected device assert all(param.device == expected_device for param in model.parameters()) @@ -115,7 +119,8 @@ def test_setup_model_move_to_device(move_to_device, accelerator, initial_device, @RunIf(min_cuda_gpus=1) @pytest.mark.parametrize("move_to_device", [True, False]) -def test_setup_model_parameters_on_different_devices(move_to_device): +@pytest.mark.parametrize("setup_method", ["setup", "setup_model"]) +def test_setup_model_parameters_on_different_devices(setup_method, move_to_device): """Test that a warning is emitted when model parameters are on a different device prior to calling `setup()`.""" device0 = torch.device("cpu") @@ -127,9 +132,11 @@ def test_setup_model_parameters_on_different_devices(move_to_device): module1 = nn.Linear(1, 2).to(device1) model = nn.Sequential(module0, module1) + setup_method = getattr(lite, setup_method) + if move_to_device: with pytest.warns(PossibleUserWarning, match="has parameters on different devices"): - lite_model = lite.setup(model, move_to_device=move_to_device) + lite_model = setup_method(model, move_to_device=move_to_device) # both have the same device now assert lite_model.device == device1 @@ -137,11 +144,11 @@ def test_setup_model_parameters_on_different_devices(move_to_device): assert module1.weight.device == module1.bias.device == device1 else: with no_warning_call(expected_warning=PossibleUserWarning, match="has parameters on different devices"): - lite.setup(model, move_to_device=move_to_device) + setup_method(model, move_to_device=move_to_device) -def test_setup_optimizers(): - """Test that setup_optimizers can handle no optimizers, one optimizer, or multiple optimizers.""" +def test_setup_model_and_optimizers(): + """Test that `setup()` can handle no optimizers, one optimizer, or multiple optimizers.""" lite = EmptyLite() model = nn.Linear(1, 2) optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1) @@ -169,8 +176,28 @@ def test_setup_optimizers(): assert lite_optimizer1.optimizer is optimizer1 +def test_setup_optimizers(): + """Test that `setup_optimizers()` can handle one or more optimizers.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1) + optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1) + + # single optimizer + lite_optimizer = lite.setup_optimizers(optimizer0) + assert isinstance(lite_optimizer, _LiteOptimizer) + assert lite_optimizer.optimizer is optimizer0 + + # multiple optimizers + lite_optimizer0, lite_optimizer1 = lite.setup_optimizers(optimizer0, optimizer1) + assert isinstance(lite_optimizer0, _LiteOptimizer) + assert isinstance(lite_optimizer1, _LiteOptimizer) + assert lite_optimizer0.optimizer is optimizer0 + assert lite_optimizer1.optimizer is optimizer1 + + def test_setup_twice_fails(): - """Test that calling setup with a model or optimizer that is already wrapped fails.""" + """Test that calling `setup` with a model or optimizer that is already wrapped fails.""" lite = EmptyLite() model = nn.Linear(1, 2) optimizer = torch.optim.Adam(model.parameters()) @@ -184,7 +211,27 @@ def test_setup_twice_fails(): lite.setup(model, lite_optimizer) -# TODO: extend this test with the other setup calls +def test_setup_model_twice_fails(): + """Test that calling `setup_model` with a model that is already wrapped fails.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + + lite_model = lite.setup_model(model) + with pytest.raises(ValueError, match="A model should be passed only once to the"): + lite.setup_model(lite_model) + + +def test_setup_optimizers_twice_fails(): + """Test that calling `setup_model` with a model that is already wrapped fails.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer = torch.optim.Adam(model.parameters()) + + lite_optimizer = lite.setup_optimizers(optimizer) + with pytest.raises(ValueError, match="An optimizer should be passed only once to"): + lite.setup_optimizers(lite_optimizer) + + def test_setup_tracks_num_models(): """Test that setup() tracks how many times it has setup a model.""" lite = EmptyLite() @@ -198,10 +245,15 @@ def test_setup_tracks_num_models(): lite.setup(model, optimizer) assert lite._models_setup == 2 + lite.setup_model(model) + assert lite._models_setup == 3 + -def test_setup_dataloaders_unsupported_type(): +def test_setup_dataloaders_unsupported_input(): """Test that the setup_dataloaders method fails when provided with non-DataLoader objects.""" lite = EmptyLite() + with pytest.raises(ValueError, match="`setup_dataloaders` requires at least one dataloader"): + lite.setup_dataloaders() with pytest.raises(TypeError, match="Only PyTorch DataLoader are currently supported"): lite.setup_dataloaders(range(2)) # type: ignore From 7dc0aa5615ce625714025118ef4219d8e7603777 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 22 Oct 2022 15:49:49 +0200 Subject: [PATCH 29/84] error messages --- src/lightning_lite/lite.py | 30 ++++++++++++++++++---- src/lightning_lite/strategies/deepspeed.py | 2 +- src/lightning_lite/strategies/fairscale.py | 4 +-- src/lightning_lite/strategies/strategy.py | 6 +++++ 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 168b2dfb436a5..631411f845ef2 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -29,7 +29,14 @@ from lightning_lite.plugins import Precision # avoid circular imports: # isort: split from lightning_lite.accelerators.accelerator import Accelerator from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT -from lightning_lite.strategies import DeepSpeedStrategy, SingleDeviceStrategy, Strategy, XLAStrategy +from lightning_lite.strategies import ( + DDPShardedStrategy, + DDPSpawnShardedStrategy, + DeepSpeedStrategy, + SingleDeviceStrategy, + Strategy, + XLAStrategy, +) from lightning_lite.strategies.strategy import TBroadcast from lightning_lite.utilities import move_data_to_device from lightning_lite.utilities.apply_func import convert_to_tensors @@ -545,13 +552,26 @@ def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): raise ValueError("An optimizer should be passed only once to the `setup` method.") - @staticmethod - def _validate_setup_model(model: nn.Module) -> None: + # TODO(lite): Add validation for FSDP here + + def _validate_setup_model(self, model: nn.Module) -> None: if isinstance(model, _LiteModule): raise ValueError("A model should be passed only once to the `setup_model` method.") - @staticmethod - def _validate_setup_optimizers(optimizers: Sequence[Optimizer]) -> None: + if isinstance(self._strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)): + raise RuntimeError( + f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly" + " through `.setup(model, optimizer, ...)`. For inference, choose a different strategy, for example" + " `ddp`." + ) + + def _validate_setup_optimizers(self, optimizers: Sequence[Optimizer]) -> None: + if isinstance(self._strategy, (DeepSpeedStrategy, DDPShardedStrategy, DDPSpawnShardedStrategy, XLAStrategy)): + raise RuntimeError( + f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly" + " through `.setup(model, optimizer, ...)`." + ) + if not optimizers: raise ValueError("`setup_optimizers` requires at least one optimizer as input.") diff --git a/src/lightning_lite/strategies/deepspeed.py b/src/lightning_lite/strategies/deepspeed.py index 83407f593a418..4e15a359d5490 100644 --- a/src/lightning_lite/strategies/deepspeed.py +++ b/src/lightning_lite/strategies/deepspeed.py @@ -326,7 +326,7 @@ def setup_module(self, module: Module) -> "deepspeed.DeepSpeedEngine": return self._deepspeed_engine def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: - raise RuntimeError("not supported") # TODO: proper error message + raise NotImplementedError(self._err_msg_joint_setup_required()) @contextlib.contextmanager def module_sharded_context(self) -> Generator[None, None, None]: diff --git a/src/lightning_lite/strategies/fairscale.py b/src/lightning_lite/strategies/fairscale.py index a7120293eb6f0..81d40e47221b4 100644 --- a/src/lightning_lite/strategies/fairscale.py +++ b/src/lightning_lite/strategies/fairscale.py @@ -90,10 +90,10 @@ def setup_module_and_optimizers( return model, optimizers def setup_module(self, module: Module) -> Module: - raise RuntimeError("not supported") # TODO: proper error message + raise NotImplementedError(self._err_msg_joint_setup_required()) def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: - raise RuntimeError("not supported") # TODO: proper error message + raise NotImplementedError(self._err_msg_joint_setup_required()) @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: diff --git a/src/lightning_lite/strategies/strategy.py b/src/lightning_lite/strategies/strategy.py index de54252bac1d6..ed428a2e8953d 100644 --- a/src/lightning_lite/strategies/strategy.py +++ b/src/lightning_lite/strategies/strategy.py @@ -298,6 +298,12 @@ def teardown(self) -> None: def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None: pass + def _err_msg_joint_setup_required(self) -> str: + return ( + f"The `{type(self).__name__}` does not support setting up the module and optimizer(s) independently." + " Please call `setup_module_and_optimizers(model, [optimizer, ...])` to jointly set them up." + ) + class _BackwardSyncControl(ABC): """Interface for any :class:`Strategy` that wants to offer a functionality to enable or disable gradient From 0190b504dc70d992d1316599710676f9428ea137 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 22 Oct 2022 17:36:40 +0200 Subject: [PATCH 30/84] test errors --- tests/tests_lite/test_lite.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 4bf8799adef35..5d89abb14a5da 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from re import escape from unittest import mock from unittest.mock import ANY, MagicMock, Mock, PropertyMock @@ -26,7 +27,15 @@ from lightning_lite.lite import LightningLite from lightning_lite.plugins import Precision -from lightning_lite.strategies import ParallelStrategy, SingleDeviceStrategy, Strategy +from lightning_lite.strategies import ( + DDPShardedStrategy, + DDPSpawnShardedStrategy, + DeepSpeedStrategy, + ParallelStrategy, + SingleDeviceStrategy, + Strategy, + XLAStrategy, +) from lightning_lite.utilities import _StrategyType from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.seed import pl_worker_init_function @@ -232,6 +241,27 @@ def test_setup_optimizers_twice_fails(): lite.setup_optimizers(lite_optimizer) +@pytest.mark.parametrize("strategy_cls", [DDPShardedStrategy, DDPSpawnShardedStrategy]) +def test_setup_model_not_supported(strategy_cls): + """Test that `setup_model` validates the strategy supports setting up model and optimizers independently.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + lite._strategy = Mock(spec=strategy_cls) + with pytest.raises(RuntimeError, match=escape("requires the model and optimizer(s) to be set up jointly through")): + lite.setup_model(model) + + +@pytest.mark.parametrize("strategy_cls", [DeepSpeedStrategy, DDPShardedStrategy, DDPSpawnShardedStrategy, XLAStrategy]) +def test_setup_optimizers_not_supported(strategy_cls): + """Test that `setup_optimizers` validates the strategy supports setting up model and optimizers independently.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer = torch.optim.Adam(model.parameters()) + lite._strategy = Mock(spec=strategy_cls) + with pytest.raises(RuntimeError, match=escape("requires the model and optimizer(s) to be set up jointly through")): + lite.setup_optimizers(optimizer) + + def test_setup_tracks_num_models(): """Test that setup() tracks how many times it has setup a model.""" lite = EmptyLite() From 1e3e7b19373af2f6ab9dbb83eea3c68071073833 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Oct 2022 15:38:25 +0000 Subject: [PATCH 31/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_lite/test_lite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 5d89abb14a5da..4bf71fd65acee 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -253,7 +253,8 @@ def test_setup_model_not_supported(strategy_cls): @pytest.mark.parametrize("strategy_cls", [DeepSpeedStrategy, DDPShardedStrategy, DDPSpawnShardedStrategy, XLAStrategy]) def test_setup_optimizers_not_supported(strategy_cls): - """Test that `setup_optimizers` validates the strategy supports setting up model and optimizers independently.""" + """Test that `setup_optimizers` validates the strategy supports setting up model and optimizers + independently.""" lite = EmptyLite() model = nn.Linear(1, 2) optimizer = torch.optim.Adam(model.parameters()) From 472b60529c20bfd6f9a09275ee39265a3f1ad3a8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 22 Oct 2022 17:46:48 +0200 Subject: [PATCH 32/84] mypy --- src/lightning_lite/lite.py | 4 +++- src/lightning_lite/strategies/deepspeed.py | 2 +- src/lightning_lite/strategies/fairscale.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 631411f845ef2..614ce3c18350c 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -167,7 +167,9 @@ def setup( # Let accelerator/plugin wrap and connect the models and optimizers if optimizers: - model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers)) + model, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment] + model, list(optimizers) + ) else: model = self._strategy.setup_module(model) diff --git a/src/lightning_lite/strategies/deepspeed.py b/src/lightning_lite/strategies/deepspeed.py index 4e15a359d5490..69e86d7180bd4 100644 --- a/src/lightning_lite/strategies/deepspeed.py +++ b/src/lightning_lite/strategies/deepspeed.py @@ -34,7 +34,7 @@ from lightning_lite.utilities.enums import AMPType, PrecisionType from lightning_lite.utilities.rank_zero import rank_zero_info from lightning_lite.utilities.seed import reset_seed -from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau +from lightning_lite.utilities.types import _PATH _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") if _DEEPSPEED_AVAILABLE: diff --git a/src/lightning_lite/strategies/fairscale.py b/src/lightning_lite/strategies/fairscale.py index 81d40e47221b4..ef6fd5d929e9d 100644 --- a/src/lightning_lite/strategies/fairscale.py +++ b/src/lightning_lite/strategies/fairscale.py @@ -18,6 +18,7 @@ import torch from lightning_utilities.core.imports import module_available from torch.nn import Module +from torch.nn.parallel import DistributedDataParallel from torch.optim import Optimizer from lightning_lite.accelerators import Accelerator @@ -89,7 +90,7 @@ def setup_module_and_optimizers( model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers - def setup_module(self, module: Module) -> Module: + def setup_module(self, module: Module) -> DistributedDataParallel: raise NotImplementedError(self._err_msg_joint_setup_required()) def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: From 3630d0599c8d00254043370978fee38d9a490302 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 22 Oct 2022 18:15:40 +0200 Subject: [PATCH 33/84] add changelog --- src/pytorch_lightning/CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 60228d76c5a7f..eef7a65645332 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +Added `LightningLite.setup_model()` and `LightningLite.setup_optimizers()` to support strategies that need to set up the model before an optimizer can be created ([#15185](https://github.com/Lightning-AI/lightning/pull/15185)) + ### Changed @@ -102,8 +104,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Deprecated the `pytorch_lightning.utilities.device_parser.parse_tpu_cores` in favor of `lightning_lite.accelerators.tpu.parse_tpu_cores` * Deprecated the `pytorch_lightning.utilities.device_parser.parse_hpus` in favor of `pytorch_lightning.accelerators.hpu.parse_hpus` - Deprecated duplicate `SaveConfigCallback` parameters in `LightningCLI.__init__`: `save_config_kwargs`, `save_config_overwrite` and `save_config_multifile`. New `save_config_kwargs` parameter should be used instead ([#14998](https://github.com/Lightning-AI/lightning/pull/14998)) - - - Deprecated `TrainerFn.TUNING`, `RunningStage.TUNING` and `trainer.tuning` property ([#15100](https://github.com/Lightning-AI/lightning/pull/15100) From 8d035fab81eaf96e3fe259b619680e09f04044d4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 17:39:21 +0200 Subject: [PATCH 34/84] messaging --- src/lightning_lite/strategies/fsdp.py | 10 ++++++++-- tests/tests_lite/strategies/test_fsdp_integration.py | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index ae8a7e88e4f65..12be98f16c62f 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -159,7 +159,11 @@ def setup_environment(self) -> None: def setup_module_and_optimizers( self, module: Module, optimizers: List[Optimizer] ) -> Tuple[Module, List[Optimizer]]: - raise RuntimeError("not supported") # TODO: proper error msg + raise NotImplementedError( + f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)." + " Please do it in this order: Create the model, call `setup_module`, create the optimizer," + " call `setup_optimizer`." + ) def setup_module(self, module: Module) -> "FullyShardedDataParallel": """Wraps the model into a @@ -182,7 +186,9 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel": ) def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: - # TODO: some validation here + # TODO: add validation + print("in setup optimizer") + print(optimizer.param_groups) return optimizer def module_to_device(self, module: Module) -> None: diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 867217762329c..4ac24eea65fb2 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -35,13 +35,13 @@ def run(self): dataloader = DataLoader(RandomDataset(32, 64)) # model needs to be set up first in FSDP - model = self.setup(model) + model = self.setup_model(model) # get parameters on the wrapped model optimizer = torch.optim.SGD(model.parameters(), lr=0.1) # optimizer nees to be set up independently - optimizer = self.setup(None, optimizer) + optimizer = self.setup_optimizers(optimizer) dataloader = self.setup_dataloaders(dataloader) model.train() From fadb2b6291b2bc7480bf9686745f5dc85ccdb83f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 17:41:18 +0200 Subject: [PATCH 35/84] debug --- tests/tests_lite/strategies/test_fsdp_integration.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 4ac24eea65fb2..1428c67ab0798 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -35,13 +35,15 @@ def run(self): dataloader = DataLoader(RandomDataset(32, 64)) # model needs to be set up first in FSDP - model = self.setup_model(model) + # model = self.setup_model(model) # get parameters on the wrapped model optimizer = torch.optim.SGD(model.parameters(), lr=0.1) # optimizer nees to be set up independently - optimizer = self.setup_optimizers(optimizer) + # optimizer = self.setup_optimizers(optimizer) + + model, optimizer = self.setup(model, optimizer) dataloader = self.setup_dataloaders(dataloader) model.train() From 39e7c096decf2280d863e0e292fa05b59e4aecc2 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 17:46:11 +0200 Subject: [PATCH 36/84] debug --- src/lightning_lite/lite.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 614ce3c18350c..19a9e8f6c491f 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -51,6 +51,8 @@ from lightning_lite.utilities.warnings import PossibleUserWarning from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from src.lightning_lite.strategies.fsdp import FSDPStrategy + class LightningLite(ABC): """Lite accelerates your PyTorch training or inference code with minimal changes required. @@ -546,15 +548,19 @@ def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> Distribut kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) return DistributedSamplerWrapper(dataloader.sampler, **kwargs) - @staticmethod - def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: + def _validate_setup(self, model: nn.Module, optimizers: Sequence[Optimizer]) -> None: if isinstance(model, _LiteModule): raise ValueError("A model should be passed only once to the `setup` method.") if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): raise ValueError("An optimizer should be passed only once to the `setup` method.") - # TODO(lite): Add validation for FSDP here + if isinstance(self._strategy, FSDPStrategy): + raise RuntimeError( + f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately." + " Create and set up the model first through `model = self.setup_model(model)`. Then create the" + " optimizer and set it up: `optimizer = self.setup_optimizer(optimizer)`." + ) def _validate_setup_model(self, model: nn.Module) -> None: if isinstance(model, _LiteModule): From 7bc9421847be103afaf69d2d52eaf2a1bdf875ce Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 17:50:33 +0200 Subject: [PATCH 37/84] fix --- src/lightning_lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 19a9e8f6c491f..7cb285957ae2c 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -51,7 +51,7 @@ from lightning_lite.utilities.warnings import PossibleUserWarning from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer -from src.lightning_lite.strategies.fsdp import FSDPStrategy +from lightning_lite.strategies.fsdp import FSDPStrategy class LightningLite(ABC): From 9d086f2d2079eb0caa5a28568a84f15fa560516c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 18:05:14 +0200 Subject: [PATCH 38/84] udpate test --- tests/tests_lite/strategies/test_fsdp_integration.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 1428c67ab0798..4ac24eea65fb2 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -35,15 +35,13 @@ def run(self): dataloader = DataLoader(RandomDataset(32, 64)) # model needs to be set up first in FSDP - # model = self.setup_model(model) + model = self.setup_model(model) # get parameters on the wrapped model optimizer = torch.optim.SGD(model.parameters(), lr=0.1) # optimizer nees to be set up independently - # optimizer = self.setup_optimizers(optimizer) - - model, optimizer = self.setup(model, optimizer) + optimizer = self.setup_optimizers(optimizer) dataloader = self.setup_dataloaders(dataloader) model.train() From ac29054100436b93ce4a2da4022225566a791f6e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 18:08:58 +0200 Subject: [PATCH 39/84] remove done todo --- src/lightning_lite/lite.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 614ce3c18350c..e75438bfed93b 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -554,8 +554,6 @@ def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): raise ValueError("An optimizer should be passed only once to the `setup` method.") - # TODO(lite): Add validation for FSDP here - def _validate_setup_model(self, model: nn.Module) -> None: if isinstance(model, _LiteModule): raise ValueError("A model should be passed only once to the `setup_model` method.") From 5e1f433687fea480c9abf75a213d674538cd3a52 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 18:12:12 +0200 Subject: [PATCH 40/84] missing err message --- src/lightning_lite/strategies/fairscale.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/strategies/fairscale.py b/src/lightning_lite/strategies/fairscale.py index ef6fd5d929e9d..7ade43566a862 100644 --- a/src/lightning_lite/strategies/fairscale.py +++ b/src/lightning_lite/strategies/fairscale.py @@ -160,11 +160,11 @@ def setup_module_and_optimizers( model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers - def setup_module(self, module: Module) -> Module: - raise RuntimeError("not supported") # TODO: proper error message + def setup_module(self, module: Module) -> DistributedDataParallel: + raise NotImplementedError(self._err_msg_joint_setup_required()) def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: - raise RuntimeError("not supported") # TODO: proper error message + raise NotImplementedError(self._err_msg_joint_setup_required()) @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: From a900207192b8073b5149566febdd31344fd3d8e5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 18:57:53 +0200 Subject: [PATCH 41/84] tests --- tests/tests_lite/strategies/test_deepspeed.py | 35 +++++++++++++++++++ tests/tests_lite/strategies/test_fairscale.py | 16 +++++++++ 2 files changed, 51 insertions(+) diff --git a/tests/tests_lite/strategies/test_deepspeed.py b/tests/tests_lite/strategies/test_deepspeed.py index bea2096013f6d..f1cac3fd4e1e7 100644 --- a/tests/tests_lite/strategies/test_deepspeed.py +++ b/tests/tests_lite/strategies/test_deepspeed.py @@ -13,8 +13,12 @@ # limitations under the License. import json import os +from re import escape +from unittest import mock +from unittest.mock import ANY, Mock import pytest +import torch from tests_lite.helpers.runif import RunIf from lightning_lite.accelerators import CPUAccelerator @@ -116,3 +120,34 @@ def test_deepspeed_config_zero_offload(deepspeed_zero_config): deepspeed_zero_config["zero_optimization"]["offload_optimizer"] = False strategy = DeepSpeedStrategy(config=deepspeed_zero_config) assert strategy.config["zero_optimization"]["offload_optimizer"] is False + + +@RunIf(deepspeed=True) +@mock.patch("lightning_lite.strategies.deepspeed.deepspeed.initialize") +def test_deepspeed_setup_module(init_mock): + """Test that the DeepSpeed strategy can set up the model for inference (no optimizer required).""" + model = Mock() + model.parameters.return_value = [] + strategy = DeepSpeedStrategy() + strategy.parallel_devices = [torch.device("cuda", 1)] + init_mock.return_value = [Mock()] * 4 # mock to make tuple unpacking work + + strategy.setup_module(model) + init_mock.assert_called_with( + args=ANY, + config=strategy.config, + model=model, + model_parameters=ANY, + optimizer=None, + dist_init_required=False, + ) + + +@RunIf(deepspeed=True) +def test_deepspeed_requires_joint_setup(): + """Test that the DeepSpeed strategy does not support setting up model and optimizer independently.""" + strategy = DeepSpeedStrategy() + with pytest.raises( + NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently") + ): + strategy.setup_optimizer(Mock()) diff --git a/tests/tests_lite/strategies/test_fairscale.py b/tests/tests_lite/strategies/test_fairscale.py index 31857e0da1bea..4029402e19e61 100644 --- a/tests/tests_lite/strategies/test_fairscale.py +++ b/tests/tests_lite/strategies/test_fairscale.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from re import escape from unittest import mock from unittest.mock import MagicMock, Mock @@ -89,3 +90,18 @@ def test_fairscale_no_backward_sync(cls): pass module.no_sync.assert_called_once() + + +@pytest.mark.parametrize("cls", [DDPShardedStrategy, DDPSpawnShardedStrategy]) +def test_fairscale_requires_joint_setup(cls): + """Test that the fairscale sharded strategy does not support setting up model and optimizer independently.""" + strategy = cls() + with pytest.raises( + NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently") + ): + strategy.setup_module(Mock()) + + with pytest.raises( + NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently") + ): + strategy.setup_optimizer(Mock()) From a41eef947494eebfd3a1acee1742543547b71510 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 18:59:54 +0200 Subject: [PATCH 42/84] flake --- src/lightning_lite/utilities/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning_lite/utilities/data.py b/src/lightning_lite/utilities/data.py index afa0e988ca766..9aace928563fb 100644 --- a/src/lightning_lite/utilities/data.py +++ b/src/lightning_lite/utilities/data.py @@ -327,8 +327,8 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable: - """Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and - :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" + """Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` + and :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" @functools.wraps(method) def wrapper(obj: Any, *args: Any) -> None: From 9cd50b499b951fdfa8ed713364dd9edbe4812e09 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Oct 2022 17:05:53 +0000 Subject: [PATCH 43/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_lite/utilities/data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lightning_lite/utilities/data.py b/src/lightning_lite/utilities/data.py index 9aace928563fb..021de9a0f1e79 100644 --- a/src/lightning_lite/utilities/data.py +++ b/src/lightning_lite/utilities/data.py @@ -327,8 +327,9 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable: - """Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` - and :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" + """Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently + :class:`~torch.utils.data.DataLoader` and :class:`~torch.utils.data.BatchSampler`) in order to enable re- + instantiation of custom subclasses.""" @functools.wraps(method) def wrapper(obj: Any, *args: Any) -> None: From 564ab17054593587dffc8165732098501d5c79c1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 19:31:12 +0200 Subject: [PATCH 44/84] docstrings --- src/lightning_lite/strategies/deepspeed.py | 12 ++++++++++-- src/lightning_lite/strategies/fairscale.py | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/lightning_lite/strategies/deepspeed.py b/src/lightning_lite/strategies/deepspeed.py index 69e86d7180bd4..87bb8f9715272 100644 --- a/src/lightning_lite/strategies/deepspeed.py +++ b/src/lightning_lite/strategies/deepspeed.py @@ -303,9 +303,9 @@ def model(self) -> "deepspeed.DeepSpeedEngine": def setup_module_and_optimizers( self, module: Module, optimizers: List[Optimizer] ) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]: - """Setup a model and multiple optimizers together. + """Set up a model and multiple optimizers together. - Currently only a single optimizer is supported. + Currently, only a single optimizer is supported. Return: The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single @@ -322,10 +322,18 @@ def setup_module_and_optimizers( return self._deepspeed_engine, [optimizer] def setup_module(self, module: Module) -> "deepspeed.DeepSpeedEngine": + """Set up a module for inference (no optimizers). + + For training, see :meth:`setup_module_and_optimizers`. + """ self._deepspeed_engine, _ = self._initialize_engine(module) return self._deepspeed_engine def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Optimizers can only be set up jointly with the model in this strategy. + + Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together. + """ raise NotImplementedError(self._err_msg_joint_setup_required()) @contextlib.contextmanager diff --git a/src/lightning_lite/strategies/fairscale.py b/src/lightning_lite/strategies/fairscale.py index 7ade43566a862..adae3b72b9dc5 100644 --- a/src/lightning_lite/strategies/fairscale.py +++ b/src/lightning_lite/strategies/fairscale.py @@ -91,9 +91,17 @@ def setup_module_and_optimizers( return model, optimizers def setup_module(self, module: Module) -> DistributedDataParallel: + """Setting up the module without optimizers in this strategy is not supported. + + Please use :meth:`setup_module_and_optimizers` instead. + """ raise NotImplementedError(self._err_msg_joint_setup_required()) def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Optimizers can only be set up jointly with the model in this strategy. + + Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer(s) together. + """ raise NotImplementedError(self._err_msg_joint_setup_required()) @classmethod @@ -161,9 +169,17 @@ def setup_module_and_optimizers( return model, optimizers def setup_module(self, module: Module) -> DistributedDataParallel: + """Setting up the module without optimizers in this strategy is not supported. + + Please use :meth:`setup_module_and_optimizers` instead. + """ raise NotImplementedError(self._err_msg_joint_setup_required()) def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Optimizers can only be set up jointly with the model in this strategy. + + Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer(s) together. + """ raise NotImplementedError(self._err_msg_joint_setup_required()) @classmethod From 6afc0aee5a5f22353265a3971f315e61584dd878 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 20:15:26 +0200 Subject: [PATCH 45/84] doc fix --- src/lightning_lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index e75438bfed93b..25e02b8225306 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -230,7 +230,7 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_LiteOptimizer, Tupl *optimizers: One or more optmizers to set up. Returns: - The wrapped model. + The wrapped optimizer(s). """ self._validate_setup_optimizers(optimizers) optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers] From 8b999b120bfe320535c3d4cf2c9aaa5e72434116 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Oct 2022 20:36:06 +0200 Subject: [PATCH 46/84] support python < 3.10 --- src/lightning_lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 25e02b8225306..d085b8ccffcea 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -184,7 +184,7 @@ def setup( if optimizers: # join both types in a tuple for API convenience - return model, *optimizers + return tuple((model, *optimizers)) return model def setup_model(self, model: nn.Module, move_to_device: bool = True) -> _LiteModule: From 3036840d091bfd64530348fb49b686f875db0462 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 20:14:31 +0200 Subject: [PATCH 47/84] validation --- src/lightning_lite/strategies/fsdp.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 85ba4d4f0edb6..c87365ae64c8a 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -19,6 +19,7 @@ from torch import Tensor from torch.distributed import default_pg_timeout from torch.nn import Module +from torch.optim import Optimizer from lightning_lite.accelerators import Accelerator from lightning_lite.plugins import CheckpointIO, ClusterEnvironment, Precision @@ -176,6 +177,16 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel": **self._ddp_kwargs, ) + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + from torch.distributed.fsdp import FlatParameter + + if len(optimizer.param_groups) > 1: + raise ValueError("Optimizers used with FSDP do not support multiple param groups.") + + if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0].values()): + return optimizer + raise ValueError("The optimizer does not seem to reference any flat FSDP parameters.") + def module_to_device(self, module: Module) -> None: pass From 4bb5d5690c242258078eef676a43a0d125cd0511 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Oct 2022 20:18:46 +0200 Subject: [PATCH 48/84] debug --- tests/tests_lite/strategies/test_fsdp_integration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 4ac24eea65fb2..806827968e893 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -39,6 +39,7 @@ def run(self): # get parameters on the wrapped model optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + print([type(p) for p in optimizer.param_groups[0].values()]) # optimizer nees to be set up independently optimizer = self.setup_optimizers(optimizer) From 7f944cc48e089a1aca0cbc4952b2d770cf5ba9b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Oct 2022 20:19:54 +0200 Subject: [PATCH 49/84] debug --- tests/tests_lite/strategies/test_fsdp_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 806827968e893..6ddb07f2c8268 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -39,7 +39,7 @@ def run(self): # get parameters on the wrapped model optimizer = torch.optim.SGD(model.parameters(), lr=0.1) - print([type(p) for p in optimizer.param_groups[0].values()]) + print([p for p in optimizer.param_groups[0].values()]) # optimizer nees to be set up independently optimizer = self.setup_optimizers(optimizer) From 070b828df3c36bdd47cde0cd797cece7802a4d74 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 20:29:21 +0200 Subject: [PATCH 50/84] update --- src/lightning_lite/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index f0b0da9d0c0cf..16587eaac0aa8 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -191,7 +191,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: if len(optimizer.param_groups) > 1: raise ValueError("Optimizers used with FSDP do not support multiple param groups.") - if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0].values()): + if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]): return optimizer raise ValueError("The optimizer does not seem to reference any flat FSDP parameters.") From a63b3c9b56e0ea429f39ae4a1b8b027207855f02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Oct 2022 20:29:28 +0200 Subject: [PATCH 51/84] debug --- repro.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 repro.py diff --git a/repro.py b/repro.py new file mode 100644 index 0000000000000..ebeb48bcbc031 --- /dev/null +++ b/repro.py @@ -0,0 +1,9 @@ +import torch +import torch.nn as nn + +from torch.optim import Adam + +model = nn.Linear(2, 2) +optimizer = Adam(model.parameters()) + +print(optimizer.param_groups[0]["params"]) \ No newline at end of file From 3e20de2b8bd905f1c869ec86677be279b88823a9 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 20:37:10 +0200 Subject: [PATCH 52/84] validate --- src/lightning_lite/strategies/fsdp.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 16587eaac0aa8..0ac3367948962 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -186,14 +186,27 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel": ) def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Set up an optimizer for a model wrapped with FSDP. + + This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is + verify that the optimizer was created after the model was wrapped with `:meth:setup_module` with a reference + to the flattened parameters. + """ from torch.distributed.fsdp import FlatParameter - if len(optimizer.param_groups) > 1: - raise ValueError("Optimizers used with FSDP do not support multiple param groups.") + num_groups = len(optimizer.param_groups) + if num_groups > 1: + raise ValueError( + "An optimizer used with an FSDP model does not support multiple param groups." + f" Found {num_groups} parameter groups." + ) if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]): return optimizer - raise ValueError("The optimizer does not seem to reference any flat FSDP parameters.") + raise ValueError( + "The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer" + " after setting up the model." + ) def module_to_device(self, module: Module) -> None: pass From cf9b92fbd5ed918192a102deec61caddcf7dc50d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 20:37:56 +0200 Subject: [PATCH 53/84] revert --- repro.py | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 repro.py diff --git a/repro.py b/repro.py deleted file mode 100644 index ebeb48bcbc031..0000000000000 --- a/repro.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -import torch.nn as nn - -from torch.optim import Adam - -model = nn.Linear(2, 2) -optimizer = Adam(model.parameters()) - -print(optimizer.param_groups[0]["params"]) \ No newline at end of file From fc34be391e81907ee3c9467863b7ba7c98546a27 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 20:41:58 +0200 Subject: [PATCH 54/84] x --- src/lightning_lite/strategies/fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 0ac3367948962..18fe637c03678 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -203,6 +203,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]): return optimizer + raise ValueError( "The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer" " after setting up the model." From 3088273e2ed2da25f44d735a8187f6cad8b5a57a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Oct 2022 18:44:54 +0000 Subject: [PATCH 55/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_lite/lite.py | 3 +-- src/lightning_lite/strategies/fsdp.py | 8 ++++---- tests/tests_lite/strategies/test_fsdp_integration.py | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index aa3c0247b8514..8bddcd64bd49c 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -37,6 +37,7 @@ Strategy, XLAStrategy, ) +from lightning_lite.strategies.fsdp import FSDPStrategy from lightning_lite.strategies.strategy import TBroadcast from lightning_lite.utilities import move_data_to_device from lightning_lite.utilities.apply_func import convert_to_tensors @@ -51,8 +52,6 @@ from lightning_lite.utilities.warnings import PossibleUserWarning from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer -from lightning_lite.strategies.fsdp import FSDPStrategy - class LightningLite(ABC): """Lite accelerates your PyTorch training or inference code with minimal changes required. diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 18fe637c03678..984b15997d52c 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING, Union, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union import torch from torch import Tensor @@ -188,9 +188,9 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel": def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: """Set up an optimizer for a model wrapped with FSDP. - This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is - verify that the optimizer was created after the model was wrapped with `:meth:setup_module` with a reference - to the flattened parameters. + This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify + that the optimizer was created after the model was wrapped with `:meth:setup_module` with a reference to the + flattened parameters. """ from torch.distributed.fsdp import FlatParameter diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 6ddb07f2c8268..30713258ffb21 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -15,6 +15,7 @@ import pytest import torch +from tests_lite.helpers.models import RandomDataset from tests_lite.helpers.runif import RunIf from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp.wrap import wrap @@ -23,7 +24,6 @@ from lightning_lite import LightningLite from lightning_lite.plugins import FSDPPrecision from lightning_lite.strategies import FSDPStrategy -from tests_lite.helpers.models import RandomDataset class FSDPLite(LightningLite): From d49ac0a1134fadb9774d4745ff8c5c6c2da5e48f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Oct 2022 20:53:06 +0200 Subject: [PATCH 56/84] debug --- src/lightning_lite/strategies/fsdp.py | 12 +++++----- tests/tests_lite/strategies/test_fsdp.py | 3 +++ .../strategies/test_fsdp_integration.py | 22 +++++++++++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 18fe637c03678..b103330abb767 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -27,9 +27,9 @@ from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning_lite.strategies.parallel import ParallelStrategy from lightning_lite.strategies.strategy import TBroadcast -from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device +from lightning_lite.utilities.distributed import _distributed_available, _get_default_process_group_backend_for_device from lightning_lite.utilities.distributed import group as _group -from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available +from lightning_lite.utilities.distributed import _init_dist_connection, ReduceOp, _sync_ddp_if_available from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from lightning_lite.utilities.rank_zero import rank_zero_only from lightning_lite.utilities.seed import reset_seed @@ -231,11 +231,11 @@ def reduce( self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" ) -> Tensor: if isinstance(tensor, Tensor): - tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + tensor = _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor def barrier(self, *args: Any, **kwargs: Any) -> None: - if not distributed_available(): + if not _distributed_available(): return if torch.distributed.get_backend() == "nccl": torch.distributed.barrier(device_ids=[self.root_device.index]) @@ -271,10 +271,10 @@ def _setup_distributed(self) -> None: rank_zero_only.rank = self.global_rank self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) def _get_process_group_backend(self) -> str: - return self._process_group_backend or get_default_process_group_backend_for_device(self.root_device) + return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) def _set_world_ranks(self) -> None: if self.cluster_environment is None: diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 17d3be06cbcd5..9f06f03a0d7ba 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -16,7 +16,10 @@ from unittest import mock import pytest +import torch +import torch.nn as nn from tests_lite.helpers.runif import RunIf +from torch.optim import Adam from lightning_lite.strategies import FSDPStrategy from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 6ddb07f2c8268..3c7f2af1b290a 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -16,8 +16,10 @@ import pytest import torch from tests_lite.helpers.runif import RunIf +from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp.wrap import wrap +from torch.optim import Adam from torch.utils.data import DataLoader from lightning_lite import LightningLite @@ -117,3 +119,23 @@ def test_fsdp_train_save_load(manual_wrapping, precision): lite = FSDPLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) lite.manual_wrapping = manual_wrapping lite.run() + + +class SetupOptimizerLite(LightningLite): + + def run(self): + module = nn.Linear(2, 2) + bad_optimizer = Adam(module.parameters()) + wrapped_module = self.setup_model(module) + good_optimizer = Adam(wrapped_module.parameters()) + + with pytest.raises(ValueError, match="sdf"): + self.setup_optimizers(bad_optimizer) + + assert self.setup_optimizers(good_optimizer) == good_optimizer + + +@RunIf(standalone=True, min_cuda_gpus=2) +def test_fsdp_setup_optimizer(): + lite = SetupOptimizerLite(accelerator="cuda", strategy="fsdp", devices=2) + lite.run() From 469cc2de9516d036dcf282685ecf588c902c4ffc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Oct 2022 18:54:56 +0000 Subject: [PATCH 57/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_lite/strategies/fsdp.py | 9 +++++++-- tests/tests_lite/strategies/test_fsdp_integration.py | 1 - 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 5982fe7aef673..a6499f1c18217 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -27,9 +27,14 @@ from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning_lite.strategies.parallel import ParallelStrategy from lightning_lite.strategies.strategy import TBroadcast -from lightning_lite.utilities.distributed import _distributed_available, _get_default_process_group_backend_for_device +from lightning_lite.utilities.distributed import ( + _distributed_available, + _get_default_process_group_backend_for_device, + _init_dist_connection, + _sync_ddp_if_available, +) from lightning_lite.utilities.distributed import group as _group -from lightning_lite.utilities.distributed import _init_dist_connection, ReduceOp, _sync_ddp_if_available +from lightning_lite.utilities.distributed import ReduceOp from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from lightning_lite.utilities.rank_zero import rank_zero_only from lightning_lite.utilities.seed import reset_seed diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index c0e819161a2e7..97b666ce142ee 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -122,7 +122,6 @@ def test_fsdp_train_save_load(manual_wrapping, precision): class SetupOptimizerLite(LightningLite): - def run(self): module = nn.Linear(2, 2) bad_optimizer = Adam(module.parameters()) From c1387167bef41bc544b3fa9e6b96b5fa6a53f061 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Oct 2022 20:55:44 +0200 Subject: [PATCH 58/84] debug --- tests/tests_lite/strategies/test_fsdp_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index c0e819161a2e7..5c3bb790e8a4a 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -129,7 +129,7 @@ def run(self): wrapped_module = self.setup_model(module) good_optimizer = Adam(wrapped_module.parameters()) - with pytest.raises(ValueError, match="sdf"): + with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"): self.setup_optimizers(bad_optimizer) assert self.setup_optimizers(good_optimizer) == good_optimizer From c8f5a671db3ecc950f2572fdca1f5776c37319b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Oct 2022 20:57:22 +0200 Subject: [PATCH 59/84] debug --- tests/tests_lite/strategies/test_fsdp_integration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 328de694e7913..38030df9ac378 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -131,7 +131,8 @@ def run(self): with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"): self.setup_optimizers(bad_optimizer) - assert self.setup_optimizers(good_optimizer) == good_optimizer + lite_optimizer = self.setup_optimizers(good_optimizer) + assert lite_optimizer.optimizer == good_optimizer @RunIf(standalone=True, min_cuda_gpus=2) From 304bac0304d1a90d4ff468e239bda79c624df2a5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 21:07:56 +0200 Subject: [PATCH 60/84] simplify --- tests/tests_lite/strategies/test_fsdp.py | 14 +++++++++++ .../strategies/test_fsdp_integration.py | 23 ------------------- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 9f06f03a0d7ba..5efce187572f9 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -45,3 +45,17 @@ def test_fsdp_custom_mixed_precision(*_): # # wrapped_module = strategy.setup_module(nn.Linear(3, 3)) # assert wrapped_module.mixed_precision == config + + +def test_fsdp_setup_optimizer_validation(): + """Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters.""" + module = nn.Linear(2, 2) + strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")]) + + bad_optimizer = Adam([{'params': [module.weight]}, {'params': [module.bias], 'lr': 1e-3}]) + with pytest.raises(ValueError, match="does not support multiple param groups"): + strategy.setup_optimizer(bad_optimizer) + + bad_optimizer = Adam(module.parameters()) + with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"): + strategy.setup_optimizer(bad_optimizer) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 38030df9ac378..3345efce60a43 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -17,10 +17,8 @@ import torch from tests_lite.helpers.models import RandomDataset from tests_lite.helpers.runif import RunIf -from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp.wrap import wrap -from torch.optim import Adam from torch.utils.data import DataLoader from lightning_lite import LightningLite @@ -41,7 +39,6 @@ def run(self): # get parameters on the wrapped model optimizer = torch.optim.SGD(model.parameters(), lr=0.1) - print([p for p in optimizer.param_groups[0].values()]) # optimizer nees to be set up independently optimizer = self.setup_optimizers(optimizer) @@ -119,23 +116,3 @@ def test_fsdp_train_save_load(manual_wrapping, precision): lite = FSDPLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) lite.manual_wrapping = manual_wrapping lite.run() - - -class SetupOptimizerLite(LightningLite): - def run(self): - module = nn.Linear(2, 2) - bad_optimizer = Adam(module.parameters()) - wrapped_module = self.setup_model(module) - good_optimizer = Adam(wrapped_module.parameters()) - - with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"): - self.setup_optimizers(bad_optimizer) - - lite_optimizer = self.setup_optimizers(good_optimizer) - assert lite_optimizer.optimizer == good_optimizer - - -@RunIf(standalone=True, min_cuda_gpus=2) -def test_fsdp_setup_optimizer(): - lite = SetupOptimizerLite(accelerator="cuda", strategy="fsdp", devices=2) - lite.run() From 28e64b7249e1dcf12f8484c5e277aed1efbf2c0f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Oct 2022 19:09:38 +0000 Subject: [PATCH 61/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_lite/strategies/test_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 5efce187572f9..f7313ff30d2f1 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -52,7 +52,7 @@ def test_fsdp_setup_optimizer_validation(): module = nn.Linear(2, 2) strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")]) - bad_optimizer = Adam([{'params': [module.weight]}, {'params': [module.bias], 'lr': 1e-3}]) + bad_optimizer = Adam([{"params": [module.weight]}, {"params": [module.bias], "lr": 1e-3}]) with pytest.raises(ValueError, match="does not support multiple param groups"): strategy.setup_optimizer(bad_optimizer) From ef7fb0eee173f14b5f87269fbd4e5c2943d90596 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 21:10:09 +0200 Subject: [PATCH 62/84] typo --- src/lightning_lite/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index a6499f1c18217..ed91392c0d65d 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -194,7 +194,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: """Set up an optimizer for a model wrapped with FSDP. This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify - that the optimizer was created after the model was wrapped with `:meth:setup_module` with a reference to the + that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the flattened parameters. """ from torch.distributed.fsdp import FlatParameter From 9bf3531e06df5d45d5943b161738caa32fcfeb2f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Nov 2022 13:49:37 +0000 Subject: [PATCH 63/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index b6382d8515083..02ec9c57f0edd 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -35,10 +35,10 @@ DDPShardedStrategy, DDPSpawnShardedStrategy, DeepSpeedStrategy, + FSDPStrategy, SingleDeviceStrategy, Strategy, XLAStrategy, - FSDPStrategy, ) from lightning_lite.strategies.strategy import _Sharded, TBroadcast from lightning_lite.utilities import move_data_to_device From 0750bfe4c39c35614f33115ded2c2da9f06cc3d1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 13 Nov 2022 01:27:04 +0100 Subject: [PATCH 64/84] changelog --- src/lightning_lite/CHANGELOG.md | 3 +++ src/pytorch_lightning/CHANGELOG.md | 4 ---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lightning_lite/CHANGELOG.md b/src/lightning_lite/CHANGELOG.md index a764b0e59f1c5..732e9452b08a7 100644 --- a/src/lightning_lite/CHANGELOG.md +++ b/src/lightning_lite/CHANGELOG.md @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningLite.setup_module()` and `LightningLite.setup_optimizers()` to support strategies that need to set up the model before an optimizer can be created ([#15185](https://github.com/Lightning-AI/lightning/pull/15185)) +- Added support for Fully Sharded Data Parallel (FSDP) training in Lightning Lite ([#14967](https://github.com/Lightning-AI/lightning/issues/14967)) + + ### Changed - The `LightningLite.run()` method is no longer abstract ([#14992](https://github.com/Lightning-AI/lightning/issues/14992)) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 303bc29197916..976ea6d056d7f 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -10,7 +10,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added support for DDP with `LRFinder` ([#15304](https://github.com/Lightning-AI/lightning/pull/15304)) -Added `LightningLite.setup_model()` and `LightningLite.setup_optimizers()` to support strategies that need to set up the model before an optimizer can be created ([#15185](https://github.com/Lightning-AI/lightning/pull/15185)) - Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) @@ -25,9 +24,6 @@ Added `LightningLite.setup_model()` and `LightningLite.setup_optimizers()` to su - Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15301](https://github.com/Lightning-AI/lightning/pull/15301)) -- Added support for Fully Sharded Data Parallel (FSDP) training in Lightning Lite ([#14967](https://github.com/Lightning-AI/lightning/issues/14967)) - - ### Changed - Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347)) From b96e1401814106ca625c0c1a56bf655ad66e4d64 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 13 Nov 2022 01:47:30 +0100 Subject: [PATCH 65/84] fix setup_module call --- tests/tests_lite/strategies/test_fsdp_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 3345efce60a43..7026adfff2306 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -35,7 +35,7 @@ def run(self): dataloader = DataLoader(RandomDataset(32, 64)) # model needs to be set up first in FSDP - model = self.setup_model(model) + model = self.setup_module(model) # get parameters on the wrapped model optimizer = torch.optim.SGD(model.parameters(), lr=0.1) From 2b3f6184e2e3dd71123322087e5339a2e93c0df3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 13 Nov 2022 02:00:48 +0100 Subject: [PATCH 66/84] fix --- src/lightning_lite/lite.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 02ec9c57f0edd..50ad43a721ec3 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -594,8 +594,7 @@ def _prepare_run_method(self) -> None: # wrap the run method, so we can inject setup logic or spawn processes for the user setattr(self, "run", partial(self._run_impl, self.run)) - @staticmethod - def _validate_setup(module: nn.Module, optimizers: Sequence[Optimizer]) -> None: + def _validate_setup(self, module: nn.Module, optimizers: Sequence[Optimizer]) -> None: if isinstance(module, _LiteModule): raise ValueError("A model should be passed only once to the `setup` method.") From a1eed5af8f0ad9233697b5568f5f7999ff3bfd95 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 13 Nov 2022 02:05:19 +0100 Subject: [PATCH 67/84] fix test --- tests/tests_lite/plugins/environments/test_slurm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_lite/plugins/environments/test_slurm.py b/tests/tests_lite/plugins/environments/test_slurm.py index 768e1f468da99..c7e9cb7cf637f 100644 --- a/tests/tests_lite/plugins/environments/test_slurm.py +++ b/tests/tests_lite/plugins/environments/test_slurm.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import os +import shutil import sys from unittest import mock @@ -120,6 +121,7 @@ def test_detect(): @RunIf(skip_windows=True) +@pytest.mark.skipif(shutil.which("srun") is not None, reason="must run on a machine where srun is not available") def test_srun_available_and_not_used(monkeypatch): """Test that a warning is emitted if Lightning suspects the user forgot to run their script with `srun`.""" monkeypatch.setattr(sys, "argv", ["train.py", "--lr", "0.01"]) From 612dc3d9b80eaebb0815a0b80dcc994187ae2b5f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 13 Nov 2022 02:08:52 +0100 Subject: [PATCH 68/84] update --- tests/tests_lite/strategies/test_fsdp_integration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 7026adfff2306..929b0bffbf97c 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -30,7 +30,8 @@ class FSDPLite(LightningLite): manual_wrapping = False def run(self): - model = self.get_model() + with self.sharded_model(): + model = self.get_model() dataloader = DataLoader(RandomDataset(32, 64)) From 190c5d2c9d593d78866d0f4021dcc114b937d33c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 20:50:13 +0100 Subject: [PATCH 69/84] update --- src/lightning_lite/strategies/fsdp.py | 4 +- tests/tests_lite/strategies/test_fsdp.py | 6 - .../strategies/test_fsdp_integration.py | 137 +++++++++--------- .../test_ddp_fully_sharded_native.py | 5 + 4 files changed, 75 insertions(+), 77 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index ed91392c0d65d..0dfaea31cec3d 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -26,7 +26,7 @@ from lightning_lite.plugins.precision.fsdp import FSDPPrecision from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning_lite.strategies.parallel import ParallelStrategy -from lightning_lite.strategies.strategy import TBroadcast +from lightning_lite.strategies.strategy import _Sharded, TBroadcast from lightning_lite.utilities.distributed import ( _distributed_available, _get_default_process_group_backend_for_device, @@ -51,7 +51,7 @@ _FSDP_ALIASES = ("fsdp", "fsdp_full_shard_offload") -class FSDPStrategy(ParallelStrategy): +class FSDPStrategy(ParallelStrategy, _Sharded): r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. .. warning:: ``FSDPStrategy`` is in BETA and subject to change. The interface can diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index f7313ff30d2f1..b864c7e66c4f8 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from unittest import mock import pytest @@ -35,16 +34,11 @@ def test_fsdp_support(*_): @RunIf(min_torch="1.12") -# @mock.patch("lightning_lite.strategies.fsdp.FullyShardedDataParallel") def test_fsdp_custom_mixed_precision(*_): """Test that passing a custom mixed precision config works.""" config = MixedPrecision() strategy = FSDPStrategy(mixed_precision=config) assert strategy.mixed_precision_config == config - # - # - # wrapped_module = strategy.setup_module(nn.Linear(3, 3)) - # assert wrapped_module.mixed_precision == config def test_fsdp_setup_optimizer_validation(): diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 929b0bffbf97c..90a7f9bc1f669 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -26,94 +26,93 @@ from lightning_lite.strategies import FSDPStrategy -class FSDPLite(LightningLite): - manual_wrapping = False +def _get_model(manual_wrapping=False): + model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + if not manual_wrapping: + return model - def run(self): - with self.sharded_model(): - model = self.get_model() + for i, layer in enumerate(model): + if i % 2 == 0: + model[i] = wrap(layer) + model = wrap(model) + return model - dataloader = DataLoader(RandomDataset(32, 64)) - # model needs to be set up first in FSDP - model = self.setup_module(model) +def _step(lite, model, batch): + forward_module = model._forward_module + original_module = model.module + assert isinstance(forward_module, FullyShardedDataParallel) + assert isinstance(lite._precision, FSDPPrecision) - # get parameters on the wrapped model - optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + precision = torch.float16 if lite._precision.precision == 16 else torch.bfloat16 + assert forward_module.mixed_precision.param_dtype == precision + assert forward_module.mixed_precision.reduce_dtype == precision + assert forward_module.mixed_precision.buffer_dtype == precision - # optimizer nees to be set up independently - optimizer = self.setup_optimizers(optimizer) + for layer_num in [0, 2]: + assert isinstance(original_module[layer_num], FullyShardedDataParallel) + assert original_module[layer_num].mixed_precision.param_dtype == precision + assert original_module[layer_num].mixed_precision.reduce_dtype == precision + assert original_module[layer_num].mixed_precision.buffer_dtype == precision - dataloader = self.setup_dataloaders(dataloader) - model.train() + output = model(batch) + loss = torch.nn.functional.mse_loss(output, torch.ones_like(output)) + return loss - data_iter = iter(dataloader) - batch = next(data_iter) - loss = self.step(model, batch) - self.backward(loss) - optimizer.step() - optimizer.zero_grad() - with tempfile.TemporaryFile() as ckpt_path: - ckpt_path = self.broadcast(str(ckpt_path)) - self._strategy.save_checkpoint(model.state_dict(), ckpt_path) +def _assert_save_equality(lite, model, ckpt_path): + current_state_dict = lite._strategy.get_module_state_dict(model) - self._assert_save_equality(model, ckpt_path) + checkpoint = lite.load(ckpt_path) + loaded_model = _get_model() + loaded_model.load_state_dict(checkpoint) - def get_model(self): - model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) - if not self.manual_wrapping: - return model + # model parameters are identical after loading + for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()): + assert torch.equal(current_param.float().cpu(), loaded_param.cpu()) + + +def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool: + return unwrapped_params >= 2 - for i, layer in enumerate(model): - if i % 2 == 0: - model[i] = wrap(layer) - model = wrap(model) - return model - def step(self, model, batch): - forward_module = model._forward_module - original_module = model.module - assert isinstance(forward_module, FullyShardedDataParallel) - assert isinstance(self._precision, FSDPPrecision) +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") +@pytest.mark.parametrize("precision", (16, pytest.param("bf16", marks=RunIf(bf16_cuda=True)))) +@pytest.mark.parametrize("manual_wrapping", [True, False]) +def test_fsdp_train_save_load(manual_wrapping, precision): + """Test FSDP training, saving and loading with different wrapping and precision settings.""" + strategy = FSDPStrategy() if manual_wrapping else FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy) + lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) + lite.launch() - precision = torch.float16 if self._precision.precision == 16 else torch.bfloat16 - assert forward_module.mixed_precision.param_dtype == precision - assert forward_module.mixed_precision.reduce_dtype == precision - assert forward_module.mixed_precision.buffer_dtype == precision + lite.manual_wrapping = manual_wrapping - for layer_num in [0, 2]: - assert isinstance(original_module[layer_num], FullyShardedDataParallel) - assert original_module[layer_num].mixed_precision.param_dtype == precision - assert original_module[layer_num].mixed_precision.reduce_dtype == precision - assert original_module[layer_num].mixed_precision.buffer_dtype == precision + with lite.sharded_model(): + model = _get_model() - output = model(batch) - loss = torch.nn.functional.mse_loss(output, torch.ones_like(output)) - return loss + dataloader = DataLoader(RandomDataset(32, 64)) - def _assert_save_equality(self, model, ckpt_path): - current_state_dict = self._strategy.get_module_state_dict(model) + # model needs to be set up first in FSDP + model = lite.setup_module(model) - checkpoint = self.load(ckpt_path) - loaded_model = self.get_model() - loaded_model.load_state_dict(checkpoint) + # get parameters on the wrapped model + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) - # model parameters are identical after loading - for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()): - assert torch.equal(current_param.float().cpu(), loaded_param.cpu()) + # optimizer nees to be set up independently + optimizer = lite.setup_optimizers(optimizer) + dataloader = lite.setup_dataloaders(dataloader) + model.train() -def custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool: - return unwrapped_params >= 2 + data_iter = iter(dataloader) + batch = next(data_iter) + loss = _step(model, batch) + lite.backward(loss) + optimizer.step() + optimizer.zero_grad() + with tempfile.TemporaryFile() as ckpt_path: + ckpt_path = lite.broadcast(str(ckpt_path)) + lite._strategy.save_checkpoint(model.state_dict(), ckpt_path) -@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") -@pytest.mark.parametrize("precision", (16, pytest.param("bf16", marks=RunIf(bf16_cuda=True)))) -@pytest.mark.parametrize("manual_wrapping", [True, False]) -def test_fsdp_train_save_load(manual_wrapping, precision): - """Test FSDP training, saving and loading with different wrapping and precision settings.""" - strategy = FSDPStrategy() if manual_wrapping else FSDPStrategy(auto_wrap_policy=custom_auto_wrap_policy) - lite = FSDPLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) - lite.manual_wrapping = manual_wrapping - lite.run() + _assert_save_equality(lite, model, ckpt_path) diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index ff962b24f480c..eb0dd19d08390 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -148,6 +148,7 @@ def custom_auto_wrap_policy( return unwrapped_params >= 2 +# lite: adopted @RunIf(min_torch="1.12") def test_invalid_on_cpu(tmpdir): """Test to ensure that we raise Misconfiguration for Native FSDP on CPU.""" @@ -161,6 +162,7 @@ def test_invalid_on_cpu(tmpdir): trainer.strategy.setup_environment() +# lite: adopted @RunIf(min_torch="1.12", min_cuda_gpus=1) @pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) def test_precision_plugin_config(precision, expected): @@ -171,6 +173,7 @@ def test_precision_plugin_config(precision, expected): assert config.reduce_dtype == expected +# lite: adopted @RunIf(min_torch="1.12") def test_fsdp_custom_mixed_precision(tmpdir): """Test to ensure that passing a custom mixed precision config works.""" @@ -179,6 +182,7 @@ def test_fsdp_custom_mixed_precision(tmpdir): assert strategy.mixed_precision_config == config +# lite: skipped @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): """Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run.""" @@ -241,6 +245,7 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, stra _run_multiple_stages(trainer, model) +# lite: adopted @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") def test_invalid_parameters_in_optimizer(tmpdir): trainer = Trainer(strategy="fsdp_native", accelerator="cuda", devices=1) From 3c67190338c116017ec69fb133932e5df3746c50 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 20:51:34 +0100 Subject: [PATCH 70/84] fix --- tests/tests_lite/strategies/test_fsdp_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 90a7f9bc1f669..f5e2c115ccee1 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -106,7 +106,7 @@ def test_fsdp_train_save_load(manual_wrapping, precision): data_iter = iter(dataloader) batch = next(data_iter) - loss = _step(model, batch) + loss = _step(lite, model, batch) lite.backward(loss) optimizer.step() optimizer.zero_grad() From 239a6749b8cf21271972e74364dac8b19b3c25a7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 20:53:08 +0100 Subject: [PATCH 71/84] fix duplicate import --- src/lightning_lite/strategies/fsdp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 0dfaea31cec3d..92550607ecdd3 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -46,7 +46,6 @@ FullyShardedDataParallel, MixedPrecision, ) - from torch.distributed.fsdp.wrap import enable_wrap # noqa: F401 _FSDP_ALIASES = ("fsdp", "fsdp_full_shard_offload") From ad770e6e4a7237595c445c75bb5d7c3f65f534df Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 20:59:50 +0100 Subject: [PATCH 72/84] add no_backward_sync for FSDP --- src/lightning_lite/strategies/fsdp.py | 20 ++++++++++++++++++- tests/tests_lite/strategies/test_fsdp.py | 25 +++++++++++++++++++++++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 92550607ecdd3..b2ff5e0d25ff0 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -26,7 +26,7 @@ from lightning_lite.plugins.precision.fsdp import FSDPPrecision from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning_lite.strategies.parallel import ParallelStrategy -from lightning_lite.strategies.strategy import _Sharded, TBroadcast +from lightning_lite.strategies.strategy import _BackwardSyncControl, _Sharded, TBroadcast from lightning_lite.utilities.distributed import ( _distributed_available, _get_default_process_group_backend_for_device, @@ -109,6 +109,7 @@ def __init__( self._num_nodes = 1 self._process_group_backend: Optional[str] = process_group_backend self._timeout: Optional[timedelta] = timeout + self._backward_sync_control = _FSDPBackwardSyncControl() self._ddp_kwargs = kwargs self.cpu_offload = cpu_offload @@ -286,3 +287,20 @@ def _set_world_ranks(self) -> None: self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) rank_zero_only.rank = self.cluster_environment.global_rank() + + +class _FSDPBackwardSyncControl(_BackwardSyncControl): + @contextmanager + def no_backward_sync(self, module: Module) -> Generator: + """Blocks gradient synchronization inside the + :class:`~torch.nn.parallel.distributed.DistributedDataParallel` wrapper.""" + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + + if not isinstance(module, FullyShardedDataParallel): + raise TypeError( + "Blocking backward sync is only possible if the module passed to" + f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `FullyShardedDataParallel`." + f" Got: {module.__class__.__name__}." + ) + with module.no_sync(): # type: ignore[operator] + yield diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index b864c7e66c4f8..0ffc90c276dec 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -13,6 +13,7 @@ # limitations under the License. from unittest import mock +from unittest.mock import MagicMock, Mock import pytest import torch @@ -21,10 +22,11 @@ from torch.optim import Adam from lightning_lite.strategies import FSDPStrategy +from lightning_lite.strategies.fsdp import _FSDPBackwardSyncControl from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision @mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False) @@ -53,3 +55,24 @@ def test_fsdp_setup_optimizer_validation(): bad_optimizer = Adam(module.parameters()) with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"): strategy.setup_optimizer(bad_optimizer) + + +@RunIf(min_torch="1.12") +def test_fsdp_no_backward_sync(): + """Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in + FullyShardedDataParallel.""" + + strategy = FSDPStrategy() + assert isinstance(strategy._backward_sync_control, _FSDPBackwardSyncControl) + + with pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`" + ): + with strategy._backward_sync_control.no_backward_sync(Mock()): + pass + + module = MagicMock(spec=FullyShardedDataParallel) + with strategy._backward_sync_control.no_backward_sync(module): + pass + + module.no_sync.assert_called_once() From 1c19b9631874590e4a7957e8fd8ed3fbebf92c26 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 21:15:27 +0100 Subject: [PATCH 73/84] fix --- tests/tests_lite/strategies/test_fsdp_integration.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index f5e2c115ccee1..8e5f29303cf10 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -40,7 +40,6 @@ def _get_model(manual_wrapping=False): def _step(lite, model, batch): forward_module = model._forward_module - original_module = model.module assert isinstance(forward_module, FullyShardedDataParallel) assert isinstance(lite._precision, FSDPPrecision) @@ -50,10 +49,10 @@ def _step(lite, model, batch): assert forward_module.mixed_precision.buffer_dtype == precision for layer_num in [0, 2]: - assert isinstance(original_module[layer_num], FullyShardedDataParallel) - assert original_module[layer_num].mixed_precision.param_dtype == precision - assert original_module[layer_num].mixed_precision.reduce_dtype == precision - assert original_module[layer_num].mixed_precision.buffer_dtype == precision + assert isinstance(forward_module[layer_num], FullyShardedDataParallel) + assert forward_module[layer_num].mixed_precision.param_dtype == precision + assert forward_module[layer_num].mixed_precision.reduce_dtype == precision + assert forward_module[layer_num].mixed_precision.buffer_dtype == precision output = model(batch) loss = torch.nn.functional.mse_loss(output, torch.ones_like(output)) From 1d2fa568106bb065c6a2cc0954a6e543ebad5164 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 21:18:05 +0100 Subject: [PATCH 74/84] fix literal import --- src/lightning_lite/plugins/precision/fsdp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning_lite/plugins/precision/fsdp.py b/src/lightning_lite/plugins/precision/fsdp.py index e8aafea518b12..020369bcbc4cf 100644 --- a/src/lightning_lite/plugins/precision/fsdp.py +++ b/src/lightning_lite/plugins/precision/fsdp.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING import torch +from typing_extensions import Literal from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision from lightning_lite.utilities.enums import PrecisionType From fac33ac20d29a1f676bd264a4fcb9a2da7450547 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 21:21:21 +0100 Subject: [PATCH 75/84] fix --- src/lightning_lite/strategies/fsdp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index b2ff5e0d25ff0..8f678b5174842 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -256,6 +256,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: + if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available(): + return from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload strategy_registry.register( From 56f2109c2dcb5543e46ce754ea9d1076f70163fc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 21:23:52 +0100 Subject: [PATCH 76/84] manual wrap --- tests/tests_lite/strategies/test_fsdp_integration.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 8e5f29303cf10..dea777c95b158 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -80,14 +80,12 @@ def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_par @pytest.mark.parametrize("manual_wrapping", [True, False]) def test_fsdp_train_save_load(manual_wrapping, precision): """Test FSDP training, saving and loading with different wrapping and precision settings.""" - strategy = FSDPStrategy() if manual_wrapping else FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy) + strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy) lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) lite.launch() - lite.manual_wrapping = manual_wrapping - with lite.sharded_model(): - model = _get_model() + model = _get_model(manual_wrapping) dataloader = DataLoader(RandomDataset(32, 64)) From 17bd7140705f1a22d1b0f4e4f640c819263ed373 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 21:32:52 +0100 Subject: [PATCH 77/84] avoid double wrap --- tests/tests_lite/strategies/test_fsdp_integration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index dea777c95b158..030876e71fb2c 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -34,7 +34,6 @@ def _get_model(manual_wrapping=False): for i, layer in enumerate(model): if i % 2 == 0: model[i] = wrap(layer) - model = wrap(model) return model From a97c56e56661e9f06aabf05b199692a699bf640e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 21:38:32 +0100 Subject: [PATCH 78/84] fix mypy --- src/lightning_lite/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 8f678b5174842..30e77a7148627 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -304,5 +304,5 @@ def no_backward_sync(self, module: Module) -> Generator: f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `FullyShardedDataParallel`." f" Got: {module.__class__.__name__}." ) - with module.no_sync(): # type: ignore[operator] + with module.no_sync(): yield From fac22b4543d67b8ae45432e4718c9a69827de610 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 21:38:47 +0100 Subject: [PATCH 79/84] revert original test --- tests/tests_lite/strategies/test_fsdp_integration.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 030876e71fb2c..628d1b225a7e3 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -39,6 +39,7 @@ def _get_model(manual_wrapping=False): def _step(lite, model, batch): forward_module = model._forward_module + original_module = model.module assert isinstance(forward_module, FullyShardedDataParallel) assert isinstance(lite._precision, FSDPPrecision) @@ -48,10 +49,10 @@ def _step(lite, model, batch): assert forward_module.mixed_precision.buffer_dtype == precision for layer_num in [0, 2]: - assert isinstance(forward_module[layer_num], FullyShardedDataParallel) - assert forward_module[layer_num].mixed_precision.param_dtype == precision - assert forward_module[layer_num].mixed_precision.reduce_dtype == precision - assert forward_module[layer_num].mixed_precision.buffer_dtype == precision + assert isinstance(original_module[layer_num], FullyShardedDataParallel) + assert original_module[layer_num].mixed_precision.param_dtype == precision + assert original_module[layer_num].mixed_precision.reduce_dtype == precision + assert original_module[layer_num].mixed_precision.buffer_dtype == precision output = model(batch) loss = torch.nn.functional.mse_loss(output, torch.ones_like(output)) From ffaba94ed1bfd61da7585c21bdb397b94511a9f2 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 21:51:10 +0100 Subject: [PATCH 80/84] skip import on torch <1.12 --- tests/tests_lite/strategies/test_fsdp_integration.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 628d1b225a7e3..052133e265e4c 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -17,13 +17,16 @@ import torch from tests_lite.helpers.models import RandomDataset from tests_lite.helpers.runif import RunIf -from torch.distributed.fsdp import FullyShardedDataParallel -from torch.distributed.fsdp.wrap import wrap from torch.utils.data import DataLoader from lightning_lite import LightningLite from lightning_lite.plugins import FSDPPrecision from lightning_lite.strategies import FSDPStrategy +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 + +if _TORCH_GREATER_EQUAL_1_12: + from torch.distributed.fsdp import FullyShardedDataParallel + from torch.distributed.fsdp.wrap import wrap def _get_model(manual_wrapping=False): From 133bb9c1b3401c051afd5e2f04579cde3db680f4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 21:53:53 +0100 Subject: [PATCH 81/84] torch compatibility --- tests/tests_lite/strategies/test_fsdp.py | 1 + tests/tests_lite/strategies/test_registry.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 0ffc90c276dec..8f609d53c253a 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -43,6 +43,7 @@ def test_fsdp_custom_mixed_precision(*_): assert strategy.mixed_precision_config == config +@RunIf(min_torch="1.12") def test_fsdp_setup_optimizer_validation(): """Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters.""" module = nn.Linear(2, 2) diff --git a/tests/tests_lite/strategies/test_registry.py b/tests/tests_lite/strategies/test_registry.py index 36bd11b45b339..bf5dfce73228c 100644 --- a/tests/tests_lite/strategies/test_registry.py +++ b/tests/tests_lite/strategies/test_registry.py @@ -13,6 +13,7 @@ # limitations under the License. from lightning_lite.strategies import STRATEGY_REGISTRY +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 def test_strategy_registry_with_new_strategy(): @@ -41,7 +42,7 @@ def __init__(self, param1, param2): def test_available_strategies_in_registry(): - assert set(STRATEGY_REGISTRY.available_strategies()) == { + expected = { "ddp_sharded_find_unused_parameters_false", "ddp_sharded", "ddp_find_unused_parameters_false", @@ -65,6 +66,7 @@ def test_available_strategies_in_registry(): "tpu_spawn", "xla", "dp", - "fsdp", - "fsdp_full_shard_offload", } + if _TORCH_GREATER_EQUAL_1_12: + expected += {"fsdp", "fsdp_full_shard_offload"} + assert set(STRATEGY_REGISTRY.available_strategies()) == expected From 07d4998c03428b6b06e4225ecd76fcdad2261741 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 20 Nov 2022 21:57:01 +0100 Subject: [PATCH 82/84] fix --- tests/tests_lite/strategies/test_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_registry.py b/tests/tests_lite/strategies/test_registry.py index bf5dfce73228c..81a49eec08934 100644 --- a/tests/tests_lite/strategies/test_registry.py +++ b/tests/tests_lite/strategies/test_registry.py @@ -68,5 +68,5 @@ def test_available_strategies_in_registry(): "dp", } if _TORCH_GREATER_EQUAL_1_12: - expected += {"fsdp", "fsdp_full_shard_offload"} + expected |= {"fsdp", "fsdp_full_shard_offload"} assert set(STRATEGY_REGISTRY.available_strategies()) == expected From 0ece17d71a0f81c3dfe9921bb1e7ea1467b2cf67 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 21 Nov 2022 14:13:01 +0100 Subject: [PATCH 83/84] revert comments in pytorch tests --- .../strategies/test_ddp_fully_sharded_native.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index eb0dd19d08390..5bb6b84d9e0f0 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -148,7 +148,6 @@ def custom_auto_wrap_policy( return unwrapped_params >= 2 -# lite: adopted @RunIf(min_torch="1.12") def test_invalid_on_cpu(tmpdir): """Test to ensure that we raise Misconfiguration for Native FSDP on CPU.""" @@ -162,7 +161,6 @@ def test_invalid_on_cpu(tmpdir): trainer.strategy.setup_environment() -# lite: adopted @RunIf(min_torch="1.12", min_cuda_gpus=1) @pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) def test_precision_plugin_config(precision, expected): @@ -173,7 +171,6 @@ def test_precision_plugin_config(precision, expected): assert config.reduce_dtype == expected -# lite: adopted @RunIf(min_torch="1.12") def test_fsdp_custom_mixed_precision(tmpdir): """Test to ensure that passing a custom mixed precision config works.""" @@ -182,7 +179,6 @@ def test_fsdp_custom_mixed_precision(tmpdir): assert strategy.mixed_precision_config == config -# lite: skipped @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): """Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run.""" @@ -200,7 +196,6 @@ def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) -# lite: adopted @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize("precision", (16, pytest.param("bf16", marks=RunIf(bf16_cuda=True)))) def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision): @@ -212,7 +207,6 @@ def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision): _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) -# lite: adopted @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize( "model, strategy", @@ -245,7 +239,6 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, stra _run_multiple_stages(trainer, model) -# lite: adopted @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") def test_invalid_parameters_in_optimizer(tmpdir): trainer = Trainer(strategy="fsdp_native", accelerator="cuda", devices=1) From e47d3c89c14779f2363a02cd33b377d33e7273b0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 21 Nov 2022 14:15:24 +0100 Subject: [PATCH 84/84] fix copy-paste error in docstring --- src/lightning_lite/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 30e77a7148627..8053992d18525 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -295,7 +295,7 @@ class _FSDPBackwardSyncControl(_BackwardSyncControl): @contextmanager def no_backward_sync(self, module: Module) -> Generator: """Blocks gradient synchronization inside the - :class:`~torch.nn.parallel.distributed.DistributedDataParallel` wrapper.""" + :class:`~torch.distributed.fsdp.FullyShardedDataParallel` wrapper.""" from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel if not isinstance(module, FullyShardedDataParallel):