1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import os
15+ from unittest import mock
1516from unittest .mock import patch
1617
1718import pytest
1819import torch
1920
2021from pytorch_lightning import Trainer
22+ from pytorch_lightning .plugins import DDPPlugin , DDPSpawnPlugin
2123from tests .accelerators import ddp_model , DDPLauncher
2224from tests .helpers .boring_model import BoringModel
2325from tests .helpers .runif import RunIf
@@ -91,7 +93,6 @@ def test_torch_distributed_backend_env_variables(tmpdir):
9193 _environ = {"PL_TORCH_DISTRIBUTED_BACKEND" : "undefined" , "CUDA_VISIBLE_DEVICES" : "0,1" , "WORLD_SIZE" : "2" }
9294 with patch .dict (os .environ , _environ ), \
9395 patch ('torch.cuda.device_count' , return_value = 2 ):
94-
9596 with pytest .raises (ValueError , match = "Invalid backend: 'undefined'" ):
9697 model = BoringModel ()
9798 trainer = Trainer (
@@ -102,3 +103,28 @@ def test_torch_distributed_backend_env_variables(tmpdir):
102103 logger = False ,
103104 )
104105 trainer .fit (model )
106+
107+
108+ @pytest .mark .parametrize ('move_to_device_pre_dispatch_enabled' , [True , False ])
109+ @mock .patch ('pytorch_lightning.plugins.DDPPlugin.model_to_device' )
110+ def test_move_to_device_in_pre_dispatch (mock_model_to_device , tmpdir , move_to_device_pre_dispatch_enabled ):
111+ """
112+ Test if ``call_move_to_device_hook_in_pre_dispatch`` is disabled we do not move to device till later
113+ in training.
114+ """
115+
116+ with mock .patch (
117+ f'pytorch_lightning.plugins.DDPPlugin.call_move_to_device_hook_in_pre_dispatch' ,
118+ move_to_device_pre_dispatch_enabled
119+ ):
120+ model = BoringModel ()
121+ trainer = Trainer (
122+ default_root_dir = tmpdir , fast_dev_run = True , accelerator = 'ddp' , plugins = DDPPlugin (), num_processes = 1
123+ )
124+ trainer .fit (model )
125+
126+ # Check if mocked device was called. Since we're on CPU, model_to_device does nothing anyway.
127+ if move_to_device_pre_dispatch_enabled :
128+ mock_model_to_device .assert_called ()
129+ else :
130+ mock_model_to_device .assert_not_called ()
0 commit comments