Skip to content

Commit 817ffe3

Browse files
awaelchlicarmocca
authored andcommitted
Enable self.device access in setup hook (#18021)
1 parent a6ce061 commit 817ffe3

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
111111
- Dropped support for `wandb` versions older than 0.12.0 in `WandbLogger` ([#17876](https://github.com/Lightning-AI/lightning/pull/17876))
112112

113113

114+
- During `LightningModule.setup()`, the `self.device` now returns the device the module will be placed on instead of `cpu` ([#18021](https://github.com/Lightning-AI/lightning/pull/18021))
115+
116+
114117
### Deprecated
115118

116119
- Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))

src/lightning/pytorch/trainer/call.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from packaging.version import Version
1919

2020
import lightning.pytorch as pl
21+
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
2122
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
2223
from lightning.pytorch.trainer.states import TrainerStatus
2324
from lightning.pytorch.utilities.exceptions import _TunerExitException
@@ -72,6 +73,12 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None:
7273
assert trainer.state.fn is not None
7374
fn = trainer.state.fn
7475

76+
# It is too early to move the model to the device, but we fake the `LightningModule.device` property
77+
# so the user can access it in the `LightningModule.setup` hook
78+
for module in trainer.lightning_module.modules():
79+
if isinstance(module, _DeviceDtypeModuleMixin):
80+
module._device = trainer.strategy.root_device
81+
7582
# Trigger lazy creation of experiment in loggers so loggers have their metadata available
7683
for logger in trainer.loggers:
7784
_ = logger.experiment

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,16 +1478,23 @@ def configure_optimizers(self):
14781478
@pytest.mark.parametrize(
14791479
"accelerator",
14801480
[
1481-
pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)),
1481+
pytest.param("cuda", marks=RunIf(min_cuda_gpus=1)),
14821482
pytest.param("mps", marks=RunIf(mps=True)),
14831483
],
14841484
)
1485-
def test_setup_hook_move_to_device_correctly(tmpdir, accelerator):
1486-
"""Verify that if a user defines a layer in the setup hook function, this is moved to the correct device."""
1485+
def test_setup_hook_device_and_layers(tmpdir, accelerator):
1486+
"""Test `LightningModule.device` access and creation of layers in `LightningModule.setup` hook."""
1487+
expected_device = torch.device(accelerator, 0)
14871488

14881489
class TestModel(BoringModel):
14891490
def setup(self, stage: str) -> None:
1491+
# The `self.device` attribute already points to what device the model will land on
1492+
assert self.device == expected_device
1493+
# However, the model parameters have not yet been moved to that device
1494+
assert self.layer.weight.device == torch.device("cpu")
1495+
# Can create new layers in this hook (on CPU)
14901496
self.new_layer = torch.nn.Linear(2, 2)
1497+
assert self.new_layer.weight.device == torch.device("cpu")
14911498

14921499
def training_step(self, batch, batch_idx):
14931500
output = self.layer(batch)

0 commit comments

Comments
 (0)