Skip to content

Commit c76c381

Browse files
committed
Fix device placement when .cuda() called without specifying index (#14128)
1 parent afe40c0 commit c76c381

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525
- Fixed a bug that caused `ddp_find_unused_parameters` to be set `False`, whereas the intended default is `True` ([#14095](https://github.com/Lightning-AI/lightning/pull/14095))
2626

2727

28+
- Fixed the device placement when `LightningModule.cuda()` gets called without specifying a device index and the current cuda device was not 0 ([#14128](https://github.com/Lightning-AI/lightning/pull/14128))
29+
30+
2831
## [1.7.0] - 2022-08-02
2932

3033
### Added

src/pytorch_lightning/core/mixins/device_dtype_mixin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,16 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # ty
118118
while being optimized.
119119
120120
Arguments:
121-
device: if specified, all parameters will be
122-
copied to that device
121+
device: If specified, all parameters will be copied to that device. If `None`, the current CUDA device
122+
index will be used.
123123
124124
Returns:
125125
Module: self
126126
"""
127-
if device is None or isinstance(device, int):
128-
device = torch.device("cuda", index=(device or 0))
127+
if device is None:
128+
device = torch.device("cuda", torch.cuda.current_device())
129+
elif isinstance(device, int):
130+
device = torch.device("cuda", index=device)
129131
self.__update_properties(device=device)
130132
return super().cuda(device=device)
131133

tests/tests_pytorch/utilities/test_dtype_device_mixin.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_submodules_multi_gpu_ddp_spawn(tmpdir):
113113
],
114114
)
115115
@RunIf(min_cuda_gpus=1)
116-
def test_gpu_cuda_device(device):
116+
def test_cuda_device(device):
117117
model = TopModule()
118118

119119
model.cuda(device)
@@ -122,3 +122,25 @@ def test_gpu_cuda_device(device):
122122
assert device.type == "cuda"
123123
assert device.index is not None
124124
assert device.index == torch.cuda.current_device()
125+
126+
127+
@RunIf(min_cuda_gpus=2)
128+
def test_cuda_current_device():
129+
"""Test that calling .cuda() moves the model to the correct device and respects current cuda device setting."""
130+
131+
class CudaModule(DeviceDtypeModuleMixin):
132+
def __init__(self):
133+
super().__init__()
134+
self.layer = nn.Linear(1, 1)
135+
136+
model = CudaModule()
137+
138+
torch.cuda.set_device(0)
139+
model.cuda(1)
140+
assert model.device == torch.device("cuda", 1)
141+
assert model.layer.weight.device == torch.device("cuda", 1)
142+
143+
torch.cuda.set_device(1)
144+
model.cuda() # model is already on device 1, and calling .cuda() without device index should not move model
145+
assert model.device == torch.device("cuda", 1)
146+
assert model.layer.weight.device == torch.device("cuda", 1)

0 commit comments

Comments
 (0)