From 9fb150680703f192e14d980fd318ba389c92db6b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 5 May 2022 16:25:12 +0100 Subject: [PATCH 01/18] Fix Native FSDP precision + tests --- .../precision/fully_sharded_native_amp.py | 24 ++++- .../strategies/fully_sharded_native.py | 29 ++++-- .../connectors/accelerator_connector.py | 6 +- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/imports.py | 2 + .../test_ddp_fully_sharded_native.py | 89 +++++++++++-------- 6 files changed, 100 insertions(+), 51 deletions(-) diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 870e658bfc9c3..e7245e7a003e8 100644 --- a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -11,10 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional + +import torch from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 + +MixedPrecision = None +if _TORCH_GREATER_EQUAL_1_12: + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): @@ -29,3 +37,17 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: raise MisconfigurationException( f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" ) + + @property + def mixed_precision_config(self) -> Optional[MixedPrecision]: + if self.precision == PrecisionType.HALF: + dtype = torch.float16 + elif self.precision == PrecisionType.BFLOAT: + dtype = torch.bfloat16 + else: + raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision}.") + return MixedPrecision( + param_dtype=dtype, + reduce_dtype=dtype, + buffer_dtype=dtype, + ) diff --git a/pytorch_lightning/strategies/fully_sharded_native.py b/pytorch_lightning/strategies/fully_sharded_native.py index 7d4a037826ab3..c6860993e4e95 100644 --- a/pytorch_lightning/strategies/fully_sharded_native.py +++ b/pytorch_lightning/strategies/fully_sharded_native.py @@ -22,6 +22,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn @@ -34,15 +35,17 @@ from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.seed import reset_seed -if _TORCH_GREATER_EQUAL_1_11: +MixedPrecision = None +if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.fully_sharded_data_parallel import ( BackwardPrefetch, CPUOffload, FullyShardedDataParallel, + MixedPrecision, ) from torch.distributed.fsdp.wrap import enable_wrap @@ -65,6 +68,7 @@ def __init__( # type: ignore[no-untyped-def] process_group_backend: Optional[str] = None, cpu_offload=None, backward_prefetch=None, + mixed_precision=None, ) -> None: """Strategy for Fully Sharded Data Parallel provided by torch.Distributed. @@ -96,9 +100,13 @@ def __init__( # type: ignore[no-untyped-def] the near future. It allows users to enable two different backward_prefetch algorithms to help backward communication and computation overlapping. Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. + mixed_precision: (Optional[MixedPrecision]): + Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16` + or BF16 if ``precision=bf16`` unless a config is passed in. + This is only available in PyTorch 1.12 and later. """ - if not _TORCH_GREATER_EQUAL_1_11: - raise MisconfigurationException("DDPFullyShardedNativeStrategy is supported from pytorch v1.11.0 onwards.") + if not _TORCH_GREATER_EQUAL_1_12: + raise MisconfigurationException("DDPFullyShardedNativeStrategy is supported from pytorch v1.12.0 onwards.") super().__init__( accelerator=accelerator, @@ -112,6 +120,7 @@ def __init__( # type: ignore[no-untyped-def] self._process_group_backend: Optional[str] = process_group_backend self.cpu_offload: Optional[CPUOffload] = cpu_offload self.backward_prefetch: Optional[BackwardPrefetch] = backward_prefetch + self.mixed_precision: Optional[MixedPrecision] = mixed_precision @property def root_device(self) -> torch.device: @@ -128,6 +137,14 @@ def process_group(self) -> Optional[ProcessGroup]: def process_group_backend(self) -> Optional[str]: return self._process_group_backend + @property + def mixed_precision_config(self) -> Optional[MixedPrecision]: + if self.mixed_precision: + return self.mixed_precision + plugin = self.precision_plugin + if isinstance(plugin, FullyShardedNativeMixedPrecisionPlugin): + return plugin.mixed_precision_config + def setup_environment(self) -> None: reset_seed() # set warning rank @@ -168,12 +185,12 @@ def model_to_device(self) -> None: @contextlib.contextmanager def model_sharded_context(self) -> Generator: log.detail(f"{self.__class__.__name__}: entered model_sharded_context.") - with enable_wrap( wrapper_cls=FullyShardedDataParallel, process_group=self.process_group, cpu_offload=self.cpu_offload, backward_prefetch=self.backward_prefetch, + mixed_precision=self.mixed_precision_config, ): yield @@ -235,7 +252,7 @@ def get_registered_strategies(cls) -> List[str]: @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: - if _TORCH_GREATER_EQUAL_1_11: + if _TORCH_GREATER_EQUAL_1_12: strategy_registry.register( "fsdp_native", cls, diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 753234bc21d83..32c92d8b1eb09 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -679,17 +679,13 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if self._precision_flag == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)" ) - if isinstance(self.strategy, DDPFullyShardedNativeStrategy): - raise MisconfigurationException( - "DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision" - ) if self._amp_type_flag == AMPType.NATIVE: device = "cpu" if self._accelerator_flag == "cpu" else "cuda" if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)): return ShardedNativeMixedPrecisionPlugin(self._precision_flag, device) - if isinstance(self.strategy, DDPFullyShardedStrategy): + if isinstance(self.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)): return FullyShardedNativeMixedPrecisionPlugin(self._precision_flag, device) return NativeMixedPrecisionPlugin(self._precision_flag, device) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 87947ac9a10f3..9e34149d227ec 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -50,6 +50,7 @@ _TORCH_GREATER_EQUAL_1_9, _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11, + _TORCH_GREATER_EQUAL_1_12, _TORCH_QUANTIZE_AVAILABLE, _TORCHTEXT_AVAILABLE, _TORCHVISION_AVAILABLE, diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 835e56f1816da..b841455a9a5e6 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -94,6 +94,8 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: _TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0") _TORCH_LESSER_EQUAL_1_10_2 = _compare_version("torch", operator.le, "1.10.2") _TORCH_GREATER_EQUAL_1_11 = _compare_version("torch", operator.ge, "1.11.0") +# todo: remove "dev" when PyTorch 1.12 is released +_TORCH_GREATER_EQUAL_1_12 = _compare_version("torch", operator.ge, "1.12.0dev") _APEX_AVAILABLE = _module_available("apex.amp") _BAGUA_AVAILABLE = _package_available("bagua") diff --git a/tests/strategies/test_ddp_fully_sharded_native.py b/tests/strategies/test_ddp_fully_sharded_native.py index cf4973e5ae035..f68f5d39afb1e 100644 --- a/tests/strategies/test_ddp_fully_sharded_native.py +++ b/tests/strategies/test_ddp_fully_sharded_native.py @@ -1,26 +1,27 @@ import os from typing import Any, Dict, Optional -from unittest import mock import pytest import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from pytorch_lightning.utilities.types import STEP_OUTPUT from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf -if _TORCH_GREATER_EQUAL_1_11: - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel +if _TORCH_GREATER_EQUAL_1_12: + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision from torch.distributed.fsdp.wrap import wrap -@RunIf(min_torch="1.11") +@RunIf(min_torch="1.12dev") def test_invalid_on_cpu(tmpdir): - """Test to ensure that to raise Misconfiguration for Native FSDP on CPU.""" + """Test to ensure that we raise Misconfiguration for Native FSDP on CPU.""" with pytest.raises( MisconfigurationException, match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, " @@ -31,29 +32,27 @@ def test_invalid_on_cpu(tmpdir): trainer.strategy.setup_environment() -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) -@mock.patch("torch.cuda.device_count", return_value=1) -@mock.patch("torch.cuda.is_available", return_value=True) -@RunIf(min_torch="1.11") -def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): - """Test to ensure that plugin native amp plugin raises Misconfiguration error.""" - with pytest.raises( - MisconfigurationException, match="DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision" - ): - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - strategy="fsdp_native", - accelerator="gpu", - devices=1, - precision=16, - ) - assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy) +@RunIf(min_torch="1.12dev") +@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) +def test_precision_plugin_config(precision, expected): + plugin = FullyShardedNativeMixedPrecisionPlugin(precision=precision, device="cuda") + config = plugin.mixed_precision_config + assert config.param_dtype == expected + assert config.buffer_dtype == expected + assert config.reduce_dtype == expected + + +@RunIf(min_torch="1.12dev") +def test_fsdp_custom_mixed_precision(tmpdir): + """Test to ensure that passing a custom mixed precision config works.""" + config = MixedPrecision() + strategy = DDPFullyShardedNativeStrategy(mixed_precision=config) + assert strategy.mixed_precision_config == config class TestFSDPModel(BoringModel): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self): + super().__init__() self.layer: Optional[torch.nn.Module] = None def _init_model(self) -> None: @@ -79,16 +78,20 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) - def on_train_start(self) -> None: + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: self._assert_layer_fsdp_instance() - def on_test_start(self) -> None: + def on_test_batch_end( + self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: self._assert_layer_fsdp_instance() - def on_validation_start(self) -> None: + def on_validation_batch_end( + self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: self._assert_layer_fsdp_instance() - def on_prediction_start(self) -> None: + def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None: self._assert_layer_fsdp_instance() def _assert_layer_fsdp_instance(self) -> None: @@ -101,9 +104,15 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer.module[0].reshard_after_forward is True assert self.layer.module[2].reshard_after_forward is True + precision = torch.float16 if self.precision == 16 else torch.bfloat16 + assert self.layer.mixed_precision.param_dtype == precision + assert self.layer.mixed_precision.reduce_dtype == precision + assert self.layer.mixed_precision.buffer_dtype == precision + -@RunIf(min_gpus=2, skip_windows=True, standalone=True, min_torch="1.11") -def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): +@RunIf(min_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev") +@pytest.mark.parametrize("precision", [16, "bf16"]) +def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir, precision): """Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run.""" model = TestFSDPModel() @@ -112,26 +121,28 @@ def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): accelerator="gpu", devices=2, strategy="fsdp_native", - precision=16, + precision=precision, max_epochs=1, sync_batchnorm=True, ) _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) -@RunIf(min_gpus=1, skip_windows=True, standalone=True, min_torch="1.11") -def test_fully_sharded_native_strategy_checkpoint(tmpdir): +@RunIf(min_gpus=1, skip_windows=True, standalone=True, min_torch="1.12dev") +@pytest.mark.parametrize("precision", [16, "bf16"]) +def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" model = TestFSDPModel() trainer = Trainer( - default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp_native", precision=16, max_epochs=1 + default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp_native", precision=precision, max_epochs=1 ) _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) -@RunIf(min_gpus=2, skip_windows=True, standalone=True, min_torch="1.11") -def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir): +@RunIf(min_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev") +@pytest.mark.parametrize("precision", [16, "bf16"]) +def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, precision): """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run.""" model = TestFSDPModel() @@ -141,7 +152,7 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir): accelerator="gpu", devices=2, strategy="fsdp_native", - precision=16, + precision=precision, max_epochs=1, callbacks=[ck], ) From 83df2047902f67da8ff4b9f290a3b9c5be64a8c3 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 5 May 2022 19:29:26 +0100 Subject: [PATCH 02/18] Fix typing --- .../plugins/precision/fully_sharded_native_amp.py | 5 +++-- pytorch_lightning/strategies/fully_sharded_native.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index e7245e7a003e8..337636f7fa688 100644 --- a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -22,7 +22,7 @@ MixedPrecision = None if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision # type: ignore[no-redef] class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): @@ -39,7 +39,8 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: ) @property - def mixed_precision_config(self) -> Optional[MixedPrecision]: + def mixed_precision_config(self) -> Optional[MixedPrecision]: # type: ignore + assert MixedPrecision is not None if self.precision == PrecisionType.HALF: dtype = torch.float16 elif self.precision == PrecisionType.BFLOAT: diff --git a/pytorch_lightning/strategies/fully_sharded_native.py b/pytorch_lightning/strategies/fully_sharded_native.py index c6860993e4e95..b52ce1241360f 100644 --- a/pytorch_lightning/strategies/fully_sharded_native.py +++ b/pytorch_lightning/strategies/fully_sharded_native.py @@ -41,11 +41,11 @@ MixedPrecision = None if _TORCH_GREATER_EQUAL_1_12: + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision # type: ignore[no-redef] from torch.distributed.fsdp.fully_sharded_data_parallel import ( BackwardPrefetch, CPUOffload, FullyShardedDataParallel, - MixedPrecision, ) from torch.distributed.fsdp.wrap import enable_wrap @@ -120,7 +120,7 @@ def __init__( # type: ignore[no-untyped-def] self._process_group_backend: Optional[str] = process_group_backend self.cpu_offload: Optional[CPUOffload] = cpu_offload self.backward_prefetch: Optional[BackwardPrefetch] = backward_prefetch - self.mixed_precision: Optional[MixedPrecision] = mixed_precision + self.mixed_precision: Optional[MixedPrecision] = mixed_precision # type: ignore @property def root_device(self) -> torch.device: @@ -138,7 +138,7 @@ def process_group_backend(self) -> Optional[str]: return self._process_group_backend @property - def mixed_precision_config(self) -> Optional[MixedPrecision]: + def mixed_precision_config(self) -> Optional[MixedPrecision]: # type: ignore if self.mixed_precision: return self.mixed_precision plugin = self.precision_plugin From e514c9a370c653cacd06514ce77a69f137175bec Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 6 May 2022 15:52:20 +0100 Subject: [PATCH 03/18] Add support for children scripts, work on fixing tests --- .../strategies/fully_sharded_native.py | 49 +++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/strategies/fully_sharded_native.py b/pytorch_lightning/strategies/fully_sharded_native.py index b52ce1241360f..6818afb832206 100644 --- a/pytorch_lightning/strategies/fully_sharded_native.py +++ b/pytorch_lightning/strategies/fully_sharded_native.py @@ -23,6 +23,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin +from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn @@ -116,16 +117,30 @@ def __init__( # type: ignore[no-untyped-def] precision_plugin=precision_plugin, ) self._process_group = None - self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 + self._num_nodes = 1 self._process_group_backend: Optional[str] = process_group_backend self.cpu_offload: Optional[CPUOffload] = cpu_offload self.backward_prefetch: Optional[BackwardPrefetch] = backward_prefetch self.mixed_precision: Optional[MixedPrecision] = mixed_precision # type: ignore + self._rank_0_will_call_children_scripts: bool = False @property def root_device(self) -> torch.device: return self.parallel_devices[self.local_rank] + @property + def num_nodes(self) -> int: + return self._num_nodes + + @num_nodes.setter + def num_nodes(self, num_nodes: int) -> None: + # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks + self._num_nodes = num_nodes + + @property + def num_processes(self): + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + @property def process_group(self) -> Optional[ProcessGroup]: if self._process_group is None: @@ -145,14 +160,26 @@ def mixed_precision_config(self) -> Optional[MixedPrecision]: # type: ignore if isinstance(plugin, FullyShardedNativeMixedPrecisionPlugin): return plugin.mixed_precision_config + @property + def distributed_sampler_kwargs(self) -> Dict: + return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) + def setup_environment(self) -> None: + self.setup_distributed() + super().setup_environment() + + def setup_distributed(self) -> None: + log.detail(f"{self.__class__.__name__}: setting up distributed...") reset_seed() + + # determine which process we are and world size + self.set_world_ranks() + # set warning rank rank_zero_only.rank = self.global_rank + self._process_group_backend = self._get_process_group_backend() - assert self.cluster_environment is not None init_dist_connection(self.cluster_environment, self._process_group_backend) - super().setup_environment() def _get_process_group_backend(self) -> str: return ( @@ -161,8 +188,22 @@ def _get_process_group_backend(self) -> str: or get_default_process_group_backend_for_device(self.root_device) ) + def set_world_ranks(self) -> None: + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() + + def _configure_launcher(self) -> None: + if not self.cluster_environment.creates_processes_externally: + self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + self._rank_0_will_call_children_scripts = True + def setup(self, trainer: "pl.Trainer") -> None: self.accelerator.setup(trainer) + # share ddp pids to all processes + self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: assert self.model is not None @@ -178,7 +219,7 @@ def setup(self, trainer: "pl.Trainer") -> None: def model_to_device(self) -> None: # ensure we update the device type in the lightning module - assert self.lightning_module is not None + assert self.model is not None log.info(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") self.lightning_module.to(self.root_device) From b715e5ce5b085538d3052ef74e48628cca92a0a2 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sat, 7 May 2022 11:10:36 +0100 Subject: [PATCH 04/18] Fix tests, fix FSDP integration --- pytorch_lightning/core/mixins/device_dtype_mixin.py | 10 ++++++---- pytorch_lightning/strategies/fully_sharded_native.py | 2 +- tests/strategies/test_ddp_fully_sharded_native.py | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/mixins/device_dtype_mixin.py b/pytorch_lightning/core/mixins/device_dtype_mixin.py index ed9f1f7683555..06ab2b8ebe30c 100644 --- a/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -188,14 +188,16 @@ def half(self) -> Self: def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None ) -> None: - def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: + def apply_fn(module: Union["DeviceDtypeModuleMixin", pl.LightningModule]) -> None: # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't # work when using `init_meta_context`. - if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)): - return if device is not None: module._device = device if dtype is not None: module._dtype = dtype - self.apply(apply_fn) + for m in self.modules(): + if isinstance(m, (DeviceDtypeModuleMixin, pl.LightningModule)): + apply_fn(m) + + apply_fn(self) diff --git a/pytorch_lightning/strategies/fully_sharded_native.py b/pytorch_lightning/strategies/fully_sharded_native.py index 6818afb832206..79d25729e68cf 100644 --- a/pytorch_lightning/strategies/fully_sharded_native.py +++ b/pytorch_lightning/strategies/fully_sharded_native.py @@ -219,7 +219,7 @@ def setup(self, trainer: "pl.Trainer") -> None: def model_to_device(self) -> None: # ensure we update the device type in the lightning module - assert self.model is not None + assert self.lightning_module is not None log.info(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") self.lightning_module.to(self.root_device) diff --git a/tests/strategies/test_ddp_fully_sharded_native.py b/tests/strategies/test_ddp_fully_sharded_native.py index f68f5d39afb1e..f7d3378d7e0f8 100644 --- a/tests/strategies/test_ddp_fully_sharded_native.py +++ b/tests/strategies/test_ddp_fully_sharded_native.py @@ -161,7 +161,7 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, precision): def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): trainer.fit(model) - + model_path = trainer.strategy.broadcast(model_path) model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path trainer.save_checkpoint(model_path, weights_only=True) @@ -169,7 +169,7 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): _assert_save_equality(trainer, model_path, cls=TestFSDPModel) # Test entry point - trainer.test(model) # model is wrapped, will not call configure_shared_model + trainer.test(model) # model is wrapped, will not call `configure_sharded_model` # provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap trainer.test(ckpt_path=model_path) From ef0add9357fd667d3fcf3c362f02fa2e3e3704ec Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 21 Jun 2022 15:04:56 +0100 Subject: [PATCH 05/18] Only print on rank 0 --- src/pytorch_lightning/strategies/fully_sharded_native.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index be0caec066ea2..054b9360abbde 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -32,7 +32,7 @@ from pytorch_lightning.utilities.distributed import ( _get_process_group_backend_from_env, distributed_available, - get_default_process_group_backend_for_device, + get_default_process_group_backend_for_device, rank_zero_info, ) from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available @@ -276,7 +276,7 @@ def _determine_device_ids(self) -> List[int]: return [self.root_device.index] def teardown(self) -> None: - log.info(f"{self.__class__.__name__}: tearing down strategy...") + rank_zero_info(f"{self.__class__.__name__}: tearing down strategy...") if ( self.lightning_module is not None and self.lightning_module.trainer is not None From 5f9437ccfcf6770b246bb526382ac852b1d9306d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Jun 2022 14:06:25 +0000 Subject: [PATCH 06/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/strategies/fully_sharded_native.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 054b9360abbde..40ee71466a2fc 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -32,10 +32,15 @@ from pytorch_lightning.utilities.distributed import ( _get_process_group_backend_from_env, distributed_available, - get_default_process_group_backend_for_device, rank_zero_info, + get_default_process_group_backend_for_device, ) from pytorch_lightning.utilities.distributed import group as _group -from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import ( + init_dist_connection, + rank_zero_info, + ReduceOp, + sync_ddp_if_available, +) from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.optimizer import optimizers_to_device From 503d7ef8c3061e6606aada0f5b2f0c2df39919ba Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 23 Jun 2022 12:03:10 +0100 Subject: [PATCH 07/18] Pass additional arguments to the FSDP wrapper --- src/pytorch_lightning/strategies/fully_sharded_native.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 40ee71466a2fc..e8f3d6c89b55a 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -76,6 +76,7 @@ def __init__( # type: ignore[no-untyped-def] cpu_offload=None, backward_prefetch=None, mixed_precision=None, + **kwargs ) -> None: """Strategy for Fully Sharded Data Parallel provided by torch.Distributed. @@ -111,6 +112,7 @@ def __init__( # type: ignore[no-untyped-def] Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16` or BF16 if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. + **kwargs: Passed to the FSDP Context manager which will configure the FSDP class when wrapping modules. """ if not _TORCH_GREATER_EQUAL_1_12: raise MisconfigurationException("DDPFullyShardedNativeStrategy is supported from pytorch v1.12.0 onwards.") @@ -129,6 +131,7 @@ def __init__( # type: ignore[no-untyped-def] self.backward_prefetch: Optional[BackwardPrefetch] = backward_prefetch self.mixed_precision: Optional[MixedPrecision] = mixed_precision # type: ignore self._rank_0_will_call_children_scripts: bool = False + self.kwargs = kwargs @property def root_device(self) -> torch.device: @@ -238,6 +241,7 @@ def model_sharded_context(self) -> Generator: cpu_offload=self.cpu_offload, backward_prefetch=self.backward_prefetch, mixed_precision=self.mixed_precision_config, + **self.kwargs ): yield From eaf1729a0e0a314628edd11679832820b43772de Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Jun 2022 11:04:56 +0000 Subject: [PATCH 08/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/strategies/fully_sharded_native.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index e8f3d6c89b55a..cd5113f23b0cb 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -76,7 +76,7 @@ def __init__( # type: ignore[no-untyped-def] cpu_offload=None, backward_prefetch=None, mixed_precision=None, - **kwargs + **kwargs, ) -> None: """Strategy for Fully Sharded Data Parallel provided by torch.Distributed. @@ -241,7 +241,7 @@ def model_sharded_context(self) -> Generator: cpu_offload=self.cpu_offload, backward_prefetch=self.backward_prefetch, mixed_precision=self.mixed_precision_config, - **self.kwargs + **self.kwargs, ): yield From cd898a564b7c7b362cfefa6e8f77179e8e00866a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 23 Jun 2022 15:03:34 +0100 Subject: [PATCH 09/18] Fix tests references, work around touching internals --- .../core/mixins/device_dtype_mixin.py | 10 ++++----- .../strategies/fully_sharded_native.py | 22 ++++++++----------- .../test_ddp_fully_sharded_native.py | 12 +++++----- 3 files changed, 19 insertions(+), 25 deletions(-) diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index 0c0cd10d1aaef..5f6397e4562e5 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -189,16 +189,14 @@ def half(self) -> Self: def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None ) -> None: - def apply_fn(module: Union["DeviceDtypeModuleMixin", pl.LightningModule]) -> None: + def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't # work when using `init_meta_context`. + if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)): + return if device is not None: module._device = device if dtype is not None: module._dtype = dtype - for m in self.modules(): - if isinstance(m, (DeviceDtypeModuleMixin, pl.LightningModule)): - apply_fn(m) - - apply_fn(self) + self.apply(apply_fn) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index cd5113f23b0cb..d2e3149c12190 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -35,15 +35,11 @@ get_default_process_group_backend_for_device, ) from pytorch_lightning.utilities.distributed import group as _group -from pytorch_lightning.utilities.distributed import ( - init_dist_connection, - rank_zero_info, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.optimizer import optimizers_to_device +from pytorch_lightning.utilities.rank_zero import rank_zero_info from pytorch_lightning.utilities.seed import reset_seed MixedPrecision = None @@ -218,8 +214,8 @@ def setup(self, trainer: "pl.Trainer") -> None: assert self.model is not None self.model = self._layer_sync.apply(self.model) - if not self.cpu_offload: - self.model_to_device() + # we set the device so that optimizers can be created with distributed comms. + self.lightning_module._device = self.root_device self.barrier() self.setup_optimizers(trainer) @@ -227,10 +223,7 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_precision_plugin() def model_to_device(self) -> None: - # ensure we update the device type in the lightning module - assert self.lightning_module is not None - log.info(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") - self.lightning_module.to(self.root_device) + pass @contextlib.contextmanager def model_sharded_context(self) -> Generator: @@ -241,6 +234,7 @@ def model_sharded_context(self) -> Generator: cpu_offload=self.cpu_offload, backward_prefetch=self.backward_prefetch, mixed_precision=self.mixed_precision_config, + device_id=self.root_device.index, **self.kwargs, ): yield @@ -295,7 +289,9 @@ def teardown(self) -> None: assert self.model is not None self.model = self._layer_sync.revert(self.model) - super().teardown() + self.cluster_environment.teardown() + self.precision_plugin.teardown() + self.accelerator.teardown() @classmethod def get_registered_strategies(cls) -> List[str]: diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index 5ac160c1c989c..34d1296f2b42d 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -3,16 +3,16 @@ import pytest import torch -from tests.helpers.boring_model import BoringModel -from tests.helpers.runif import RunIf from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.types import STEP_OUTPUT +from tests_pytorch.helpers.runif import RunIf if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision @@ -110,7 +110,7 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer.mixed_precision.buffer_dtype == precision -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.11") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev") def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): """Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run.""" @@ -120,7 +120,7 @@ def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): accelerator="gpu", devices=2, strategy="fsdp_native", - precision=precision, + precision=16, max_epochs=1, sync_batchnorm=True, ) @@ -139,7 +139,7 @@ def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision): _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.11") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev") def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir): """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run.""" @@ -150,7 +150,7 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir): accelerator="gpu", devices=2, strategy="fsdp_native", - precision=precision, + precision=16, max_epochs=1, callbacks=[ck], ) From 771b8cf288c2325df074fe4ca4d2c3c6cb29bd15 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 24 Jun 2022 11:11:14 +0100 Subject: [PATCH 10/18] Fix checks --- .../tests_pytorch/accelerators/test_accelerator_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/accelerators/test_accelerator_connector.py b/tests/tests_pytorch/accelerators/test_accelerator_connector.py index c5480fad089fc..a11678966c7d0 100644 --- a/tests/tests_pytorch/accelerators/test_accelerator_connector.py +++ b/tests/tests_pytorch/accelerators/test_accelerator_connector.py @@ -643,7 +643,7 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock assert trainer.strategy.local_rank == 0 -@RunIf(min_torch="1.11") +@RunIf(min_torch="1.12") def test_check_native_fsdp_strategy_and_fallback(): with pytest.raises( MisconfigurationException, @@ -656,7 +656,7 @@ def test_check_native_fsdp_strategy_and_fallback(): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) @mock.patch("torch.cuda.device_count", return_value=1) @mock.patch("torch.cuda.is_available", return_value=True) -@RunIf(min_torch="1.11") +@RunIf(min_torch="1.12") def test_mixed_precision_support_with_native_fsdp_strategy(device_count_mock, mock_cuda_available, tmpdir): with pytest.raises( MisconfigurationException, match="DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision" From a8828d3963fbf947b932338fb9462a425aa53ae3 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 19 Jul 2022 10:28:43 +0100 Subject: [PATCH 11/18] Fix typing issues --- src/pytorch_lightning/strategies/fully_sharded_native.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 8dca6557e794f..51c85a5ac7218 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -185,6 +185,7 @@ def setup_distributed(self) -> None: rank_zero_only.rank = self.global_rank self._process_group_backend = self._get_process_group_backend() + assert self.cluster_environment is not None init_dist_connection(self.cluster_environment, self._process_group_backend) def _get_process_group_backend(self) -> str: @@ -202,6 +203,7 @@ def set_world_ranks(self) -> None: rank_zero_only.rank = self.cluster_environment.global_rank() def _configure_launcher(self) -> None: + assert self.cluster_environment is not None if not self.cluster_environment.creates_processes_externally: self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) self._rank_0_will_call_children_scripts = True @@ -216,6 +218,7 @@ def setup(self, trainer: "pl.Trainer") -> None: self.model = self._layer_sync.apply(self.model) # we set the device so that optimizers can be created with distributed comms. + assert self.lightning_module is not None self.lightning_module._device = self.root_device self.barrier() @@ -290,6 +293,7 @@ def teardown(self) -> None: assert self.model is not None self.model = self._layer_sync.revert(self.model) + assert self.cluster_environment is not None self.cluster_environment.teardown() self.precision_plugin.teardown() self.accelerator.teardown() From 033ce1507008372c0f6cb74d1461408e9331d3df Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 19 Jul 2022 10:41:19 +0100 Subject: [PATCH 12/18] Update min torch --- .../strategies/test_ddp_fully_sharded_native.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index 34d1296f2b42d..eaed63725c4bc 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -19,7 +19,7 @@ from torch.distributed.fsdp.wrap import wrap -@RunIf(min_torch="1.12dev") +@RunIf(min_torch="1.12") def test_invalid_on_cpu(tmpdir): """Test to ensure that we raise Misconfiguration for Native FSDP on CPU.""" with pytest.raises( @@ -32,7 +32,7 @@ def test_invalid_on_cpu(tmpdir): trainer.strategy.setup_environment() -@RunIf(min_torch="1.12dev") +@RunIf(min_torch="1.12") @pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) def test_precision_plugin_config(precision, expected): plugin = FullyShardedNativeMixedPrecisionPlugin(precision=precision, device="cuda") @@ -42,7 +42,7 @@ def test_precision_plugin_config(precision, expected): assert config.reduce_dtype == expected -@RunIf(min_torch="1.12dev") +@RunIf(min_torch="1.12") def test_fsdp_custom_mixed_precision(tmpdir): """Test to ensure that passing a custom mixed precision config works.""" config = MixedPrecision() @@ -110,7 +110,7 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer.mixed_precision.buffer_dtype == precision -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): """Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run.""" @@ -127,7 +127,7 @@ def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) -@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12dev") +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize("precision", [16, "bf16"]) def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" @@ -139,7 +139,7 @@ def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision): _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir): """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run.""" From 86ec744b0bccb00629dd622962f7d1f6cbc8e1ac Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 19 Jul 2022 10:42:21 +0100 Subject: [PATCH 13/18] Add type --- src/pytorch_lightning/strategies/fully_sharded_native.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 51c85a5ac7218..f391bf201e1d5 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -144,7 +144,7 @@ def num_nodes(self, num_nodes: int) -> None: self._num_nodes = num_nodes @property - def num_processes(self): + def num_processes(self) -> int: return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property From b275f032a6a6d8b55da1b36d4e296f0c703bfc67 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 19 Jul 2022 11:09:52 +0100 Subject: [PATCH 14/18] Fix tests --- .../test_accelerator_connector.py | 19 ------------------- .../test_ddp_fully_sharded_native.py | 2 +- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/tests/tests_pytorch/accelerators/test_accelerator_connector.py b/tests/tests_pytorch/accelerators/test_accelerator_connector.py index b61afaf0f91a0..f3dbf0877aa6f 100644 --- a/tests/tests_pytorch/accelerators/test_accelerator_connector.py +++ b/tests/tests_pytorch/accelerators/test_accelerator_connector.py @@ -584,25 +584,6 @@ def test_check_native_fsdp_strategy_and_fallback(): Trainer(accelerator="cpu", strategy="fsdp_native") -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) -@mock.patch("torch.cuda.device_count", return_value=1) -@mock.patch("torch.cuda.is_available", return_value=True) -@RunIf(min_torch="1.12") -def test_mixed_precision_support_with_native_fsdp_strategy(device_count_mock, mock_cuda_available, tmpdir): - with pytest.raises( - MisconfigurationException, match="DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision" - ): - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - strategy="fsdp_native", - accelerator="gpu", - devices=1, - precision=16, - ) - assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy) - - @mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True) def test_unsupported_tpu_choice(mock_tpu_acc_avail): diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index eaed63725c4bc..1ac7ad0b6660b 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -32,7 +32,7 @@ def test_invalid_on_cpu(tmpdir): trainer.strategy.setup_environment() -@RunIf(min_torch="1.12") +@RunIf(min_torch="1.12", min_cuda_gpus=1) @pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) def test_precision_plugin_config(precision, expected): plugin = FullyShardedNativeMixedPrecisionPlugin(precision=precision, device="cuda") From 2e948427350061c243ca033c83984566d740bb17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 19 Jul 2022 14:12:01 +0200 Subject: [PATCH 15/18] Mypy cleanup --- .../precision/fully_sharded_native_amp.py | 7 ++-- .../strategies/fully_sharded_native.py | 34 ++++++++++--------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 337636f7fa688..7621c2c00cca9 100644 --- a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -20,9 +20,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 -MixedPrecision = None if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision # type: ignore[no-redef] + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision +else: + MixedPrecision = None class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): @@ -39,7 +40,7 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: ) @property - def mixed_precision_config(self) -> Optional[MixedPrecision]: # type: ignore + def mixed_precision_config(self) -> Optional[MixedPrecision]: assert MixedPrecision is not None if self.precision == PrecisionType.HALF: dtype = torch.float16 diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index f391bf201e1d5..b1545aceef1fe 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -42,16 +42,18 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_info from pytorch_lightning.utilities.seed import reset_seed -MixedPrecision = None if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision # type: ignore[no-redef] from torch.distributed.fsdp.fully_sharded_data_parallel import ( BackwardPrefetch, CPUOffload, FullyShardedDataParallel, + MixedPrecision, ) from torch.distributed.fsdp.wrap import enable_wrap - +else: + MixedPrecision = None + BackwardPrefetch = None # type: ignore[misc,assignment] + CPUOffload = None # type: ignore[misc,assignment] log = logging.getLogger(__name__) @@ -61,7 +63,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): strategy_name = "fsdp_native" _registered_strategies: List[str] = [] - def __init__( # type: ignore[no-untyped-def] + def __init__( self, accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, @@ -69,10 +71,10 @@ def __init__( # type: ignore[no-untyped-def] checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, process_group_backend: Optional[str] = None, - cpu_offload=None, - backward_prefetch=None, - mixed_precision=None, - **kwargs, + cpu_offload: Optional[CPUOffload] = None, + backward_prefetch: Optional[BackwardPrefetch] = None, + mixed_precision: Optional[MixedPrecision] = None, + **kwargs: Any, ) -> None: """Strategy for Fully Sharded Data Parallel provided by torch.Distributed. @@ -91,7 +93,7 @@ def __init__( # type: ignore[no-untyped-def] `https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html` Arguments: - cpu_offload (Optional [CPUOffload]): + cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this @@ -99,12 +101,12 @@ def __init__( # type: ignore[no-untyped-def] params and grads to be on same device to work with optimizer. This API is subject to change. Default is ``None`` in which case there will be no offloading. - backward_prefetch: (Optional[BackwardPrefetch]): + backward_prefetch: This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch algorithms to help backward communication and computation overlapping. Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. - mixed_precision: (Optional[MixedPrecision]): + mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16` or BF16 if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. @@ -122,10 +124,10 @@ def __init__( # type: ignore[no-untyped-def] ) self._process_group = None self._num_nodes = 1 - self._process_group_backend: Optional[str] = process_group_backend - self.cpu_offload: Optional[CPUOffload] = cpu_offload - self.backward_prefetch: Optional[BackwardPrefetch] = backward_prefetch - self.mixed_precision: Optional[MixedPrecision] = mixed_precision # type: ignore + self._process_group_backend = process_group_backend + self.cpu_offload = cpu_offload + self.backward_prefetch = backward_prefetch + self.mixed_precision = mixed_precision self._rank_0_will_call_children_scripts: bool = False self.kwargs = kwargs @@ -159,7 +161,7 @@ def process_group_backend(self) -> Optional[str]: return self._process_group_backend @property - def mixed_precision_config(self) -> Optional[MixedPrecision]: # type: ignore + def mixed_precision_config(self) -> Optional[MixedPrecision]: if self.mixed_precision: return self.mixed_precision plugin = self.precision_plugin From b9f2ba1afe0a37c88448d59c3a10089c542d3b15 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 19 Jul 2022 13:51:08 +0100 Subject: [PATCH 16/18] Address reviews --- .../strategies/fully_sharded_native.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index b1545aceef1fe..1f0f4e0711a71 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -123,7 +123,7 @@ def __init__( precision_plugin=precision_plugin, ) self._process_group = None - self._num_nodes = 1 + self.num_nodes = 1 self._process_group_backend = process_group_backend self.cpu_offload = cpu_offload self.backward_prefetch = backward_prefetch @@ -136,15 +136,6 @@ def root_device(self) -> torch.device: assert self.parallel_devices is not None return self.parallel_devices[self.local_rank] - @property - def num_nodes(self) -> int: - return self._num_nodes - - @num_nodes.setter - def num_nodes(self, num_nodes: int) -> None: - # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks - self._num_nodes = num_nodes - @property def num_processes(self) -> int: return len(self.parallel_devices) if self.parallel_devices is not None else 0 @@ -173,10 +164,6 @@ def distributed_sampler_kwargs(self) -> Dict: return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) def setup_environment(self) -> None: - self.setup_distributed() - super().setup_environment() - - def setup_distributed(self) -> None: log.detail(f"{self.__class__.__name__}: setting up distributed...") reset_seed() @@ -189,6 +176,7 @@ def setup_distributed(self) -> None: self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None init_dist_connection(self.cluster_environment, self._process_group_backend) + super().setup_environment() def _get_process_group_backend(self) -> str: return ( From bb1c7ce8f63922dfe92f08a2d4ce6f078f55cfba Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Tue, 19 Jul 2022 14:28:37 +0100 Subject: [PATCH 17/18] Apply suggestions from code review Co-authored-by: Rohit Gupta --- .../plugins/precision/fully_sharded_native_amp.py | 2 +- src/pytorch_lightning/strategies/fully_sharded_native.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 7621c2c00cca9..8c693f2975bbd 100644 --- a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -47,7 +47,7 @@ def mixed_precision_config(self) -> Optional[MixedPrecision]: elif self.precision == PrecisionType.BFLOAT: dtype = torch.bfloat16 else: - raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision}.") + raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") return MixedPrecision( param_dtype=dtype, reduce_dtype=dtype, diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 1f0f4e0711a71..251ea1202a035 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -110,10 +110,10 @@ def __init__( Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16` or BF16 if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. - **kwargs: Passed to the FSDP Context manager which will configure the FSDP class when wrapping modules. + \**kwargs: Passed to the FSDP Context manager which will configure the FSDP class when wrapping modules. """ if not _TORCH_GREATER_EQUAL_1_12: - raise MisconfigurationException("DDPFullyShardedNativeStrategy is supported from pytorch v1.12.0 onwards.") + raise MisconfigurationException("`DDPFullyShardedNativeStrategy` is supported from PyTorch v1.12.0 onwards.") super().__init__( accelerator=accelerator, From 7a3533fbe4a18008804672c38e1dd235a9652ceb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Jul 2022 13:31:19 +0000 Subject: [PATCH 18/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/strategies/fully_sharded_native.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 251ea1202a035..7528d5b95903e 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -76,7 +76,7 @@ def __init__( mixed_precision: Optional[MixedPrecision] = None, **kwargs: Any, ) -> None: - """Strategy for Fully Sharded Data Parallel provided by torch.Distributed. + r"""Strategy for Fully Sharded Data Parallel provided by torch.Distributed. Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain @@ -113,7 +113,9 @@ def __init__( \**kwargs: Passed to the FSDP Context manager which will configure the FSDP class when wrapping modules. """ if not _TORCH_GREATER_EQUAL_1_12: - raise MisconfigurationException("`DDPFullyShardedNativeStrategy` is supported from PyTorch v1.12.0 onwards.") + raise MisconfigurationException( + "`DDPFullyShardedNativeStrategy` is supported from PyTorch v1.12.0 onwards." + ) super().__init__( accelerator=accelerator,