Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [unreleased] - YYYY-MM-DD

### Added

- Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems.

## [0.7.3] - 2020-04-09

### Added
Expand Down
44 changes: 43 additions & 1 deletion pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@

import os
from abc import ABC, abstractmethod

import time
import random
import torch

from pytorch_lightning import _logger as log
Expand Down Expand Up @@ -648,3 +649,44 @@ def determine_root_gpu_device(gpus):
root_gpu = gpus[0]

return root_gpu


def retry_jittered_backoff(f, num_retries=5):
# Based on:
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
cap = 1.0 # max sleep time is 1s
base = 0.01 # initial sleep time is 10ms
sleep = base # initial sleep time is 10ms

for i in range(num_retries):
try:
return f()
except RuntimeError as e:
if i == num_retries - 1:
raise e
else:
continue
time.sleep(sleep)
sleep = min(cap, random.uniform(base, sleep * 3))


def pick_single_gpu(exclude_gpus=[]):
for i in range(torch.cuda.device_count()):
if i in exclude_gpus:
continue
# Try to allocate on device:
device = torch.device(f"cuda:{i}")
try:
torch.ones(1).to(device)
except RuntimeError:
continue
return i
raise RuntimeError("No GPUs available.")


def pick_multiple_gpus(n):
picked = []
for _ in range(n):
picked.append(pick_single_gpu(exclude_gpus=picked))

return picked
23 changes: 20 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8, TrainerDeprecatedAPITillVer0_9
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.trainer.distrib_parts import (
TrainerDPMixin,
parse_gpu_ids,
determine_root_gpu_device,
pick_multiple_gpus,
)
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -85,6 +90,7 @@ def __init__(
process_position: int = 0,
num_nodes: int = 1,
gpus: Optional[Union[List[int], str, int]] = None,
auto_select_gpus: bool = False,
num_tpu_cores: Optional[int] = None,
log_gpu_memory: Optional[str] = None,
progress_bar_refresh_rate: int = 1,
Expand Down Expand Up @@ -158,6 +164,13 @@ def __init__(

gpus: Which GPUs to train on.

auto_select_gpus:

If enabled and `gpus` is an integer, pick available
gpus automatically. This is especially useful when
GPUs are configured to be in "exclusive mode", such
that only one process at a time can access them.

num_tpu_cores: How many TPU cores to train on (1 or 8).

log_gpu_memory: None, 'min_max', 'all'. Might slow performance
Expand Down Expand Up @@ -384,8 +397,12 @@ def __init__(
self.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

# allow int, string and gpu list
self.gpus = gpus
# for gpus allow int, string and gpu list
if auto_select_gpus and isinstance(gpus, int):
self.gpus = pick_multiple_gpus(gpus)
else:
self.gpus = gpus

self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
self.root_device = torch.device("cpu")
Expand Down
15 changes: 15 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,3 +685,18 @@ def _optimizer_step(*args, **kwargs):
model.prev_called_batch_idx = 0

trainer.fit(model)


def test_gpu_choice(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
)
# Only run if CUDA is available
if not torch.cuda.is_available():
return

num_gpus = torch.cuda.device_count()
Trainer(**trainer_options, gpus=num_gpus, auto_select_gpus=True)

with pytest.raises(RuntimeError, match=r'.*No GPUs available.*'):
Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True)