Skip to content

Commit bd2a53a

Browse files
author
SeanNaren
committed
Add additional model to device hook if any additional parameters have been set
1 parent 72097ba commit bd2a53a

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

pytorch_lightning/plugins/training_type/single_device.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,12 @@ def model_to_device(self) -> None:
6565
self._model.to(self.root_device)
6666

6767
def pre_dispatch(self) -> None:
68+
# Ensures any additional parameters defined in setup are moved to the correct device.
6869
self.model_to_device()
6970

7071
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
7172
self._model = model
73+
self.model_to_device()
7274
return self.model
7375

7476
@property

0 commit comments

Comments
 (0)