diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index f2d27a249f6f2..cc265f3cbf3b9 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -348,13 +348,21 @@ def _check_config_and_set_final_flags( else: self._cluster_environment_flag = getattr(self._strategy_flag, "cluster_environment") - # TODO: RFC existing accel_conn doesn't handle this, should we add conflict check? - # eg: parallel_device is torch.device(cpu) but accelerator=gpu if hasattr(self._strategy_flag, "parallel_devices"): if self._strategy_flag.parallel_devices: if self._strategy_flag.parallel_devices[0].type == "cpu": + if self._accelerator_flag and self._accelerator_flag not in ("auto", "cpu"): + raise MisconfigurationException( + f"CPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," + f" but accelerator set to {self._accelerator_flag}, please choose one device type" + ) self._accelerator_flag = "cpu" if self._strategy_flag.parallel_devices[0].type == "cuda": + if self._accelerator_flag and self._accelerator_flag not in ("auto", "gpu"): + raise MisconfigurationException( + f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," + f" but accelerator set to {self._accelerator_flag}, please choose one device type" + ) self._accelerator_flag = "gpu" amp_type = amp_type if isinstance(amp_type, str) else None diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 0e13b4af0f8d2..78376942b5278 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -933,6 +933,15 @@ def test_devices_auto_choice_gpu(is_gpu_available_mock, device_count_mock): assert trainer.gpus == 2 +@pytest.mark.parametrize( + ["parallel_devices", "accelerator"], + [([torch.device("cpu")], "gpu"), ([torch.device("cuda", i) for i in range(8)], ("tpu"))], +) +def test_parallel_devices_in_strategy_confilict_with_accelerator(parallel_devices, accelerator): + with pytest.raises(MisconfigurationException, match=r"parallel_devices set through"): + Trainer(strategy=DDPStrategy(parallel_devices=parallel_devices), accelerator=accelerator) + + def test_passing_zero_and_empty_list_to_devices_flag(): with pytest.warns(UserWarning, match=r"switching to `cpu` accelerator"): Trainer(accelerator="gpu", devices=0) diff --git a/tests/strategies/test_ddp_strategy.py b/tests/strategies/test_ddp_strategy.py index 157908309f0e6..a54780bd5505c 100644 --- a/tests/strategies/test_ddp_strategy.py +++ b/tests/strategies/test_ddp_strategy.py @@ -115,6 +115,7 @@ def test_ddp_configure_ddp(): # in DDPStrategy configure_ddp(), model wrapped by DistributedDataParallel assert isinstance(trainer.model, DistributedDataParallel) + ddp_strategy = DDPStrategy() trainer = Trainer( max_epochs=1, strategy=ddp_strategy,