Skip to content

Commit d786985

Browse files
author
Sean Naren
authored
[FIX] Native FSDP precision + tests (#12985)
1 parent c028ff3 commit d786985

File tree

5 files changed

+155
-86
lines changed

5 files changed

+155
-86
lines changed

src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any
14+
from typing import Any, Optional
15+
16+
import torch
1517

1618
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
19+
from pytorch_lightning.utilities.enums import PrecisionType
1720
from pytorch_lightning.utilities.exceptions import MisconfigurationException
21+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
22+
23+
if _TORCH_GREATER_EQUAL_1_12:
24+
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
25+
else:
26+
MixedPrecision = None
1827

1928

2029
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
@@ -29,3 +38,18 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
2938
raise MisconfigurationException(
3039
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
3140
)
41+
42+
@property
43+
def mixed_precision_config(self) -> Optional[MixedPrecision]:
44+
assert MixedPrecision is not None
45+
if self.precision == PrecisionType.HALF:
46+
dtype = torch.float16
47+
elif self.precision == PrecisionType.BFLOAT:
48+
dtype = torch.bfloat16
49+
else:
50+
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
51+
return MixedPrecision(
52+
param_dtype=dtype,
53+
reduce_dtype=dtype,
54+
buffer_dtype=dtype,
55+
)

src/pytorch_lightning/strategies/fully_sharded_native.py

Lines changed: 84 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
2424
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2525
from pytorch_lightning.plugins.precision import PrecisionPlugin
26+
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
27+
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
2628
from pytorch_lightning.strategies.parallel import ParallelStrategy
2729
from pytorch_lightning.strategies.strategy import TBroadcast
2830
from pytorch_lightning.trainer.states import TrainerFn
@@ -35,18 +37,23 @@
3537
from pytorch_lightning.utilities.distributed import group as _group
3638
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
3739
from pytorch_lightning.utilities.exceptions import MisconfigurationException
38-
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
40+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
3941
from pytorch_lightning.utilities.optimizer import optimizers_to_device
42+
from pytorch_lightning.utilities.rank_zero import rank_zero_info
4043
from pytorch_lightning.utilities.seed import reset_seed
4144

42-
if _TORCH_GREATER_EQUAL_1_11:
45+
if _TORCH_GREATER_EQUAL_1_12:
4346
from torch.distributed.fsdp.fully_sharded_data_parallel import (
4447
BackwardPrefetch,
4548
CPUOffload,
4649
FullyShardedDataParallel,
50+
MixedPrecision,
4751
)
4852
from torch.distributed.fsdp.wrap import enable_wrap
49-
53+
else:
54+
MixedPrecision = None
55+
BackwardPrefetch = None # type: ignore[misc,assignment]
56+
CPUOffload = None # type: ignore[misc,assignment]
5057

5158
log = logging.getLogger(__name__)
5259

@@ -56,18 +63,20 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
5663
strategy_name = "fsdp_native"
5764
_registered_strategies: List[str] = []
5865

59-
def __init__( # type: ignore[no-untyped-def]
66+
def __init__(
6067
self,
6168
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
6269
parallel_devices: Optional[List[torch.device]] = None,
6370
cluster_environment: Optional[ClusterEnvironment] = None,
6471
checkpoint_io: Optional[CheckpointIO] = None,
6572
precision_plugin: Optional[PrecisionPlugin] = None,
6673
process_group_backend: Optional[str] = None,
67-
cpu_offload=None,
68-
backward_prefetch=None,
74+
cpu_offload: Optional[CPUOffload] = None,
75+
backward_prefetch: Optional[BackwardPrefetch] = None,
76+
mixed_precision: Optional[MixedPrecision] = None,
77+
**kwargs: Any,
6978
) -> None:
70-
"""Strategy for Fully Sharded Data Parallel provided by torch.Distributed.
79+
r"""Strategy for Fully Sharded Data Parallel provided by torch.Distributed.
7180
7281
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
7382
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
@@ -84,22 +93,29 @@ def __init__( # type: ignore[no-untyped-def]
8493
`https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html`
8594
8695
Arguments:
87-
cpu_offload (Optional [CPUOffload]):
96+
cpu_offload:
8897
CPU offloading config. Currently, only parameter and gradient CPU
8998
offload is supported. It can be enabled via passing in
9099
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
91100
currently implicitly enables gradient offloading to CPU in order for
92101
params and grads to be on same device to work with optimizer. This
93102
API is subject to change. Default is ``None`` in which case there
94103
will be no offloading.
95-
backward_prefetch: (Optional[BackwardPrefetch]):
104+
backward_prefetch:
96105
This is an experimental feature that is subject to change in the
97106
the near future. It allows users to enable two different backward_prefetch
98107
algorithms to help backward communication and computation overlapping.
99108
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
109+
mixed_precision:
110+
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`
111+
or BF16 if ``precision=bf16`` unless a config is passed in.
112+
This is only available in PyTorch 1.12 and later.
113+
\**kwargs: Passed to the FSDP Context manager which will configure the FSDP class when wrapping modules.
100114
"""
101-
if not _TORCH_GREATER_EQUAL_1_11:
102-
raise MisconfigurationException("DDPFullyShardedNativeStrategy is supported from pytorch v1.11.0 onwards.")
115+
if not _TORCH_GREATER_EQUAL_1_12:
116+
raise MisconfigurationException(
117+
"`DDPFullyShardedNativeStrategy` is supported from PyTorch v1.12.0 onwards."
118+
)
103119

104120
super().__init__(
105121
accelerator=accelerator,
@@ -109,16 +125,23 @@ def __init__( # type: ignore[no-untyped-def]
109125
precision_plugin=precision_plugin,
110126
)
111127
self._process_group = None
112-
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
113-
self._process_group_backend: Optional[str] = process_group_backend
114-
self.cpu_offload: Optional[CPUOffload] = cpu_offload
115-
self.backward_prefetch: Optional[BackwardPrefetch] = backward_prefetch
128+
self.num_nodes = 1
129+
self._process_group_backend = process_group_backend
130+
self.cpu_offload = cpu_offload
131+
self.backward_prefetch = backward_prefetch
132+
self.mixed_precision = mixed_precision
133+
self._rank_0_will_call_children_scripts: bool = False
134+
self.kwargs = kwargs
116135

117136
@property
118137
def root_device(self) -> torch.device:
119138
assert self.parallel_devices is not None
120139
return self.parallel_devices[self.local_rank]
121140

141+
@property
142+
def num_processes(self) -> int:
143+
return len(self.parallel_devices) if self.parallel_devices is not None else 0
144+
122145
@property
123146
def process_group(self) -> Optional[ProcessGroup]:
124147
if self._process_group is None:
@@ -130,10 +153,28 @@ def process_group(self) -> Optional[ProcessGroup]:
130153
def process_group_backend(self) -> Optional[str]:
131154
return self._process_group_backend
132155

156+
@property
157+
def mixed_precision_config(self) -> Optional[MixedPrecision]:
158+
if self.mixed_precision:
159+
return self.mixed_precision
160+
plugin = self.precision_plugin
161+
if isinstance(plugin, FullyShardedNativeMixedPrecisionPlugin):
162+
return plugin.mixed_precision_config
163+
164+
@property
165+
def distributed_sampler_kwargs(self) -> Dict:
166+
return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
167+
133168
def setup_environment(self) -> None:
169+
log.detail(f"{self.__class__.__name__}: setting up distributed...")
134170
reset_seed()
171+
172+
# determine which process we are and world size
173+
self.set_world_ranks()
174+
135175
# set warning rank
136176
rank_zero_only.rank = self.global_rank
177+
137178
self._process_group_backend = self._get_process_group_backend()
138179
assert self.cluster_environment is not None
139180
init_dist_connection(self.cluster_environment, self._process_group_backend)
@@ -146,36 +187,51 @@ def _get_process_group_backend(self) -> str:
146187
or get_default_process_group_backend_for_device(self.root_device)
147188
)
148189

190+
def set_world_ranks(self) -> None:
191+
if self.cluster_environment is None:
192+
return
193+
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
194+
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
195+
rank_zero_only.rank = self.cluster_environment.global_rank()
196+
197+
def _configure_launcher(self) -> None:
198+
assert self.cluster_environment is not None
199+
if not self.cluster_environment.creates_processes_externally:
200+
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
201+
self._rank_0_will_call_children_scripts = True
202+
149203
def setup(self, trainer: "pl.Trainer") -> None:
150204
self.accelerator.setup(trainer)
205+
# share ddp pids to all processes
206+
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
151207

152208
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
153209
assert self.model is not None
154210
self.model = self._layer_sync.apply(self.model)
155211

156-
if not self.cpu_offload:
157-
self.model_to_device()
212+
# we set the device so that optimizers can be created with distributed comms.
213+
assert self.lightning_module is not None
214+
self.lightning_module._device = self.root_device
158215

159216
self.barrier()
160217
self.setup_optimizers(trainer)
161218
optimizers_to_device(self.optimizers, self.root_device)
162219
self.setup_precision_plugin()
163220

164221
def model_to_device(self) -> None:
165-
# ensure we update the device type in the lightning module
166-
assert self.lightning_module is not None
167-
log.info(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
168-
self.lightning_module.to(self.root_device)
222+
pass
169223

170224
@contextlib.contextmanager
171225
def model_sharded_context(self) -> Generator:
172226
log.detail(f"{self.__class__.__name__}: entered model_sharded_context.")
173-
174227
with enable_wrap(
175228
wrapper_cls=FullyShardedDataParallel,
176229
process_group=self.process_group,
177230
cpu_offload=self.cpu_offload,
178231
backward_prefetch=self.backward_prefetch,
232+
mixed_precision=self.mixed_precision_config,
233+
device_id=self.root_device.index,
234+
**self.kwargs,
179235
):
180236
yield
181237

@@ -219,7 +275,7 @@ def _determine_device_ids(self) -> List[int]:
219275
return [self.root_device.index]
220276

221277
def teardown(self) -> None:
222-
log.info(f"{self.__class__.__name__}: tearing down strategy...")
278+
rank_zero_info(f"{self.__class__.__name__}: tearing down strategy...")
223279
if (
224280
self.lightning_module is not None
225281
and self.lightning_module.trainer is not None
@@ -229,15 +285,18 @@ def teardown(self) -> None:
229285
assert self.model is not None
230286
self.model = self._layer_sync.revert(self.model)
231287

232-
super().teardown()
288+
assert self.cluster_environment is not None
289+
self.cluster_environment.teardown()
290+
self.precision_plugin.teardown()
291+
self.accelerator.teardown()
233292

234293
@classmethod
235294
def get_registered_strategies(cls) -> List[str]:
236295
return cls._registered_strategies
237296

238297
@classmethod
239298
def register_strategies(cls, strategy_registry: Dict) -> None:
240-
if _TORCH_GREATER_EQUAL_1_11:
299+
if _TORCH_GREATER_EQUAL_1_12:
241300
strategy_registry.register(
242301
"fsdp_native",
243302
cls,

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -700,17 +700,13 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
700700
if self._precision_flag == 16
701701
else "Using bfloat16 Automatic Mixed Precision (AMP)"
702702
)
703-
if isinstance(self.strategy, DDPFullyShardedNativeStrategy):
704-
raise MisconfigurationException(
705-
"DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision"
706-
)
707703

708704
if self._amp_type_flag == AMPType.NATIVE:
709705
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
710706

711707
if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)):
712708
return ShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
713-
if isinstance(self.strategy, DDPFullyShardedStrategy):
709+
if isinstance(self.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)):
714710
return FullyShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
715711
return NativeMixedPrecisionPlugin(self._precision_flag, device)
716712

tests/tests_pytorch/accelerators/test_accelerator_connector.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock
574574
assert trainer.strategy.local_rank == 0
575575

576576

577-
@RunIf(min_torch="1.11")
577+
@RunIf(min_torch="1.12")
578578
def test_check_native_fsdp_strategy_and_fallback():
579579
with pytest.raises(
580580
MisconfigurationException,
@@ -584,25 +584,6 @@ def test_check_native_fsdp_strategy_and_fallback():
584584
Trainer(accelerator="cpu", strategy="fsdp_native")
585585

586586

587-
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
588-
@mock.patch("torch.cuda.device_count", return_value=1)
589-
@mock.patch("torch.cuda.is_available", return_value=True)
590-
@RunIf(min_torch="1.11")
591-
def test_mixed_precision_support_with_native_fsdp_strategy(device_count_mock, mock_cuda_available, tmpdir):
592-
with pytest.raises(
593-
MisconfigurationException, match="DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision"
594-
):
595-
trainer = Trainer(
596-
default_root_dir=tmpdir,
597-
fast_dev_run=True,
598-
strategy="fsdp_native",
599-
accelerator="gpu",
600-
devices=1,
601-
precision=16,
602-
)
603-
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
604-
605-
606587
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
607588
def test_unsupported_tpu_choice(mock_tpu_acc_avail):
608589

0 commit comments

Comments
 (0)