Skip to content

Commit 2d5db37

Browse files
author
Allard Hendriksen
committed
Add automatic GPU choice to trainer
This commit adds the `gpu_choice` parameter to Trainer. By default, this parameter is set to 'manual' which causes no observable difference in behavior. When `gpu_choice` is set to "auto" and `gpus` is an int, then the trainer will automatically allocate the first available GPU. This is especially useful when GPUs are configured to be in "exclusive mode", which means that only one process at a time can use them.
1 parent b5c6d0e commit 2d5db37

File tree

4 files changed

+85
-4
lines changed

4 files changed

+85
-4
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
## [unreleased] - YYYY-MM-DD
8+
9+
### Added
10+
11+
- Added `gpu_choice` to trainer which can enable automatically picking the first available GPU on exclusive mode systems.
12+
713
## [0.7.2] - 2020-04-07
814

915
### Added

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@
339339

340340
import os
341341
from abc import ABC, abstractmethod
342-
342+
import time
343+
import random
343344
import torch
344345

345346
from pytorch_lightning import _logger as log
@@ -646,3 +647,44 @@ def determine_root_gpu_device(gpus):
646647
root_gpu = gpus[0]
647648

648649
return root_gpu
650+
651+
652+
def retry_jittered_backoff(f, num_retries=5):
653+
# Based on:
654+
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
655+
cap = 1.0 # max sleep time is 1s
656+
base = 0.01 # initial sleep time is 10ms
657+
sleep = base # initial sleep time is 10ms
658+
659+
for i in range(num_retries):
660+
try:
661+
return f()
662+
except RuntimeError as e:
663+
if i == num_retries - 1:
664+
raise e
665+
else:
666+
continue
667+
time.sleep(sleep)
668+
sleep = min(cap, random.uniform(base, sleep * 3))
669+
670+
671+
def pick_single_gpu(exclude_gpus=[]):
672+
for i in range(torch.cuda.device_count()):
673+
if i in exclude_gpus:
674+
continue
675+
# Try to allocate on device:
676+
device = torch.device(f"cuda:{i}")
677+
try:
678+
torch.ones(1).to(device)
679+
except RuntimeError:
680+
continue
681+
return i
682+
raise RuntimeError("No GPUs available.")
683+
684+
685+
def pick_multiple_gpus(n):
686+
picked = []
687+
for _ in range(n):
688+
picked.append(pick_single_gpu(exclude_gpus=picked))
689+
690+
return picked

pytorch_lightning/trainer/trainer.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
2424
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8, TrainerDeprecatedAPITillVer0_9
2525
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
26-
from pytorch_lightning.trainer.distrib_parts import TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device
26+
from pytorch_lightning.trainer.distrib_parts import (
27+
TrainerDPMixin,
28+
parse_gpu_ids,
29+
determine_root_gpu_device,
30+
pick_multiple_gpus,
31+
)
2732
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
2833
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
2934
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
@@ -85,6 +90,7 @@ def __init__(
8590
process_position: int = 0,
8691
num_nodes: int = 1,
8792
gpus: Optional[Union[List[int], str, int]] = None,
93+
gpu_choice: str = 'manual',
8894
num_tpu_cores: Optional[int] = None,
8995
log_gpu_memory: Optional[str] = None,
9096
progress_bar_refresh_rate: int = 1,
@@ -158,6 +164,14 @@ def __init__(
158164
159165
gpus: Which GPUs to train on.
160166
167+
gpu_choice: 'manual' (default) or 'auto'.
168+
169+
If 'auto' and `gpus` is an integer, pick the first
170+
available gpus automatically. This is especially
171+
useful when GPUs are configured to be in "exclusive
172+
mode", which means that only one process at a time can
173+
use them.
174+
161175
num_tpu_cores: How many TPU cores to train on (1 or 8).
162176
163177
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
@@ -385,8 +399,12 @@ def __init__(
385399
self.accumulate_grad_batches = accumulate_grad_batches
386400
self.configure_accumulated_gradients(accumulate_grad_batches)
387401

388-
# allow int, string and gpu list
389-
self.gpus = gpus
402+
# for gpus allow int, string and gpu list
403+
if gpu_choice == "auto" and isinstance(gpus, int):
404+
self.gpus = pick_multiple_gpus(gpus)
405+
else:
406+
self.gpus = gpus
407+
390408
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
391409
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
392410
self.root_device = torch.device("cpu")

tests/trainer/test_trainer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,3 +658,18 @@ def on_batch_start(self, trainer, pl_module):
658658
assert not trainer.interrupted
659659
trainer.fit(model)
660660
assert trainer.interrupted
661+
662+
663+
def test_gpu_choice(tmpdir):
664+
trainer_options = dict(
665+
default_save_path=tmpdir,
666+
)
667+
# Only run if CUDA is available
668+
if not torch.cuda.is_available():
669+
return
670+
671+
num_gpus = torch.cuda.device_count()
672+
Trainer(**trainer_options, gpus=num_gpus, gpu_choice="auto")
673+
674+
with pytest.raises(RuntimeError, match=r'.*No GPUs available.*'):
675+
Trainer(**trainer_options, gpus=num_gpus + 1, gpu_choice="auto")

0 commit comments

Comments
 (0)