Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional

import torch

from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
else:
MixedPrecision = None


class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
Expand All @@ -29,3 +38,18 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)

@property
def mixed_precision_config(self) -> Optional[MixedPrecision]:
assert MixedPrecision is not None
if self.precision == PrecisionType.HALF:
dtype = torch.float16
elif self.precision == PrecisionType.BFLOAT:
dtype = torch.bfloat16
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
return MixedPrecision(
param_dtype=dtype,
reduce_dtype=dtype,
buffer_dtype=dtype,
)
109 changes: 84 additions & 25 deletions src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
Expand All @@ -35,18 +37,23 @@
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.seed import reset_seed

if _TORCH_GREATER_EQUAL_1_11:
if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
FullyShardedDataParallel,
MixedPrecision,
)
from torch.distributed.fsdp.wrap import enable_wrap

else:
MixedPrecision = None
BackwardPrefetch = None # type: ignore[misc,assignment]
CPUOffload = None # type: ignore[misc,assignment]

log = logging.getLogger(__name__)

Expand All @@ -56,18 +63,20 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
strategy_name = "fsdp_native"
_registered_strategies: List[str] = []

def __init__( # type: ignore[no-untyped-def]
def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
process_group_backend: Optional[str] = None,
cpu_offload=None,
backward_prefetch=None,
cpu_offload: Optional[CPUOffload] = None,
backward_prefetch: Optional[BackwardPrefetch] = None,
mixed_precision: Optional[MixedPrecision] = None,
**kwargs: Any,
) -> None:
"""Strategy for Fully Sharded Data Parallel provided by torch.Distributed.
r"""Strategy for Fully Sharded Data Parallel provided by torch.Distributed.

Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
Expand All @@ -84,22 +93,29 @@ def __init__( # type: ignore[no-untyped-def]
`https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html`

Arguments:
cpu_offload (Optional [CPUOffload]):
cpu_offload:
CPU offloading config. Currently, only parameter and gradient CPU
offload is supported. It can be enabled via passing in
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
currently implicitly enables gradient offloading to CPU in order for
params and grads to be on same device to work with optimizer. This
API is subject to change. Default is ``None`` in which case there
will be no offloading.
backward_prefetch: (Optional[BackwardPrefetch]):
backward_prefetch:
This is an experimental feature that is subject to change in the
the near future. It allows users to enable two different backward_prefetch
algorithms to help backward communication and computation overlapping.
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
mixed_precision:
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`
or BF16 if ``precision=bf16`` unless a config is passed in.
This is only available in PyTorch 1.12 and later.
\**kwargs: Passed to the FSDP Context manager which will configure the FSDP class when wrapping modules.
"""
if not _TORCH_GREATER_EQUAL_1_11:
raise MisconfigurationException("DDPFullyShardedNativeStrategy is supported from pytorch v1.11.0 onwards.")
if not _TORCH_GREATER_EQUAL_1_12:
raise MisconfigurationException(
"`DDPFullyShardedNativeStrategy` is supported from PyTorch v1.12.0 onwards."
)

super().__init__(
accelerator=accelerator,
Expand All @@ -109,16 +125,23 @@ def __init__( # type: ignore[no-untyped-def]
precision_plugin=precision_plugin,
)
self._process_group = None
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
self._process_group_backend: Optional[str] = process_group_backend
self.cpu_offload: Optional[CPUOffload] = cpu_offload
self.backward_prefetch: Optional[BackwardPrefetch] = backward_prefetch
self.num_nodes = 1
self._process_group_backend = process_group_backend
self.cpu_offload = cpu_offload
self.backward_prefetch = backward_prefetch
self.mixed_precision = mixed_precision
self._rank_0_will_call_children_scripts: bool = False
self.kwargs = kwargs

@property
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]

@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0

@property
def process_group(self) -> Optional[ProcessGroup]:
if self._process_group is None:
Expand All @@ -130,10 +153,28 @@ def process_group(self) -> Optional[ProcessGroup]:
def process_group_backend(self) -> Optional[str]:
return self._process_group_backend

@property
def mixed_precision_config(self) -> Optional[MixedPrecision]:
if self.mixed_precision:
return self.mixed_precision
plugin = self.precision_plugin
if isinstance(plugin, FullyShardedNativeMixedPrecisionPlugin):
return plugin.mixed_precision_config

@property
def distributed_sampler_kwargs(self) -> Dict:
return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)

def setup_environment(self) -> None:
log.detail(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()

# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
init_dist_connection(self.cluster_environment, self._process_group_backend)
Expand All @@ -146,36 +187,51 @@ def _get_process_group_backend(self) -> str:
or get_default_process_group_backend_for_device(self.root_device)
)

def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
if not self.cluster_environment.creates_processes_externally:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
self._rank_0_will_call_children_scripts = True

def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
# share ddp pids to all processes
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)

if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
assert self.model is not None
self.model = self._layer_sync.apply(self.model)

if not self.cpu_offload:
self.model_to_device()
# we set the device so that optimizers can be created with distributed comms.
assert self.lightning_module is not None
self.lightning_module._device = self.root_device

self.barrier()
self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)
self.setup_precision_plugin()

def model_to_device(self) -> None:
# ensure we update the device type in the lightning module
assert self.lightning_module is not None
log.info(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
self.lightning_module.to(self.root_device)
pass

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
log.detail(f"{self.__class__.__name__}: entered model_sharded_context.")

with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
backward_prefetch=self.backward_prefetch,
mixed_precision=self.mixed_precision_config,
device_id=self.root_device.index,
**self.kwargs,
):
yield

Expand Down Expand Up @@ -219,7 +275,7 @@ def _determine_device_ids(self) -> List[int]:
return [self.root_device.index]

def teardown(self) -> None:
log.info(f"{self.__class__.__name__}: tearing down strategy...")
rank_zero_info(f"{self.__class__.__name__}: tearing down strategy...")
if (
self.lightning_module is not None
and self.lightning_module.trainer is not None
Expand All @@ -229,15 +285,18 @@ def teardown(self) -> None:
assert self.model is not None
self.model = self._layer_sync.revert(self.model)

super().teardown()
assert self.cluster_environment is not None
self.cluster_environment.teardown()
self.precision_plugin.teardown()
self.accelerator.teardown()

@classmethod
def get_registered_strategies(cls) -> List[str]:
return cls._registered_strategies

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
if _TORCH_GREATER_EQUAL_1_11:
if _TORCH_GREATER_EQUAL_1_12:
strategy_registry.register(
"fsdp_native",
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,17 +700,13 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
if self._precision_flag == 16
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
if isinstance(self.strategy, DDPFullyShardedNativeStrategy):
raise MisconfigurationException(
"DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision"
)

if self._amp_type_flag == AMPType.NATIVE:
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"

if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)):
return ShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
if isinstance(self.strategy, DDPFullyShardedStrategy):
if isinstance(self.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)):
return FullyShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
return NativeMixedPrecisionPlugin(self._precision_flag, device)

Expand Down
21 changes: 1 addition & 20 deletions tests/tests_pytorch/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock
assert trainer.strategy.local_rank == 0


@RunIf(min_torch="1.11")
@RunIf(min_torch="1.12")
def test_check_native_fsdp_strategy_and_fallback():
with pytest.raises(
MisconfigurationException,
Expand All @@ -584,25 +584,6 @@ def test_check_native_fsdp_strategy_and_fallback():
Trainer(accelerator="cpu", strategy="fsdp_native")


@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
@mock.patch("torch.cuda.device_count", return_value=1)
@mock.patch("torch.cuda.is_available", return_value=True)
@RunIf(min_torch="1.11")
def test_mixed_precision_support_with_native_fsdp_strategy(device_count_mock, mock_cuda_available, tmpdir):
with pytest.raises(
MisconfigurationException, match="DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision"
):
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
strategy="fsdp_native",
accelerator="gpu",
devices=1,
precision=16,
)
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)


@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
def test_unsupported_tpu_choice(mock_tpu_acc_avail):

Expand Down
Loading