Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda

try:
from apex import amp
Expand Down Expand Up @@ -471,9 +472,46 @@ def pick_single_gpu(exclude_gpus: list):
raise RuntimeError("No GPUs available.")


def pick_multiple_gpus(nb):
def pick_single_gpu_realist_workload(exclude_gpus: list, model: LightningModule) -> int:
batch = next(iter(model.train_dataloader()))
for i in range(torch.cuda.device_count()):
if i in exclude_gpus:
continue
# Try to allocate on device:
device = torch.device(f"cuda:{i}")
try:
with torch.set_grad_enabled(True):
model_device = model.to(device)
batch_device = tuple(itup.to(device) for itup in batch)
model_device.train() # record grads
model_device.training_step(batch_device, 0)
except RuntimeError as exception:
if is_oom_error(exception): # clean after the failed attempt
garbage_collection_cuda()
else:
raise
continue
return i
raise RuntimeError("None of the GPUs could accept the given workload.")


def pick_multiple_gpus(nb: int, model: Optional[LightningModule] = None) -> list:
r""" Pick available GPUs

Args:
nb: the max number of GPU to pick
model: (optional) a LightningModule with model and train_loader attached

Return:
a list of GPU index availables

Note: if model is not None, a GPU is considered available if it is able to run in `train` mode a batch
"""
picked = []
for _ in range(nb):
picked.append(pick_single_gpu(exclude_gpus=picked))

if not model:
picked.append(pick_single_gpu(exclude_gpus=picked))
else:
assert hasattr(model, 'train_dataloader')
picked.append(pick_single_gpu_realist_workload(exclude_gpus=picked, model=model))
return picked
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def __init__(
self.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

self.auto_select_gpus = auto_select_gpus
# for gpus allow int, string and gpu list
if auto_select_gpus and isinstance(gpus, int):
self.gpus = pick_multiple_gpus(gpus)
Expand Down Expand Up @@ -970,6 +971,10 @@ def fit(
model.prepare_data()
self._is_data_prepared = True

# Run updated auto GPU selection with the actual model/input data
if self.auto_select_gpus:
self.update_auto_selected_gpus(model)

# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
Expand Down Expand Up @@ -1407,6 +1412,15 @@ def call_setup_hook(self, model):
self.setup(stage_name)
model.setup(stage_name)

def update_auto_selected_gpus(self, model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def update_auto_selected_gpus(self, model):
def update_auto_selected_gpus(self, model: LightningModule):

# Called when model/data is known. Ensure the GPU used have enough VRAM.
self.gpus = pick_multiple_gpus(len(self.gpus), model) # At most the current number of GPUs

self.data_parallel_device_ids = _parse_gpu_ids(self.gpus)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this function be used during __init__ to not have duplicate code?

self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
self.on_gpu = True if (self.data_parallel_device_ids and torch.cuda.is_available()) else False
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)


class _PatchDataLoader(object):
r"""
Expand Down
27 changes: 24 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,13 +748,11 @@ def _optimizer_step(*args, **kwargs):
trainer.fit(model)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_gpu_choice(tmpdir):
trainer_options = dict(
default_root_dir=tmpdir,
)
# Only run if CUDA is available
if not torch.cuda.is_available():
return

num_gpus = torch.cuda.device_count()
Trainer(**trainer_options, gpus=num_gpus, auto_select_gpus=True)
Expand All @@ -763,6 +761,29 @@ def test_gpu_choice(tmpdir):
Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_gpu_choice_workload(tmpdir):
""" Test if the training is not allowed by `pick_single_gpu_realist_workload` by using an overly large batch-size
TODO: not adapted for new gen GPUs with very large VRAM
"""
class CurrentModel(EvalModelTemplate):
def train_dataloader(self):
# Aim to overload the VRAM with the whole MNIST
from tests.base.datasets import MNIST
return torch.utils.data.DataLoader(MNIST(root=self.data_root, train=True, download=True), batch_size=60000)

model = CurrentModel()
trainer = Trainer(
max_steps=1,
max_epochs=1,
default_root_dir=tmpdir,
gpus=1,
auto_select_gpus=True,
)
with pytest.raises(RuntimeError, match=r'.*None of the GPUs could accept the given workload.*'):
trainer.fit(model)


@pytest.mark.parametrize(['tpu_cores', 'expected_tpu_id', 'error_expected'], [
pytest.param(1, None, False),
pytest.param(8, None, False),
Expand Down