From 92b5807e4cd3bdf0c4891570b71467a98690fc60 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 24 Feb 2022 14:46:30 -0800 Subject: [PATCH 1/3] Check parallel_devices passed through strategy is consistent with accelerator flag --- .../trainer/connectors/accelerator_connector.py | 12 ++++++++++-- tests/accelerators/test_accelerator_connector.py | 9 +++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index f2d27a249f6f2..2adda3245777e 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -348,13 +348,21 @@ def _check_config_and_set_final_flags( else: self._cluster_environment_flag = getattr(self._strategy_flag, "cluster_environment") - # TODO: RFC existing accel_conn doesn't handle this, should we add conflict check? - # eg: parallel_device is torch.device(cpu) but accelerator=gpu if hasattr(self._strategy_flag, "parallel_devices"): if self._strategy_flag.parallel_devices: if self._strategy_flag.parallel_devices[0].type == "cpu": + if self._accelerator_flag and self._accelerator_flag not in ["auto", "cpu"]: + raise MisconfigurationException( + f"CPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," + f" but accelerator set to {self._accelerator_flag}, please choose one device type" + ) self._accelerator_flag = "cpu" if self._strategy_flag.parallel_devices[0].type == "cuda": + if self._accelerator_flag and self._accelerator_flag not in ["auto", "gpu"]: + raise MisconfigurationException( + f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," + f" but accelerator set to {self._accelerator_flag}, please choose one device type" + ) self._accelerator_flag = "gpu" amp_type = amp_type if isinstance(amp_type, str) else None diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 0e13b4af0f8d2..78376942b5278 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -933,6 +933,15 @@ def test_devices_auto_choice_gpu(is_gpu_available_mock, device_count_mock): assert trainer.gpus == 2 +@pytest.mark.parametrize( + ["parallel_devices", "accelerator"], + [([torch.device("cpu")], "gpu"), ([torch.device("cuda", i) for i in range(8)], ("tpu"))], +) +def test_parallel_devices_in_strategy_confilict_with_accelerator(parallel_devices, accelerator): + with pytest.raises(MisconfigurationException, match=r"parallel_devices set through"): + Trainer(strategy=DDPStrategy(parallel_devices=parallel_devices), accelerator=accelerator) + + def test_passing_zero_and_empty_list_to_devices_flag(): with pytest.warns(UserWarning, match=r"switching to `cpu` accelerator"): Trainer(accelerator="gpu", devices=0) From 70e226d6402895b3fac8d85378eca0585b98d5b2 Mon Sep 17 00:00:00 2001 From: Siyu Wang Date: Thu, 24 Feb 2022 15:08:59 -0800 Subject: [PATCH 2/3] fix test --- tests/strategies/test_ddp_strategy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/strategies/test_ddp_strategy.py b/tests/strategies/test_ddp_strategy.py index 157908309f0e6..a54780bd5505c 100644 --- a/tests/strategies/test_ddp_strategy.py +++ b/tests/strategies/test_ddp_strategy.py @@ -115,6 +115,7 @@ def test_ddp_configure_ddp(): # in DDPStrategy configure_ddp(), model wrapped by DistributedDataParallel assert isinstance(trainer.model, DistributedDataParallel) + ddp_strategy = DDPStrategy() trainer = Trainer( max_epochs=1, strategy=ddp_strategy, From a12b42f8670c0f4bf901e5d802d7cfc45c6fe189 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Fri, 25 Feb 2022 14:03:33 -0800 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: ananthsub --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 2adda3245777e..cc265f3cbf3b9 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -351,14 +351,14 @@ def _check_config_and_set_final_flags( if hasattr(self._strategy_flag, "parallel_devices"): if self._strategy_flag.parallel_devices: if self._strategy_flag.parallel_devices[0].type == "cpu": - if self._accelerator_flag and self._accelerator_flag not in ["auto", "cpu"]: + if self._accelerator_flag and self._accelerator_flag not in ("auto", "cpu"): raise MisconfigurationException( f"CPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cpu" if self._strategy_flag.parallel_devices[0].type == "cuda": - if self._accelerator_flag and self._accelerator_flag not in ["auto", "gpu"]: + if self._accelerator_flag and self._accelerator_flag not in ("auto", "gpu"): raise MisconfigurationException( f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type"