From c5065ed3ce10eedad6669df25d3caa053484f713 Mon Sep 17 00:00:00 2001 From: sebastienwood Date: Wed, 5 Aug 2020 15:41:39 -0400 Subject: [PATCH 1/7] init --- pytorch_lightning/trainer/distrib_parts.py | 29 ++++++++++++++++++++-- pytorch_lightning/trainer/trainer.py | 13 ++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 7d5a00523ef9e..b8c1bceddc759 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -37,6 +37,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda try: from apex import amp @@ -471,9 +472,33 @@ def pick_single_gpu(exclude_gpus: list): raise RuntimeError("No GPUs available.") -def pick_multiple_gpus(nb): +def pick_single_gpu_realist_workload(exclude_gpus: list, model, batch): + for i in range(torch.cuda.device_count()): + if i in exclude_gpus: + continue + # Try to allocate on device: + device = torch.device(f"cuda:{i}") + try: + model_device = model.to(device) + batch_device = batch.to(device) + model_device.train() # record grads + model_device(batch_device) + except RuntimeError as exception: + if is_oom_error(exception): # clean after the failed attempt + garbage_collection_cuda() + else: raise + continue + return i + raise RuntimeError("No GPUs available.") + + +def pick_multiple_gpus(nb, model=None): picked = [] for _ in range(nb): - picked.append(pick_single_gpu(exclude_gpus=picked)) + if not model: picked.append(pick_single_gpu(exclude_gpus=picked)) + else : + assert hasattr(model, 'train_dataloader') + picked.append(pick_single_gpu_realist_workload(exclude_gpus=picked, model=model, batch=next(iter(model.train_dataloader)))) + if len(picked) < 1: raise RuntimeError("None of the GPUs could accept the given workload.") return picked diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 323f2866b1cab..52219cb84a99d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -487,6 +487,7 @@ def __init__( self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) + self.auto_select_gpus = auto_select_gpus # for gpus allow int, string and gpu list if auto_select_gpus and isinstance(gpus, int): self.gpus = pick_multiple_gpus(gpus) @@ -964,6 +965,9 @@ def fit( model.prepare_data() self._is_data_prepared = True + # Run updated auto GPU selection with the actual model/input data + if self.auto_select_gpus: self.update_auto_selected_gpus(model) + # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): @@ -1401,6 +1405,15 @@ def call_setup_hook(self, model): self.setup(stage_name) model.setup(stage_name) + def update_auto_selected_gpus(self, model): + # Called when model/data is known. Ensure the GPU used have enough VRAM. + self.gpus = pick_multiple_gpus(len(self.gpus), model) # At most the current number of GPUs + + self.data_parallel_device_ids = _parse_gpu_ids(self.gpus) + self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) + self.on_gpu = True if (self.data_parallel_device_ids and torch.cuda.is_available()) else False + self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) + class _PatchDataLoader(object): r""" From ac2ea563d0d021dfdd7fe84e1a2de2b768659060 Mon Sep 17 00:00:00 2001 From: sebastienwood Date: Wed, 5 Aug 2020 16:19:58 -0400 Subject: [PATCH 2/7] smol doc --- pytorch_lightning/trainer/distrib_parts.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index b8c1bceddc759..cfe9ce52e2f0c 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -25,7 +25,7 @@ import random import torch from torch.optim.lr_scheduler import _LRScheduler -from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence +from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence, NoneType from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning import _logger as log @@ -472,12 +472,13 @@ def pick_single_gpu(exclude_gpus: list): raise RuntimeError("No GPUs available.") -def pick_single_gpu_realist_workload(exclude_gpus: list, model, batch): +def pick_single_gpu_realist_workload(exclude_gpus: list, model:LightningModule): for i in range(torch.cuda.device_count()): if i in exclude_gpus: continue # Try to allocate on device: device = torch.device(f"cuda:{i}") + batch=next(iter(model.train_dataloader)) try: model_device = model.to(device) batch_device = batch.to(device) @@ -492,13 +493,24 @@ def pick_single_gpu_realist_workload(exclude_gpus: list, model, batch): raise RuntimeError("No GPUs available.") -def pick_multiple_gpus(nb, model=None): +def pick_multiple_gpus(nb:int, model:Optional[LightningModule] = None) -> list: + r""" Pick available GPUs + + Args: + nb: the max number of GPU to pick + model: (optional) a LightningModule with model and train_loader attached + + Return: + a list of GPU index availables + + Note: if model is not None, a GPU is considered available if it is able to run in `train` mode a batch + """ picked = [] for _ in range(nb): if not model: picked.append(pick_single_gpu(exclude_gpus=picked)) else : assert hasattr(model, 'train_dataloader') - picked.append(pick_single_gpu_realist_workload(exclude_gpus=picked, model=model, batch=next(iter(model.train_dataloader)))) + picked.append(pick_single_gpu_realist_workload(exclude_gpus=picked, model=model)) if len(picked) < 1: raise RuntimeError("None of the GPUs could accept the given workload.") return picked From 74ae43f468fefef0cf19f8e36c6653190e1cf71a Mon Sep 17 00:00:00 2001 From: Sebastien Henwood Date: Wed, 5 Aug 2020 17:49:40 -0400 Subject: [PATCH 3/7] small updates --- pytorch_lightning/trainer/distrib_parts.py | 9 +++++---- tests/trainer/test_trainer.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index cfe9ce52e2f0c..f777c4093ce05 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -480,10 +480,11 @@ def pick_single_gpu_realist_workload(exclude_gpus: list, model:LightningModule): device = torch.device(f"cuda:{i}") batch=next(iter(model.train_dataloader)) try: - model_device = model.to(device) - batch_device = batch.to(device) - model_device.train() # record grads - model_device(batch_device) + with torch.set_grad_enabled(True): + model_device = model.to(device) + batch_device = batch.to(device) + model_device.train() # record grads + model_device(batch_device) except RuntimeError as exception: if is_oom_error(exception): # clean after the failed attempt garbage_collection_cuda() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c7652ebecf3f9..1553af6680b48 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -763,6 +763,20 @@ def test_gpu_choice(tmpdir): Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True) +def test_gpu_choice_workload(tmpdir): + """ Test if the training is not allowed by `pick_single_gpu_realist_workload` by using an overly large batch-size """ + model = EvalModelTemplate() + model.train_dataloader.batch_size = 50000 # the size of MNIST trainset + trainer = Trainer( + max_steps=1, + max_epochs=1, + default_root_dir=tmpdir, + ) + + with pytest.raises(RuntimeError, match=r'.*None of the GPUs could accept the given workload.*'): + trainer.fit(model) + + @pytest.mark.parametrize(['tpu_cores', 'expected_tpu_id', 'error_expected'], [ pytest.param(1, None, False), pytest.param(8, None, False), From 484f09ba58fd7aa62be22b9de98e0ef41a5274bd Mon Sep 17 00:00:00 2001 From: Sebastien Henwood Date: Wed, 5 Aug 2020 18:04:44 -0400 Subject: [PATCH 4/7] fix useless import --- pytorch_lightning/trainer/distrib_parts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index f777c4093ce05..ad7112657313a 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -25,7 +25,7 @@ import random import torch from torch.optim.lr_scheduler import _LRScheduler -from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence, NoneType +from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning import _logger as log From 2fa02bdaa553611ba87bfcc11281151134dd3e9f Mon Sep 17 00:00:00 2001 From: Sebastien Henwood Date: Thu, 6 Aug 2020 12:16:20 -0400 Subject: [PATCH 5/7] fixes + flake8 --- pytorch_lightning/trainer/distrib_parts.py | 34 ++++++++++++-------- pytorch_lightning/trainer/trainer.py | 9 +++--- pytorch_lightning/trainer/training_tricks.py | 2 +- tests/trainer/test_trainer.py | 23 ++++++++----- 4 files changed, 42 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index ad7112657313a..29e0c3f0fe842 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -472,31 +472,32 @@ def pick_single_gpu(exclude_gpus: list): raise RuntimeError("No GPUs available.") -def pick_single_gpu_realist_workload(exclude_gpus: list, model:LightningModule): +def pick_single_gpu_realist_workload(exclude_gpus: list, model:LightningModule) -> int: + batch = next(iter(model.train_dataloader)) for i in range(torch.cuda.device_count()): if i in exclude_gpus: continue # Try to allocate on device: device = torch.device(f"cuda:{i}") - batch=next(iter(model.train_dataloader)) try: with torch.set_grad_enabled(True): - model_device = model.to(device) + model_device = model.to(device) batch_device = batch.to(device) - model_device.train() # record grads + model_device.train() # record grads model_device(batch_device) except RuntimeError as exception: - if is_oom_error(exception): # clean after the failed attempt + if is_oom_error(exception): # clean after the failed attempt garbage_collection_cuda() - else: raise + else: + raise continue return i - raise RuntimeError("No GPUs available.") + return -1 def pick_multiple_gpus(nb:int, model:Optional[LightningModule] = None) -> list: - r""" Pick available GPUs - + r""" Pick available GPUs + Args: nb: the max number of GPU to pick model: (optional) a LightningModule with model and train_loader attached @@ -508,10 +509,17 @@ def pick_multiple_gpus(nb:int, model:Optional[LightningModule] = None) -> list: """ picked = [] for _ in range(nb): - if not model: picked.append(pick_single_gpu(exclude_gpus=picked)) - else : + if not model: + picked.append(pick_single_gpu(exclude_gpus=picked)) + else: assert hasattr(model, 'train_dataloader') - picked.append(pick_single_gpu_realist_workload(exclude_gpus=picked, model=model)) + pick = pick_single_gpu_realist_workload(exclude_gpus=picked, model=model) + if pick != -1: + picked.append(pick) + else: + print(f'There were less than {nb} GPUs capable for this workload') + break - if len(picked) < 1: raise RuntimeError("None of the GPUs could accept the given workload.") + if len(picked) < 1: + raise RuntimeError("None of the GPUs could accept the given workload.") return picked diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 52219cb84a99d..487cef667855f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -966,7 +966,8 @@ def fit( self._is_data_prepared = True # Run updated auto GPU selection with the actual model/input data - if self.auto_select_gpus: self.update_auto_selected_gpus(model) + if self.auto_select_gpus: + self.update_auto_selected_gpus(model) # Run auto batch size scaling if self.auto_scale_batch_size: @@ -1406,12 +1407,12 @@ def call_setup_hook(self, model): model.setup(stage_name) def update_auto_selected_gpus(self, model): - # Called when model/data is known. Ensure the GPU used have enough VRAM. - self.gpus = pick_multiple_gpus(len(self.gpus), model) # At most the current number of GPUs + # Called when model/data is known. Ensure the GPU used have enough VRAM. + self.gpus = pick_multiple_gpus(len(self.gpus), model) # At most the current number of GPUs self.data_parallel_device_ids = _parse_gpu_ids(self.gpus) self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) - self.on_gpu = True if (self.data_parallel_device_ids and torch.cuda.is_available()) else False + self.on_gpu = True if (self.data_parallel_device_ids and torch.cuda.is_available()) else False self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 20eeff3878cc2..5bea8fbc1a3cd 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -269,7 +269,7 @@ def _adjust_batch_size(trainer, if hasattr(model, batch_arg_name): setattr(model, batch_arg_name, value) else: - setattr(model.hparams, batch_arg_name, value) + setattr(model.hparams, batch_arg_name, value) new_size = value if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1553af6680b48..ae7d3b523132e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -748,13 +748,11 @@ def _optimizer_step(*args, **kwargs): trainer.fit(model) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_gpu_choice(tmpdir): trainer_options = dict( default_root_dir=tmpdir, ) - # Only run if CUDA is available - if not torch.cuda.is_available(): - return num_gpus = torch.cuda.device_count() Trainer(**trainer_options, gpus=num_gpus, auto_select_gpus=True) @@ -763,19 +761,28 @@ def test_gpu_choice(tmpdir): Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_gpu_choice_workload(tmpdir): - """ Test if the training is not allowed by `pick_single_gpu_realist_workload` by using an overly large batch-size """ - model = EvalModelTemplate() - model.train_dataloader.batch_size = 50000 # the size of MNIST trainset + """ Test if the training is not allowed by `pick_single_gpu_realist_workload` by using an overly large batch-size + TODO: not adapted for new gen GPUs with very large VRAM + """ + class CurrentModel(EvalModelTemplate): + def train_dataloader(self): + # Aim to overload the VRAM with the whole MNIST + from tests.base.dataloaders import MNIST + return torch.utils.data.DataLoader(MNIST(root=self.data_root, train=True, download=True), batch_size=60000) + + model = CurrentModel() trainer = Trainer( max_steps=1, max_epochs=1, default_root_dir=tmpdir, + gpus=1, + auto_select_gpus=True, ) - with pytest.raises(RuntimeError, match=r'.*None of the GPUs could accept the given workload.*'): trainer.fit(model) - + @pytest.mark.parametrize(['tpu_cores', 'expected_tpu_id', 'error_expected'], [ pytest.param(1, None, False), From 5693187efdb6b419c3b92bbb37f4d649e419f167 Mon Sep 17 00:00:00 2001 From: sebastienwood Date: Thu, 6 Aug 2020 12:32:32 -0400 Subject: [PATCH 6/7] Update distrib_parts.py pep8 fix --- pytorch_lightning/trainer/distrib_parts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 29e0c3f0fe842..dbd7964740826 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -472,7 +472,7 @@ def pick_single_gpu(exclude_gpus: list): raise RuntimeError("No GPUs available.") -def pick_single_gpu_realist_workload(exclude_gpus: list, model:LightningModule) -> int: +def pick_single_gpu_realist_workload(exclude_gpus: list, model: LightningModule) -> int: batch = next(iter(model.train_dataloader)) for i in range(torch.cuda.device_count()): if i in exclude_gpus: @@ -495,7 +495,7 @@ def pick_single_gpu_realist_workload(exclude_gpus: list, model:LightningModule) return -1 -def pick_multiple_gpus(nb:int, model:Optional[LightningModule] = None) -> list: +def pick_multiple_gpus(nb: int, model: Optional[LightningModule] = None) -> list: r""" Pick available GPUs Args: From 96f3a6677428dfd56c5f4317463f25af50b1e7ec Mon Sep 17 00:00:00 2001 From: Sebastien Henwood Date: Fri, 7 Aug 2020 14:01:01 -0400 Subject: [PATCH 7/7] training step delegation + small fixes --- pytorch_lightning/trainer/distrib_parts.py | 18 +++++------------- tests/trainer/test_trainer.py | 2 +- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index dbd7964740826..2d5b7f25c01f7 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -473,7 +473,7 @@ def pick_single_gpu(exclude_gpus: list): def pick_single_gpu_realist_workload(exclude_gpus: list, model: LightningModule) -> int: - batch = next(iter(model.train_dataloader)) + batch = next(iter(model.train_dataloader())) for i in range(torch.cuda.device_count()): if i in exclude_gpus: continue @@ -482,9 +482,9 @@ def pick_single_gpu_realist_workload(exclude_gpus: list, model: LightningModule) try: with torch.set_grad_enabled(True): model_device = model.to(device) - batch_device = batch.to(device) + batch_device = tuple(itup.to(device) for itup in batch) model_device.train() # record grads - model_device(batch_device) + model_device.training_step(batch_device, 0) except RuntimeError as exception: if is_oom_error(exception): # clean after the failed attempt garbage_collection_cuda() @@ -492,7 +492,7 @@ def pick_single_gpu_realist_workload(exclude_gpus: list, model: LightningModule) raise continue return i - return -1 + raise RuntimeError("None of the GPUs could accept the given workload.") def pick_multiple_gpus(nb: int, model: Optional[LightningModule] = None) -> list: @@ -513,13 +513,5 @@ def pick_multiple_gpus(nb: int, model: Optional[LightningModule] = None) -> list picked.append(pick_single_gpu(exclude_gpus=picked)) else: assert hasattr(model, 'train_dataloader') - pick = pick_single_gpu_realist_workload(exclude_gpus=picked, model=model) - if pick != -1: - picked.append(pick) - else: - print(f'There were less than {nb} GPUs capable for this workload') - break - - if len(picked) < 1: - raise RuntimeError("None of the GPUs could accept the given workload.") + picked.append(pick_single_gpu_realist_workload(exclude_gpus=picked, model=model)) return picked diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index aa2b45731380d..2f9d3239fac48 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -769,7 +769,7 @@ def test_gpu_choice_workload(tmpdir): class CurrentModel(EvalModelTemplate): def train_dataloader(self): # Aim to overload the VRAM with the whole MNIST - from tests.base.dataloaders import MNIST + from tests.base.datasets import MNIST return torch.utils.data.DataLoader(MNIST(root=self.data_root, train=True, download=True), batch_size=60000) model = CurrentModel()