Skip to content

Commit 50ecc4a

Browse files
authored
Fix root GPU property (#5908)
* Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator * Add missing tests back
1 parent 8f3947b commit 50ecc4a

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def __init__(
113113
self.gpus = pick_multiple_gpus(gpus)
114114

115115
self.parallel_device_ids = device_parser.parse_gpu_ids(self.gpus)
116-
self.root_gpu = device_parser.determine_root_gpu_device(self.parallel_device_ids)
117116

118117
self.set_distributed_mode()
119118
self.configure_slurm_ddp()
@@ -276,6 +275,10 @@ def parallel_devices(self):
276275
devices = [torch.device("cpu")] * self.num_processes
277276
return devices
278277

278+
@property
279+
def root_gpu(self) -> int:
280+
return self.accelerator.root_device.index
281+
279282
@property
280283
def is_using_torchelastic(self):
281284
te_flags_passed = "WORLD_SIZE" in os.environ and ("GROUP_RANK" in os.environ or "NODE_RANK" in os.environ)
@@ -375,7 +378,8 @@ def select_training_type_plugin(self):
375378
elif self.on_tpu:
376379
plugin = SingleTPUPlugin(self.tpu_id)
377380
else:
378-
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu"))
381+
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
382+
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
379383
return plugin
380384

381385
def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
@@ -525,7 +529,6 @@ def _set_horovod_backend(self):
525529
if self.on_gpu:
526530
# Horovod assigns one local GPU per process
527531
self.parallel_device_ids = list(range(hvd.local_size()))
528-
self.root_gpu = hvd.local_rank()
529532
else:
530533
self.num_processes = hvd.local_size()
531534

tests/models/test_gpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def mocked_device_count(monkeypatch):
6868
def device_count():
6969
return PRETEND_N_OF_GPUS
7070

71+
def is_available():
72+
return True
73+
74+
monkeypatch.setattr(torch.cuda, 'is_available', is_available)
7175
monkeypatch.setattr(torch.cuda, 'device_count', device_count)
7276

7377

0 commit comments

Comments
 (0)