From e7595b5784a7409ceb3b14f69ab3cc41b45f78e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Feb 2021 22:51:46 +0100 Subject: [PATCH] fix memory issue with ddp_spawn --- pytorch_lightning/accelerators/gpu.py | 1 - pytorch_lightning/plugins/training_type/dp.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index f01cecac1615a..33a3cce7e3a31 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -16,7 +16,6 @@ def setup(self, trainer, model): raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") self.set_nvidia_flags() torch.cuda.set_device(self.root_device) - model.to(self.root_device) return super().setup(trainer, model) def on_train_start(self): diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 54258a8bc1563..76b1247293113 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -27,6 +27,8 @@ def __init__(self, parallel_devices: List[torch.device]): super().__init__(parallel_devices=parallel_devices, cluster_environment=None) def setup(self, model): + # model needs to be moved to the device before it is wrapped + model.to(self.root_device) self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) def reduce(self, output, *args, **kwargs):