Skip to content

Commit 7ac1580

Browse files
author
Allard Hendriksen
authored
Add automatic GPU choice to trainer (#1426)
* Add automatic GPU choice to trainer This commit adds the `gpu_choice` parameter to Trainer. By default, this parameter is set to 'manual' which causes no observable difference in behavior. When `gpu_choice` is set to "auto" and `gpus` is an int, then the trainer will automatically allocate the first available GPU. This is especially useful when GPUs are configured to be in "exclusive mode", which means that only one process at a time can use them. * Rename gpu_choice -> auto_select_gpus
1 parent e79ae18 commit 7ac1580

File tree

4 files changed

+84
-4
lines changed

4 files changed

+84
-4
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
## [unreleased] - YYYY-MM-DD
8+
9+
### Added
10+
11+
- Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems.
12+
713
## [0.7.3] - 2020-04-09
814

915
### Added

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@
339339

340340
import os
341341
from abc import ABC, abstractmethod
342-
342+
import time
343+
import random
343344
import torch
344345

345346
from pytorch_lightning import _logger as log
@@ -648,3 +649,44 @@ def determine_root_gpu_device(gpus):
648649
root_gpu = gpus[0]
649650

650651
return root_gpu
652+
653+
654+
def retry_jittered_backoff(f, num_retries=5):
655+
# Based on:
656+
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
657+
cap = 1.0 # max sleep time is 1s
658+
base = 0.01 # initial sleep time is 10ms
659+
sleep = base # initial sleep time is 10ms
660+
661+
for i in range(num_retries):
662+
try:
663+
return f()
664+
except RuntimeError as e:
665+
if i == num_retries - 1:
666+
raise e
667+
else:
668+
continue
669+
time.sleep(sleep)
670+
sleep = min(cap, random.uniform(base, sleep * 3))
671+
672+
673+
def pick_single_gpu(exclude_gpus=[]):
674+
for i in range(torch.cuda.device_count()):
675+
if i in exclude_gpus:
676+
continue
677+
# Try to allocate on device:
678+
device = torch.device(f"cuda:{i}")
679+
try:
680+
torch.ones(1).to(device)
681+
except RuntimeError:
682+
continue
683+
return i
684+
raise RuntimeError("No GPUs available.")
685+
686+
687+
def pick_multiple_gpus(n):
688+
picked = []
689+
for _ in range(n):
690+
picked.append(pick_single_gpu(exclude_gpus=picked))
691+
692+
return picked

pytorch_lightning/trainer/trainer.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
2323
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8, TrainerDeprecatedAPITillVer0_9
2424
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
25-
from pytorch_lightning.trainer.distrib_parts import TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device
25+
from pytorch_lightning.trainer.distrib_parts import (
26+
TrainerDPMixin,
27+
parse_gpu_ids,
28+
determine_root_gpu_device,
29+
pick_multiple_gpus,
30+
)
2631
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
2732
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
2833
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
@@ -85,6 +90,7 @@ def __init__(
8590
process_position: int = 0,
8691
num_nodes: int = 1,
8792
gpus: Optional[Union[List[int], str, int]] = None,
93+
auto_select_gpus: bool = False,
8894
num_tpu_cores: Optional[int] = None,
8995
log_gpu_memory: Optional[str] = None,
9096
progress_bar_refresh_rate: int = 1,
@@ -158,6 +164,13 @@ def __init__(
158164
159165
gpus: Which GPUs to train on.
160166
167+
auto_select_gpus:
168+
169+
If enabled and `gpus` is an integer, pick available
170+
gpus automatically. This is especially useful when
171+
GPUs are configured to be in "exclusive mode", such
172+
that only one process at a time can access them.
173+
161174
num_tpu_cores: How many TPU cores to train on (1 or 8).
162175
163176
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
@@ -384,8 +397,12 @@ def __init__(
384397
self.accumulate_grad_batches = accumulate_grad_batches
385398
self.configure_accumulated_gradients(accumulate_grad_batches)
386399

387-
# allow int, string and gpu list
388-
self.gpus = gpus
400+
# for gpus allow int, string and gpu list
401+
if auto_select_gpus and isinstance(gpus, int):
402+
self.gpus = pick_multiple_gpus(gpus)
403+
else:
404+
self.gpus = gpus
405+
389406
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
390407
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
391408
self.root_device = torch.device("cpu")

tests/trainer/test_trainer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,3 +685,18 @@ def _optimizer_step(*args, **kwargs):
685685
model.prev_called_batch_idx = 0
686686

687687
trainer.fit(model)
688+
689+
690+
def test_gpu_choice(tmpdir):
691+
trainer_options = dict(
692+
default_save_path=tmpdir,
693+
)
694+
# Only run if CUDA is available
695+
if not torch.cuda.is_available():
696+
return
697+
698+
num_gpus = torch.cuda.device_count()
699+
Trainer(**trainer_options, gpus=num_gpus, auto_select_gpus=True)
700+
701+
with pytest.raises(RuntimeError, match=r'.*No GPUs available.*'):
702+
Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True)

0 commit comments

Comments
 (0)