diff --git a/CHANGELOG.md b/CHANGELOG.md index 906603f55e3db..9351dda511550 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -265,6 +265,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350)) +- Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350)) + + +- Fixed Native FSDP calling distributed communications early when setting device within the `LightningModule` ([#13387](https://github.com/PyTorchLightning/pytorch-lightning/pull/13387)) + + - diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index 5f6397e4562e5..24aed2549dc7b 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -189,14 +189,16 @@ def half(self) -> Self: def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None ) -> None: - def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: - # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't - # work when using `init_meta_context`. - if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)): - return + def apply_fn(module: Union["DeviceDtypeModuleMixin", pl.LightningModule]) -> None: if device is not None: module._device = device if dtype is not None: module._dtype = dtype - self.apply(apply_fn) + # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't + # work when using `init_meta_context`. + for m in self.modules(): + if isinstance(m, (DeviceDtypeModuleMixin, pl.LightningModule)): + apply_fn(m) + + apply_fn(self)