diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index f01cecac1615a..33a3cce7e3a31 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -16,7 +16,6 @@ def setup(self, trainer, model): raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") self.set_nvidia_flags() torch.cuda.set_device(self.root_device) - model.to(self.root_device) return super().setup(trainer, model) def on_train_start(self): diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 54258a8bc1563..76b1247293113 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -27,6 +27,8 @@ def __init__(self, parallel_devices: List[torch.device]): super().__init__(parallel_devices=parallel_devices, cluster_environment=None) def setup(self, model): + # model needs to be moved to the device before it is wrapped + model.to(self.root_device) self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) def reduce(self, output, *args, **kwargs):