diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 7d5a00523ef9e..2d5b7f25c01f7 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -37,6 +37,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda try: from apex import amp @@ -471,9 +472,46 @@ def pick_single_gpu(exclude_gpus: list): raise RuntimeError("No GPUs available.") -def pick_multiple_gpus(nb): +def pick_single_gpu_realist_workload(exclude_gpus: list, model: LightningModule) -> int: + batch = next(iter(model.train_dataloader())) + for i in range(torch.cuda.device_count()): + if i in exclude_gpus: + continue + # Try to allocate on device: + device = torch.device(f"cuda:{i}") + try: + with torch.set_grad_enabled(True): + model_device = model.to(device) + batch_device = tuple(itup.to(device) for itup in batch) + model_device.train() # record grads + model_device.training_step(batch_device, 0) + except RuntimeError as exception: + if is_oom_error(exception): # clean after the failed attempt + garbage_collection_cuda() + else: + raise + continue + return i + raise RuntimeError("None of the GPUs could accept the given workload.") + + +def pick_multiple_gpus(nb: int, model: Optional[LightningModule] = None) -> list: + r""" Pick available GPUs + + Args: + nb: the max number of GPU to pick + model: (optional) a LightningModule with model and train_loader attached + + Return: + a list of GPU index availables + + Note: if model is not None, a GPU is considered available if it is able to run in `train` mode a batch + """ picked = [] for _ in range(nb): - picked.append(pick_single_gpu(exclude_gpus=picked)) - + if not model: + picked.append(pick_single_gpu(exclude_gpus=picked)) + else: + assert hasattr(model, 'train_dataloader') + picked.append(pick_single_gpu_realist_workload(exclude_gpus=picked, model=model)) return picked diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4b342328df297..ba58cee894c67 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -493,6 +493,7 @@ def __init__( self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) + self.auto_select_gpus = auto_select_gpus # for gpus allow int, string and gpu list if auto_select_gpus and isinstance(gpus, int): self.gpus = pick_multiple_gpus(gpus) @@ -970,6 +971,10 @@ def fit( model.prepare_data() self._is_data_prepared = True + # Run updated auto GPU selection with the actual model/input data + if self.auto_select_gpus: + self.update_auto_selected_gpus(model) + # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): @@ -1407,6 +1412,15 @@ def call_setup_hook(self, model): self.setup(stage_name) model.setup(stage_name) + def update_auto_selected_gpus(self, model): + # Called when model/data is known. Ensure the GPU used have enough VRAM. + self.gpus = pick_multiple_gpus(len(self.gpus), model) # At most the current number of GPUs + + self.data_parallel_device_ids = _parse_gpu_ids(self.gpus) + self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) + self.on_gpu = True if (self.data_parallel_device_ids and torch.cuda.is_available()) else False + self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) + class _PatchDataLoader(object): r""" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d6641c2f7ab24..2f9d3239fac48 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -748,13 +748,11 @@ def _optimizer_step(*args, **kwargs): trainer.fit(model) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_gpu_choice(tmpdir): trainer_options = dict( default_root_dir=tmpdir, ) - # Only run if CUDA is available - if not torch.cuda.is_available(): - return num_gpus = torch.cuda.device_count() Trainer(**trainer_options, gpus=num_gpus, auto_select_gpus=True) @@ -763,6 +761,29 @@ def test_gpu_choice(tmpdir): Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_gpu_choice_workload(tmpdir): + """ Test if the training is not allowed by `pick_single_gpu_realist_workload` by using an overly large batch-size + TODO: not adapted for new gen GPUs with very large VRAM + """ + class CurrentModel(EvalModelTemplate): + def train_dataloader(self): + # Aim to overload the VRAM with the whole MNIST + from tests.base.datasets import MNIST + return torch.utils.data.DataLoader(MNIST(root=self.data_root, train=True, download=True), batch_size=60000) + + model = CurrentModel() + trainer = Trainer( + max_steps=1, + max_epochs=1, + default_root_dir=tmpdir, + gpus=1, + auto_select_gpus=True, + ) + with pytest.raises(RuntimeError, match=r'.*None of the GPUs could accept the given workload.*'): + trainer.fit(model) + + @pytest.mark.parametrize(['tpu_cores', 'expected_tpu_id', 'error_expected'], [ pytest.param(1, None, False), pytest.param(8, None, False),