Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug that caused `ddp_find_unused_parameters` to be set `False`, whereas the intended default is `True` ([#14095](https://github.com/Lightning-AI/lightning/pull/14095))


- Fixed the device placement when `LightningModule.cuda()` gets called without specifying a device index and the current cuda device was not 0 ([#14128](https://github.com/Lightning-AI/lightning/pull/14128))


## [1.7.0] - 2022-08-02

### Added
Expand Down
10 changes: 6 additions & 4 deletions src/pytorch_lightning/core/mixins/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,16 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # ty
while being optimized.

Arguments:
device: if specified, all parameters will be
copied to that device
device: If specified, all parameters will be copied to that device. If `None`, the current CUDA device
index will be used.

Returns:
Module: self
"""
if device is None or isinstance(device, int):
device = torch.device("cuda", index=(device or 0))
if device is None:
device = torch.device("cuda", torch.cuda.current_device())
elif isinstance(device, int):
device = torch.device("cuda", index=device)
self.__update_properties(device=device)
return super().cuda(device=device)

Expand Down
24 changes: 23 additions & 1 deletion tests/tests_pytorch/utilities/test_dtype_device_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_submodules_multi_gpu_ddp_spawn(tmpdir):
],
)
@RunIf(min_cuda_gpus=1)
def test_gpu_cuda_device(device):
def test_cuda_device(device):
model = TopModule()

model.cuda(device)
Expand All @@ -122,3 +122,25 @@ def test_gpu_cuda_device(device):
assert device.type == "cuda"
assert device.index is not None
assert device.index == torch.cuda.current_device()


@RunIf(min_cuda_gpus=2)
def test_cuda_current_device():
"""Test that calling .cuda() moves the model to the correct device and respects current cuda device setting."""

class CudaModule(DeviceDtypeModuleMixin):
def __init__(self):
super().__init__()
self.layer = nn.Linear(1, 1)

model = CudaModule()

torch.cuda.set_device(0)
model.cuda(1)
assert model.device == torch.device("cuda", 1)
assert model.layer.weight.device == torch.device("cuda", 1)

torch.cuda.set_device(1)
model.cuda() # model is already on device 1, and calling .cuda() without device index should not move model
assert model.device == torch.device("cuda", 1)
assert model.layer.weight.device == torch.device("cuda", 1)