Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350))


- Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350))


- Fixed Native FSDP calling distributed communications early when setting device within the `LightningModule` ([#13387](https://github.com/PyTorchLightning/pytorch-lightning/pull/13387))


-


Expand Down
14 changes: 8 additions & 6 deletions src/pytorch_lightning/core/mixins/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,16 @@ def half(self) -> Self:
def __update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None:
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
# work when using `init_meta_context`.
if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)):
return
def apply_fn(module: Union["DeviceDtypeModuleMixin", pl.LightningModule]) -> None:
if device is not None:
module._device = device
if dtype is not None:
module._dtype = dtype

self.apply(apply_fn)
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
# work when using `init_meta_context`.
for m in self.modules():
if isinstance(m, (DeviceDtypeModuleMixin, pl.LightningModule)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(m, (DeviceDtypeModuleMixin, pl.LightningModule)):
if isinstance(m, DeviceDtypeModuleMixin):

should be sufficient I guess as LightningModule inherits from DeviceDtypeModuleMixin:

https://github.com/Lightning-AI/lightning/blob/511f1a651506af2e12e15346de0a3715fed7e814/src/pytorch_lightning/core/module.py#L60

apply_fn(m)

apply_fn(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we immediately apply this on self only, there is no reason to have this defined as a local function :)