Skip to content

Commit b285087

Browse files
author
Sean Naren
authored
Merge c53873d into 2708c39
2 parents 2708c39 + c53873d commit b285087

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,9 @@ def pre_dispatch(self):
242242
if self.sync_batchnorm:
243243
self.model = self.configure_sync_batchnorm(self.model)
244244

245-
# move the model to the correct device
246-
self.model_to_device()
245+
if self.call_move_to_device_hook_in_pre_dispatch:
246+
# move the model to the correct device
247+
self.model_to_device()
247248

248249
self.configure_ddp()
249250

@@ -302,3 +303,11 @@ def predict(self, *args, **kwargs):
302303
def post_training_step(self):
303304
if not self.lightning_module.automatic_optimization:
304305
self.model.require_backward_grad_sync = True
306+
307+
@property
308+
def call_move_to_device_hook_in_pre_dispatch(self) -> bool:
309+
"""
310+
Call the ``model_to_device`` function within pre_dispatch if this is set to True.
311+
Useful for when plugin would like to call model_to_device at another time, or skip the call.
312+
"""
313+
return True

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
self.sync_batchnorm = sync_batchnorm
5555
self._ddp_kwargs = kwargs
5656
self.dist = LightningDistributed()
57-
self.num_processes = len(parallel_devices)
57+
self.num_processes = len(parallel_devices) if parallel_devices is not None else None
5858
self.node_rank = 0
5959
self.mp_queue = None
6060

@@ -146,8 +146,9 @@ def new_process(self, process_idx, trainer, mp_queue):
146146
if self.sync_batchnorm:
147147
self.model = self.configure_sync_batchnorm(self.model)
148148

149-
# move the model to the correct device
150-
self.model_to_device()
149+
if self.call_move_to_device_hook_in_pre_dispatch:
150+
# move the model to the correct device
151+
self.model_to_device()
151152

152153
self.configure_ddp()
153154

@@ -285,3 +286,11 @@ def predict(self, *args, **kwargs):
285286
def post_training_step(self):
286287
if not self.lightning_module.automatic_optimization:
287288
self.model.require_backward_grad_sync = True
289+
290+
@property
291+
def call_move_to_device_hook_in_pre_dispatch(self) -> bool:
292+
"""
293+
Call the ``model_to_device`` function within pre_dispatch if this is set to True.
294+
Useful for when plugin would like to call model_to_device at another time, or skip the call.
295+
"""
296+
return True

tests/accelerators/test_ddp.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from unittest import mock
1516
from unittest.mock import patch
1617

1718
import pytest
1819
import torch
1920

2021
from pytorch_lightning import Trainer
22+
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
2123
from tests.accelerators import ddp_model, DDPLauncher
2224
from tests.helpers.boring_model import BoringModel
2325
from tests.helpers.runif import RunIf
@@ -91,7 +93,6 @@ def test_torch_distributed_backend_env_variables(tmpdir):
9193
_environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"}
9294
with patch.dict(os.environ, _environ), \
9395
patch('torch.cuda.device_count', return_value=2):
94-
9596
with pytest.raises(ValueError, match="Invalid backend: 'undefined'"):
9697
model = BoringModel()
9798
trainer = Trainer(
@@ -102,3 +103,28 @@ def test_torch_distributed_backend_env_variables(tmpdir):
102103
logger=False,
103104
)
104105
trainer.fit(model)
106+
107+
108+
@pytest.mark.parametrize('move_to_device_pre_dispatch_enabled', [True, False])
109+
@mock.patch('pytorch_lightning.plugins.DDPPlugin.model_to_device')
110+
def test_move_to_device_in_pre_dispatch(mock_model_to_device, tmpdir, move_to_device_pre_dispatch_enabled):
111+
"""
112+
Test if ``call_move_to_device_hook_in_pre_dispatch`` is disabled we do not move to device till later
113+
in training.
114+
"""
115+
116+
with mock.patch(
117+
f'pytorch_lightning.plugins.DDPPlugin.call_move_to_device_hook_in_pre_dispatch',
118+
move_to_device_pre_dispatch_enabled
119+
):
120+
model = BoringModel()
121+
trainer = Trainer(
122+
default_root_dir=tmpdir, fast_dev_run=True, accelerator='ddp', plugins=DDPPlugin(), num_processes=1
123+
)
124+
trainer.fit(model)
125+
126+
# Check if mocked device was called. Since we're on CPU, model_to_device does nothing anyway.
127+
if move_to_device_pre_dispatch_enabled:
128+
mock_model_to_device.assert_called()
129+
else:
130+
mock_model_to_device.assert_not_called()

0 commit comments

Comments
 (0)