Skip to content

Commit d7ec33e

Browse files
author
SeanNaren
committed
For single device move in pre_dispatch after setup function
1 parent 4651e57 commit d7ec33e

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

pytorch_lightning/plugins/training_type/single_device.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ def model_to_device(self) -> None:
6464

6565
self._model.to(self.root_device)
6666

67+
def pre_dispatch(self) -> None:
68+
self.model_to_device()
69+
6770
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
6871
self._model = model
69-
self.model_to_device()
7072
return self.model
7173

7274
@property

pytorch_lightning/trainer/trainer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -430,12 +430,8 @@ def fit(
430430
# ----------------------------
431431
self.call_hook("on_before_accelerator_backend_setup", model)
432432
self.accelerator.setup(self, model) # note: this sets up self.lightning_module
433-
# TODO If we call the setup hook here, we might not move our model to the correct device since setup is called
434-
# TODO On One GPU
435433
self.call_setup_hook(model)
436434

437-
438-
439435
# ----------------------------
440436
# INSPECT THE CORE LOOPS
441437
# ----------------------------

0 commit comments

Comments
 (0)