From 021dae9b0eb85593530678039e3e9bf89dd07bde Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 21 Feb 2022 15:20:35 +0530 Subject: [PATCH 01/37] Add support for pluggable Accelerators --- pytorch_lightning/accelerators/accelerator.py | 12 ++++- pytorch_lightning/accelerators/cpu.py | 4 ++ pytorch_lightning/accelerators/gpu.py | 8 +++ pytorch_lightning/accelerators/ipu.py | 4 ++ pytorch_lightning/accelerators/tpu.py | 6 +++ .../connectors/accelerator_connector.py | 50 ++++++------------- 6 files changed, 48 insertions(+), 36 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 06f82fb8d4b96..8a8058a49d20f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Union +from typing import Any, Dict, List, Union import torch @@ -55,11 +55,21 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """ raise NotImplementedError + @staticmethod + # @abstractmethod + def parse_devices(devices) -> int: + """Accelerator Parsing logic.""" + @staticmethod @abstractmethod def auto_device_count() -> int: """Get the device count when set to auto.""" + @staticmethod + @abstractmethod + def get_parallel_devices(devices: Union[List[int], str, int]) -> List[torch.device]: + """Gets parallel devices for the given Accelerator.""" + @staticmethod @abstractmethod def is_available() -> bool: diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 2fbe3bf18b079..beb8b2314b0b3 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -39,6 +39,10 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """CPU device stats aren't supported yet.""" return {} + @staticmethod + def get_parallel_devices(devices): + return [torch.device("cpu")] * devices + @staticmethod def auto_device_count() -> int: """Get the devices when set to auto.""" diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index aa8b0d56dbf63..15b4c5e83c77b 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -23,6 +23,7 @@ import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.types import _DEVICE @@ -75,6 +76,13 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return torch.cuda.memory_stats(device) return get_nvidia_gpu_stats(device) + @staticmethod + def get_parallel_devices(devices): + if isinstance(devices, int) or isinstance(devices, str): + devices = int(devices) + return [torch.device("cuda", i) for i in device_parser.parse_gpu_ids(devices)] if devices != 0 else [] + return [torch.device("cuda", i) for i in devices] + @staticmethod def auto_device_count() -> int: """Get the devices when set to auto.""" diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index 6928546cf8c50..147665841a480 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -26,6 +26,10 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """IPU device stats aren't supported yet.""" return {} + @staticmethod + def get_parallel_devices(devices): + return list(range(devices)) + @staticmethod def auto_device_count() -> int: """Get the devices when set to auto.""" diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index f1f598c3f1b3c..0c9c23a62e3ef 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -43,6 +43,12 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: } return device_stats + @staticmethod + def get_parallel_devices(devices): + if isinstance(devices, int): + return list(range(devices)) + return devices + @staticmethod def auto_device_count() -> int: """Get the devices when set to auto.""" diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 8d451f97249fc..c287c6a4b4800 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -455,43 +455,23 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: if isinstance(self._accelerator_flag, Accelerator): self.accelerator: Accelerator = self._accelerator_flag - elif self._accelerator_flag == "tpu": - self.accelerator = TPUAccelerator() - self._set_devices_flag_if_auto_passed() - if isinstance(self._devices_flag, int): - self._parallel_devices = list(range(self._devices_flag)) - else: - self._parallel_devices = self._devices_flag # type: ignore[assignment] - - elif self._accelerator_flag == "ipu": - self.accelerator = IPUAccelerator() - self._set_devices_flag_if_auto_passed() - if isinstance(self._devices_flag, int): - self._parallel_devices = list(range(self._devices_flag)) - - elif self._accelerator_flag == "gpu": - self.accelerator = GPUAccelerator() - self._set_devices_flag_if_auto_passed() - if isinstance(self._devices_flag, int) or isinstance(self._devices_flag, str): - self._devices_flag = int(self._devices_flag) - self._parallel_devices = ( - [torch.device("cuda", i) for i in device_parser.parse_gpu_ids(self._devices_flag)] # type: ignore - if self._devices_flag != 0 - else [] + else: + ACCELERATORS = { + "cpu": CPUAccelerator, + "gpu": GPUAccelerator, + "tpu": TPUAccelerator, + "ipu": IPUAccelerator, + } + self._accelerator_flag = self._accelerator_flag.lower() + if self._accelerator_flag not in ACCELERATORS: + raise MisconfigurationException( + "When passing string value for the `accelerator` argument of `Trainer`," + f" it can only be one of {list(ACCELERATORS.keys())}." ) - else: - self._parallel_devices = [torch.device("cuda", i) for i in self._devices_flag] # type: ignore - - elif self._accelerator_flag == "cpu": - self.accelerator = CPUAccelerator() + accelerator_class = ACCELERATORS[self._accelerator_flag] + self.accelerator = accelerator_class() self._set_devices_flag_if_auto_passed() - if isinstance(self._devices_flag, int): - self._parallel_devices = [torch.device("cpu")] * self._devices_flag - else: - rank_zero_warn( - "The flag `devices` must be an int with `accelerator='cpu'`," - f" got `devices={self._devices_flag}` instead." - ) + self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) self._gpus = self._devices_flag if not self._gpus else self._gpus self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores From 7320c662c59c1624bb010b467efbba83bfb61520 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 21 Feb 2022 17:34:24 +0530 Subject: [PATCH 02/37] Add parse_devices method to Accelerators --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/accelerators/cpu.py | 5 +++++ pytorch_lightning/accelerators/gpu.py | 8 +++++--- pytorch_lightning/accelerators/ipu.py | 5 +++++ pytorch_lightning/accelerators/tpu.py | 6 ++++++ 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8a8058a49d20f..1f0f2ff23cbff 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -56,7 +56,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: raise NotImplementedError @staticmethod - # @abstractmethod + @abstractmethod def parse_devices(devices) -> int: """Accelerator Parsing logic.""" diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index beb8b2314b0b3..adbaf71437286 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -39,6 +39,11 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """CPU device stats aren't supported yet.""" return {} + @staticmethod + def parse_devices(devices) -> int: + """Accelerator Parsing logic.""" + return devices + @staticmethod def get_parallel_devices(devices): return [torch.device("cpu")] * devices diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 15b4c5e83c77b..c5e3f886c7b3b 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -76,11 +76,13 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return torch.cuda.memory_stats(device) return get_nvidia_gpu_stats(device) + @staticmethod + def parse_devices(devices) -> int: + """Accelerator Parsing logic.""" + return device_parser.parse_gpu_ids(devices) + @staticmethod def get_parallel_devices(devices): - if isinstance(devices, int) or isinstance(devices, str): - devices = int(devices) - return [torch.device("cuda", i) for i in device_parser.parse_gpu_ids(devices)] if devices != 0 else [] return [torch.device("cuda", i) for i in devices] @staticmethod diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index 147665841a480..cc625d06a2e8f 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -26,6 +26,11 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """IPU device stats aren't supported yet.""" return {} + @staticmethod + def parse_devices(devices) -> int: + """Accelerator Parsing logic.""" + return devices + @staticmethod def get_parallel_devices(devices): return list(range(devices)) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 0c9c23a62e3ef..6c749940fff4b 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -16,6 +16,7 @@ import torch from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.imports import _TPU_AVAILABLE, _XLA_AVAILABLE if _XLA_AVAILABLE: @@ -43,6 +44,11 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: } return device_stats + @staticmethod + def parse_devices(devices) -> int: + """Accelerator Parsing logic.""" + return device_parser.parse_tpu_cores(devices) + @staticmethod def get_parallel_devices(devices): if isinstance(devices, int): From ac2f5c051fa8201e03308b46b2a76272e922410d Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 21 Feb 2022 17:43:51 +0530 Subject: [PATCH 03/37] Refactor device parsing logic --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index c287c6a4b4800..dceab972e4ad1 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -403,8 +403,6 @@ def _map_deprecated_devices_specfic_info_to_accelerator_and_device_flag( """Sets the `devices_flag` and `accelerator_flag` based on num_processes, gpus, ipus, tpu_cores.""" self._gpus: Optional[Union[List[int], str, int]] = gpus self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores - gpus = device_parser.parse_gpu_ids(gpus) - tpu_cores = device_parser.parse_tpu_cores(tpu_cores) deprecated_devices_specific_flag = num_processes or gpus or ipus or tpu_cores if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in (0, "0"): if devices: @@ -470,6 +468,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: ) accelerator_class = ACCELERATORS[self._accelerator_flag] self.accelerator = accelerator_class() + self._devices_flag = self.accelerator.parse_devices(self._devices_flag) self._set_devices_flag_if_auto_passed() self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) From cfece86b5e695fe7d6ab368e8053f45392060c0b Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 21 Feb 2022 17:59:08 +0530 Subject: [PATCH 04/37] Fix passing Accelerator instances --- .../trainer/connectors/accelerator_connector.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index dceab972e4ad1..35640e1078406 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -468,9 +468,10 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: ) accelerator_class = ACCELERATORS[self._accelerator_flag] self.accelerator = accelerator_class() - self._devices_flag = self.accelerator.parse_devices(self._devices_flag) - self._set_devices_flag_if_auto_passed() - self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) + + self._devices_flag = self.accelerator.parse_devices(self._devices_flag) + self._set_devices_flag_if_auto_passed() + self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) self._gpus = self._devices_flag if not self._gpus else self._gpus self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores From c83717ef51d5c493a1ac4969a588bb2facc10712 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 21 Feb 2022 18:17:04 +0530 Subject: [PATCH 05/37] Fix devices auto --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 35640e1078406..a1eaebb888c85 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -469,8 +469,8 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: accelerator_class = ACCELERATORS[self._accelerator_flag] self.accelerator = accelerator_class() - self._devices_flag = self.accelerator.parse_devices(self._devices_flag) self._set_devices_flag_if_auto_passed() + self._devices_flag = self.accelerator.parse_devices(self._devices_flag) self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) self._gpus = self._devices_flag if not self._gpus else self._gpus From e82feaf53420b836910195a5cd4dc228434b1da4 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 21 Feb 2022 18:39:57 +0530 Subject: [PATCH 06/37] Add tests --- tests/accelerators/test_common.py | 52 ++++++++++++++++++++++++++++++- tests/conftest.py | 1 + 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index ecdcc743ea822..7c3904507b75e 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -18,7 +18,8 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator +from pytorch_lightning.accelerators import Accelerator, CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator +from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.utilities.seed import seed_everything from tests.accelerators.test_dp import CustomClassificationModelDP from tests.helpers.boring_model import BoringModel @@ -85,3 +86,52 @@ def test_auto_device_count(device_count_mock): assert GPUAccelerator.auto_device_count() == 2 assert TPUAccelerator.auto_device_count() == 8 assert IPUAccelerator.auto_device_count() == 4 + + +def test_pluggable_accelerator(tmpdir): + class TestAccelerator(Accelerator): + @staticmethod + def parse_devices(devices) -> int: + """Accelerator Parsing logic.""" + return devices + + @staticmethod + def get_parallel_devices(devices): + return [torch.device("cpu")] * devices + + @staticmethod + def auto_device_count() -> int: + """Get the devices when set to auto.""" + return 1 + + @staticmethod + def is_available() -> bool: + return True + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + accelerator=TestAccelerator(), + devices=2, + strategy="ddp", + ) + trainer.fit(model) + + assert isinstance(trainer.accelerator, TestAccelerator) + assert trainer._accelerator_connector.parallel_devices == [torch.device("cpu")] * 2 + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + strategy=DDPStrategy(TestAccelerator()), + devices=2, + ) + trainer.fit(model) + + assert isinstance(trainer.accelerator, TestAccelerator) + assert trainer._accelerator_connector.parallel_devices == [torch.device("cpu")] * 2 diff --git a/tests/conftest.py b/tests/conftest.py index 8ad7faa3cd769..4440548fa09bc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,6 +77,7 @@ def restore_env_variables(): "XRT_HOST_WORLD_SIZE", "XRT_SHARD_ORDINAL", "XRT_SHARD_LOCAL_ORDINAL", + "TF_CPP_MIN_LOG_LEVEL", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" From 24fa9a4868f6902f81c15b377ef79bed44b6a212 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 21 Feb 2022 18:44:25 +0530 Subject: [PATCH 07/37] Update changelog --- CHANGELOG.md | 3 +++ pytorch_lightning/accelerators/accelerator.py | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c88e4fa0e9564..7129698289352 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -123,6 +123,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added utility functions for moving optimizers to devices ([#11758](https://github.com/PyTorchLightning/pytorch-lightning/pull/11758)) +- Added support for pluggable Accelerators ([#12030](https://github.com/PyTorchLightning/pytorch-lightning/pull/12030)) + + ### Changed - Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 1f0f2ff23cbff..a52445e7d3a14 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -62,13 +62,13 @@ def parse_devices(devices) -> int: @staticmethod @abstractmethod - def auto_device_count() -> int: - """Get the device count when set to auto.""" + def get_parallel_devices(devices: Union[List[int], str, int]) -> List[torch.device]: + """Gets parallel devices for the given Accelerator.""" @staticmethod @abstractmethod - def get_parallel_devices(devices: Union[List[int], str, int]) -> List[torch.device]: - """Gets parallel devices for the given Accelerator.""" + def auto_device_count() -> int: + """Get the device count when set to auto.""" @staticmethod @abstractmethod From ae08e2ab0ccd76b407f0d80af27ea59b1d34d349 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 21 Feb 2022 19:02:45 +0530 Subject: [PATCH 08/37] Update accelerators --- pytorch_lightning/accelerators/accelerator.py | 6 +++--- pytorch_lightning/accelerators/cpu.py | 5 +++-- pytorch_lightning/accelerators/gpu.py | 5 +++-- pytorch_lightning/accelerators/ipu.py | 7 ++++--- pytorch_lightning/accelerators/tpu.py | 7 ++++--- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index a52445e7d3a14..7669ee3b5fbbf 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch @@ -57,12 +57,12 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: @staticmethod @abstractmethod - def parse_devices(devices) -> int: + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]: """Accelerator Parsing logic.""" @staticmethod @abstractmethod - def get_parallel_devices(devices: Union[List[int], str, int]) -> List[torch.device]: + def get_parallel_devices(devices: Union[List[int], str, int]) -> Union[List[torch.device], List[int]]: """Gets parallel devices for the given Accelerator.""" @staticmethod diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index adbaf71437286..913ee1fc703ca 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -40,12 +40,13 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return {} @staticmethod - def parse_devices(devices) -> int: + def parse_devices(devices: int | str | list[int]) -> int | list[int] | None: """Accelerator Parsing logic.""" return devices @staticmethod - def get_parallel_devices(devices): + def get_parallel_devices(devices: list[int] | str | int) -> list[torch.device] | list[int]: + """Gets parallel devices for the given Accelerator.""" return [torch.device("cpu")] * devices @staticmethod diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index c5e3f886c7b3b..954598e8dd08c 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -77,12 +77,13 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return get_nvidia_gpu_stats(device) @staticmethod - def parse_devices(devices) -> int: + def parse_devices(devices: int | str | list[int]) -> int | list[int] | None: """Accelerator Parsing logic.""" return device_parser.parse_gpu_ids(devices) @staticmethod - def get_parallel_devices(devices): + def get_parallel_devices(devices: list[int] | str | int) -> list[torch.device] | list[int]: + """Gets parallel devices for the given Accelerator.""" return [torch.device("cuda", i) for i in devices] @staticmethod diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index cc625d06a2e8f..64f819f7c48bb 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Union +from typing import Any, Dict, List, Optional, Union import torch @@ -27,12 +27,13 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: return {} @staticmethod - def parse_devices(devices) -> int: + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]: """Accelerator Parsing logic.""" return devices @staticmethod - def get_parallel_devices(devices): + def get_parallel_devices(devices: Union[List[int], str, int]) -> Union[List[torch.device], List[int]]: + """Gets parallel devices for the given Accelerator.""" return list(range(devices)) @staticmethod diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 6c749940fff4b..57abec431fb4a 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Union +from typing import Any, Dict, List, Optional, Union import torch @@ -45,12 +45,13 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: return device_stats @staticmethod - def parse_devices(devices) -> int: + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]: """Accelerator Parsing logic.""" return device_parser.parse_tpu_cores(devices) @staticmethod - def get_parallel_devices(devices): + def get_parallel_devices(devices: Union[List[int], str, int]) -> Union[List[torch.device], List[int]]: + """Gets parallel devices for the given Accelerator.""" if isinstance(devices, int): return list(range(devices)) return devices From a6c8bc40ec2b44c198454a508fc3d4ce74ea48d1 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 21 Feb 2022 19:12:04 +0530 Subject: [PATCH 09/37] Update accelerator doc --- docs/source/extensions/accelerator.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/extensions/accelerator.rst b/docs/source/extensions/accelerator.rst index 03c547f49a0f8..9a8fb13512dd7 100644 --- a/docs/source/extensions/accelerator.rst +++ b/docs/source/extensions/accelerator.rst @@ -28,8 +28,8 @@ One to handle differences from the training routine and one to handle different accelerator = GPUAccelerator() precision_plugin = NativeMixedPrecisionPlugin(precision=16, device="cuda") - training_type_plugin = DDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin) - trainer = Trainer(strategy=training_type_plugin) + training_strategy = DDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin) + trainer = Trainer(strategy=training_strategy, devices=2) We expose Accelerators and Plugins mainly for expert users who want to extend Lightning to work with new From 19f2e8d74a5cab9f0c80943d783c317393f0802f Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 22 Feb 2022 12:32:27 +0530 Subject: [PATCH 10/37] Fix acc connector tests --- pytorch_lightning/accelerators/cpu.py | 8 +++++++- .../trainer/connectors/accelerator_connector.py | 5 +++-- tests/accelerators/test_accelerator_connector.py | 14 ++++++++++---- tests/accelerators/test_common.py | 5 ++++- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 913ee1fc703ca..bb4b461bb199e 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -19,6 +19,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.types import _DEVICE @@ -47,7 +48,12 @@ def parse_devices(devices: int | str | list[int]) -> int | list[int] | None: @staticmethod def get_parallel_devices(devices: list[int] | str | int) -> list[torch.device] | list[int]: """Gets parallel devices for the given Accelerator.""" - return [torch.device("cpu")] * devices + if isinstance(devices, int): + return [torch.device("cpu")] * devices + rank_zero_warn( + "The flag `devices` must be an int with `accelerator='cpu'`," f" got `devices={devices}` instead." + ) + return [] @staticmethod def auto_device_count() -> int: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index f6bb6c42ee892..6a3d6b79b33f1 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -469,12 +469,13 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: self.accelerator = accelerator_class() self._set_devices_flag_if_auto_passed() - self._devices_flag = self.accelerator.parse_devices(self._devices_flag) - self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) self._gpus = self._devices_flag if not self._gpus else self._gpus self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores + self._devices_flag = self.accelerator.parse_devices(self._devices_flag) + self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) + def _set_devices_flag_if_auto_passed(self) -> None: if self._devices_flag == "auto" or not self._devices_flag: self._devices_flag = self.accelerator.auto_device_count() diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 088c5ecf4b731..0dd3e602ac549 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -341,6 +341,15 @@ def creates_processes_externally(self) -> bool: @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_custom_accelerator(device_count_mock, setup_distributed_mock): class Accel(Accelerator): + @staticmethod + def parse_devices(devices) -> int: + """Accelerator Parsing logic.""" + return devices + + @staticmethod + def get_parallel_devices(devices): + return [torch.device("cpu")] * devices + @staticmethod def auto_device_count() -> int: return 1 @@ -889,12 +898,9 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock assert trainer.strategy.local_rank == 0 +@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8) def test_unsupported_tpu_choice(monkeypatch): - import pytorch_lightning.utilities.imports as imports - from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector - monkeypatch.setattr(imports, "_XLA_AVAILABLE", True) - monkeypatch.setattr(AcceleratorConnector, "has_tpu", True) with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"): Trainer(accelerator="tpu", precision=64) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index fb42b91c6c2bc..03f3b7f56f948 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -13,8 +13,11 @@ # limitations under the License. from unittest import mock +import torch + from pytorch_lightning import Trainer -from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator +from pytorch_lightning.accelerators import Accelerator, CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator +from pytorch_lightning.strategies import DDPStrategy from tests.helpers.boring_model import BoringModel From b373fa6c04e0cda8672ad5866358c8555fc98156 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 23 Feb 2022 11:40:59 +0530 Subject: [PATCH 11/37] Fix parallel devices being passed to strategy --- .../trainer/connectors/accelerator_connector.py | 7 ++++--- tests/accelerators/test_tpu.py | 10 ++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 6a3d6b79b33f1..0d1501d04f5b6 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -162,6 +162,7 @@ def __init__( self._precision_flag: Optional[Union[int, str]] = None self._precision_plugin_flag: Optional[PrecisionPlugin] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None + self._parallel_devices: List[Union[int, torch.device]] = [] self.checkpoint_io: Optional[CheckpointIO] = None self._amp_type_flag: Optional[LightningEnum] = None self._amp_level_flag: Optional[str] = amp_level @@ -355,6 +356,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = "cpu" if self._strategy_flag.parallel_devices[0].type == "cuda": self._accelerator_flag = "gpu" + self._parallel_devices = self._strategy_flag.parallel_devices amp_type = amp_type if isinstance(amp_type, str) else None self._amp_type_flag = AMPType.from_str(amp_type) @@ -448,8 +450,6 @@ def _choose_accelerator(self) -> str: def _set_parallel_devices_and_init_accelerator(self) -> None: # TODO add device availability check - self._parallel_devices: List[Union[int, torch.device]] = [] - if isinstance(self._accelerator_flag, Accelerator): self.accelerator: Accelerator = self._accelerator_flag else: @@ -474,7 +474,8 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores self._devices_flag = self.accelerator.parse_devices(self._devices_flag) - self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) + if not self._parallel_devices: + self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) def _set_devices_flag_if_auto_passed(self) -> None: if self._devices_flag == "auto" or not self._devices_flag: diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index d8f99ec4dcedb..a065fb6e5aae3 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -292,28 +292,30 @@ def forward(self, x): assert torch.all(torch.eq(model.net_a.layer.weight, model.net_b.layer.weight)) +@RunIf(tpu=True) def test_tpu_invalid_raises(): strategy = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): - Trainer(strategy=strategy) + Trainer(strategy=strategy, accelerator="tpu", devices=8) strategy = DDPStrategy(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"): - Trainer(strategy=strategy) + Trainer(strategy=strategy, accelerator="tpu", devices=8) +@RunIf(tpu=True) def test_tpu_invalid_raises_set_precision_with_strategy(): accelerator = TPUAccelerator() strategy = TPUSpawnStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): - Trainer(strategy=strategy) + Trainer(strategy=strategy, accelerator="tpu", devices=8) accelerator = TPUAccelerator() strategy = DDPStrategy(accelerator=accelerator, precision_plugin=TPUPrecisionPlugin()) with pytest.raises( ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy" ): - Trainer(strategy=strategy) + Trainer(strategy=strategy, accelerator="tpu", devices=8) @RunIf(tpu=True) From af76a72292c8a0d6409505054d4c2203a0353582 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 23 Feb 2022 12:06:32 +0530 Subject: [PATCH 12/37] Fix gpu test --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 0d1501d04f5b6..625db1f38c3f2 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -383,7 +383,7 @@ def _check_device_config_and_set_final_flags( devices, num_processes, gpus, ipus, tpu_cores ) - if self._devices_flag in ([], 0, "0", "0,"): + if self._devices_flag in ([], 0, "0"): rank_zero_warn(f"You passed `devices={devices}`, switching to `cpu` accelerator") self._accelerator_flag = "cpu" From 0f4f387683bfcc9065ecd4df2fc91678503930f3 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 23 Feb 2022 13:58:35 +0530 Subject: [PATCH 13/37] Update tests --- pytorch_lightning/accelerators/cpu.py | 4 +--- tests/accelerators/test_accelerator_connector.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index bb4b461bb199e..45f510ba136b0 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -50,9 +50,7 @@ def get_parallel_devices(devices: list[int] | str | int) -> list[torch.device] | """Gets parallel devices for the given Accelerator.""" if isinstance(devices, int): return [torch.device("cpu")] * devices - rank_zero_warn( - "The flag `devices` must be an int with `accelerator='cpu'`," f" got `devices={devices}` instead." - ) + rank_zero_warn(f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices}` instead.") return [] @staticmethod diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index fcf3c2ab9bfe9..3bb97aa0be58c 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -422,10 +422,17 @@ def test_ipython_incompatible_backend_error(_, monkeypatch): Trainer(strategy="dp") -@pytest.mark.parametrize("trainer_kwargs", [{}, dict(strategy="dp", accelerator="gpu"), dict(accelerator="tpu")]) -def test_ipython_compatible_backend(trainer_kwargs, monkeypatch): +@mock.patch("torch.cuda.device_count", return_value=2) +def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch): + monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) + trainer = Trainer(strategy="dp", accelerator="gpu") + assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible + + +@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8) +def test_ipython_compatible_strategy_tpu(_, monkeypatch): monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) - trainer = Trainer(**trainer_kwargs) + trainer = Trainer(accelerator="tpu") assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible From 0fc482b52f57e8df43e862724c0ad575448b27e7 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 24 Feb 2022 14:48:59 +0530 Subject: [PATCH 14/37] Update tests --- .../trainer/connectors/accelerator_connector.py | 2 +- tests/accelerators/test_tpu.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 9518608230774..61cfdbcc51113 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -384,7 +384,7 @@ def _check_device_config_and_set_final_flags( devices, num_processes, gpus, ipus, tpu_cores ) - if self._devices_flag in ([], 0, "0"): + if self._devices_flag in ([], 0, "0", "0,"): rank_zero_warn(f"You passed `devices={devices}`, switching to `cpu` accelerator") self._accelerator_flag = "cpu" diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index d8af00ca44c44..f7b031105278e 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -270,11 +270,11 @@ def forward(self, x): def test_tpu_invalid_raises(): strategy = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): - Trainer(strategy=strategy, accelerator="tpu", devices=8) + Trainer(strategy=strategy, devices=8) strategy = DDPStrategy(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"): - Trainer(strategy=strategy, accelerator="tpu", devices=8) + Trainer(strategy=strategy, devices=8) @RunIf(tpu=True) @@ -282,14 +282,14 @@ def test_tpu_invalid_raises_set_precision_with_strategy(): accelerator = TPUAccelerator() strategy = TPUSpawnStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): - Trainer(strategy=strategy, accelerator="tpu", devices=8) + Trainer(strategy=strategy, devices=8) accelerator = TPUAccelerator() strategy = DDPStrategy(accelerator=accelerator, precision_plugin=TPUPrecisionPlugin()) with pytest.raises( ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy" ): - Trainer(strategy=strategy, accelerator="tpu", devices=8) + Trainer(strategy=strategy, devices=8) @RunIf(tpu=True) From 10c3d4f4eda85822e854b9c5ceedffc4b8ce9d08 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 24 Feb 2022 15:04:40 +0530 Subject: [PATCH 15/37] Update tests --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 61cfdbcc51113..7858fad71743e 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -406,7 +406,7 @@ def _map_deprecated_devices_specfic_info_to_accelerator_and_device_flag( self._gpus: Optional[Union[List[int], str, int]] = gpus self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores deprecated_devices_specific_flag = num_processes or gpus or ipus or tpu_cores - if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in (0, "0"): + if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0", "0,"): if devices: # TODO: @awaelchli improve error message rank_zero_warn( From b97c88cf9109890e8a86d735488a5c3d2e83156f Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 24 Feb 2022 22:35:49 +0530 Subject: [PATCH 16/37] Update --- pytorch_lightning/accelerators/accelerator.py | 6 +++--- pytorch_lightning/accelerators/cpu.py | 4 ++-- pytorch_lightning/accelerators/gpu.py | 4 ++-- pytorch_lightning/accelerators/ipu.py | 6 +++--- pytorch_lightning/accelerators/tpu.py | 6 +++--- .../trainer/connectors/accelerator_connector.py | 4 ++-- tests/trainer/test_trainer_cli.py | 11 +++++++---- 7 files changed, 22 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 7669ee3b5fbbf..610b8f9668af8 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Union import torch @@ -57,12 +57,12 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: @staticmethod @abstractmethod - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]: + def parse_devices(devices: Any) -> Any: """Accelerator Parsing logic.""" @staticmethod @abstractmethod - def get_parallel_devices(devices: Union[List[int], str, int]) -> Union[List[torch.device], List[int]]: + def get_parallel_devices(devices: Any) -> Any: """Gets parallel devices for the given Accelerator.""" @staticmethod diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 45f510ba136b0..f3ddde8303c29 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -41,12 +41,12 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return {} @staticmethod - def parse_devices(devices: int | str | list[int]) -> int | list[int] | None: + def parse_devices(devices: int | str | list[int]) -> int | list[int]: """Accelerator Parsing logic.""" return devices @staticmethod - def get_parallel_devices(devices: list[int] | str | int) -> list[torch.device] | list[int]: + def get_parallel_devices(devices: list[int] | str | int) -> list[torch.device]: """Gets parallel devices for the given Accelerator.""" if isinstance(devices, int): return [torch.device("cpu")] * devices diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 954598e8dd08c..10daf0764b5c1 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -77,12 +77,12 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return get_nvidia_gpu_stats(device) @staticmethod - def parse_devices(devices: int | str | list[int]) -> int | list[int] | None: + def parse_devices(devices: int | str | list[int]) -> list[int]: """Accelerator Parsing logic.""" return device_parser.parse_gpu_ids(devices) @staticmethod - def get_parallel_devices(devices: list[int] | str | int) -> list[torch.device] | list[int]: + def get_parallel_devices(devices: list[int]) -> list[torch.device]: """Gets parallel devices for the given Accelerator.""" return [torch.device("cuda", i) for i in devices] diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index 64f819f7c48bb..7c0b62ab72b3f 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union import torch @@ -27,12 +27,12 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: return {} @staticmethod - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]: + def parse_devices(devices: int) -> int: """Accelerator Parsing logic.""" return devices @staticmethod - def get_parallel_devices(devices: Union[List[int], str, int]) -> Union[List[torch.device], List[int]]: + def get_parallel_devices(devices: int) -> List[int]: """Gets parallel devices for the given Accelerator.""" return list(range(devices)) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 57abec431fb4a..a6e58dd293c42 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union import torch @@ -45,12 +45,12 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: return device_stats @staticmethod - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]: + def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]: """Accelerator Parsing logic.""" return device_parser.parse_tpu_cores(devices) @staticmethod - def get_parallel_devices(devices: Union[List[int], str, int]) -> Union[List[torch.device], List[int]]: + def get_parallel_devices(devices: Union[int, List[int]]) -> List[int]: """Gets parallel devices for the given Accelerator.""" if isinstance(devices, int): return list(range(devices)) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 7858fad71743e..a50342534a0be 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -384,7 +384,7 @@ def _check_device_config_and_set_final_flags( devices, num_processes, gpus, ipus, tpu_cores ) - if self._devices_flag in ([], 0, "0", "0,"): + if self._devices_flag in ([], 0, "0"): rank_zero_warn(f"You passed `devices={devices}`, switching to `cpu` accelerator") self._accelerator_flag = "cpu" @@ -406,7 +406,7 @@ def _map_deprecated_devices_specfic_info_to_accelerator_and_device_flag( self._gpus: Optional[Union[List[int], str, int]] = gpus self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores deprecated_devices_specific_flag = num_processes or gpus or ipus or tpu_cores - if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0", "0,"): + if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"): if devices: # TODO: @awaelchli improve error message rank_zero_warn( diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index b5713893f769b..8a074e20ee055 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -17,11 +17,11 @@ from unittest import mock import pytest +import torch import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.utilities import argparse -from tests.helpers.runif import RunIf @mock.patch("argparse.ArgumentParser.parse_args") @@ -163,11 +163,14 @@ def test_argparse_args_parsing_fast_dev_run(cli_args, expected): @pytest.mark.parametrize( ["cli_args", "expected_parsed", "expected_device_ids"], - [("", None, None), ("--accelerator gpu --devices 1", "1", [0]), ("--accelerator gpu --devices 0,", "0,", None)], + [("", None, None), ("--accelerator gpu --devices 1", "1", [0]), ("--accelerator gpu --devices 0,", "0,", [0])], ) -@RunIf(min_gpus=1) -def test_argparse_args_parsing_devices(cli_args, expected_parsed, expected_device_ids): +def test_argparse_args_parsing_devices(cli_args, expected_parsed, expected_device_ids, monkeypatch): """Test multi type argument with bool.""" + + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) + cli_args = cli_args.split(" ") if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): parser = ArgumentParser(add_help=False) From 1c12097c1d9506f8c88228470712f609991a0768 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 25 Feb 2022 01:17:28 +0530 Subject: [PATCH 17/37] Fix typing --- docs/source/extensions/accelerator.rst | 1 + pytorch_lightning/accelerators/cpu.py | 2 +- pytorch_lightning/accelerators/gpu.py | 2 +- pytorch_lightning/accelerators/tpu.py | 4 ++-- pytorch_lightning/trainer/connectors/accelerator_connector.py | 4 ++-- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/extensions/accelerator.rst b/docs/source/extensions/accelerator.rst index 9a8fb13512dd7..762f2b6e57a90 100644 --- a/docs/source/extensions/accelerator.rst +++ b/docs/source/extensions/accelerator.rst @@ -20,6 +20,7 @@ Each Accelerator gets two plugins upon initialization: One to handle differences from the training routine and one to handle different precisions. .. testcode:: + :skipif: torch.cuda.device_count() < 2 from pytorch_lightning import Trainer from pytorch_lightning.accelerators import GPUAccelerator diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index f3ddde8303c29..da14dcf2116a0 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -41,7 +41,7 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return {} @staticmethod - def parse_devices(devices: int | str | list[int]) -> int | list[int]: + def parse_devices(devices: int | str | list[int]) -> int | str | list[int]: """Accelerator Parsing logic.""" return devices diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 10daf0764b5c1..000e0d2d0792e 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -77,7 +77,7 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return get_nvidia_gpu_stats(device) @staticmethod - def parse_devices(devices: int | str | list[int]) -> list[int]: + def parse_devices(devices: int | str | list[int]) -> list[int] | None: """Accelerator Parsing logic.""" return device_parser.parse_gpu_ids(devices) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index a6e58dd293c42..422865d9a98ec 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch @@ -45,7 +45,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: return device_stats @staticmethod - def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]: + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]: """Accelerator Parsing logic.""" return device_parser.parse_tpu_cores(devices) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index a50342534a0be..3585346d373f0 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -460,14 +460,14 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: "tpu": TPUAccelerator, "ipu": IPUAccelerator, } - self._accelerator_flag = self._accelerator_flag.lower() + self._accelerator_flag = self._accelerator_flag.lower() # type: ignore if self._accelerator_flag not in ACCELERATORS: raise MisconfigurationException( "When passing string value for the `accelerator` argument of `Trainer`," f" it can only be one of {list(ACCELERATORS.keys())}." ) accelerator_class = ACCELERATORS[self._accelerator_flag] - self.accelerator = accelerator_class() + self.accelerator = accelerator_class() # type: ignore self._set_devices_flag_if_auto_passed() From b75cf38d100060d11550124b2837be9d8e6d885d Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 25 Feb 2022 14:59:32 +0530 Subject: [PATCH 18/37] Update pytorch_lightning/accelerators/cpu.py Co-authored-by: Jirka Borovec --- pytorch_lightning/accelerators/cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index da14dcf2116a0..0b5eaa28c0ed4 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -41,7 +41,7 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return {} @staticmethod - def parse_devices(devices: int | str | list[int]) -> int | str | list[int]: + def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, str, List[int]]: """Accelerator Parsing logic.""" return devices From d8ac58e93bdbee56122064cc0dad10f56edddbf6 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 25 Feb 2022 14:59:46 +0530 Subject: [PATCH 19/37] Update pytorch_lightning/accelerators/gpu.py Co-authored-by: Jirka Borovec --- pytorch_lightning/accelerators/gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 000e0d2d0792e..df5263ba0890c 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -77,7 +77,7 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return get_nvidia_gpu_stats(device) @staticmethod - def parse_devices(devices: int | str | list[int]) -> list[int] | None: + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: """Accelerator Parsing logic.""" return device_parser.parse_gpu_ids(devices) From 8b21da3c88b13abb1ca294502b79de7f0cd222ca Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 25 Feb 2022 15:03:48 +0530 Subject: [PATCH 20/37] Fix typing --- pytorch_lightning/accelerators/cpu.py | 4 ++-- pytorch_lightning/accelerators/gpu.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 0b5eaa28c0ed4..7c7b2aca3bc09 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any +from typing import Any, List, Union import torch @@ -46,7 +46,7 @@ def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, str, List[i return devices @staticmethod - def get_parallel_devices(devices: list[int] | str | int) -> list[torch.device]: + def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: """Gets parallel devices for the given Accelerator.""" if isinstance(devices, int): return [torch.device("cpu")] * devices diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index df5263ba0890c..0a07208ff0ee8 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -17,7 +17,7 @@ import os import shutil import subprocess -from typing import Any +from typing import Any, List, Optional, Union import torch @@ -82,7 +82,7 @@ def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: return device_parser.parse_gpu_ids(devices) @staticmethod - def get_parallel_devices(devices: list[int]) -> list[torch.device]: + def get_parallel_devices(devices: List[int]) -> List[torch.device]: """Gets parallel devices for the given Accelerator.""" return [torch.device("cuda", i) for i in devices] From 30ea6783d852cdb57f23c18d9718102eb5d23861 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 25 Feb 2022 09:35:13 +0000 Subject: [PATCH 21/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/accelerators/cpu.py | 4 ++-- pytorch_lightning/accelerators/gpu.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 7c7b2aca3bc09..913ed018a2e6d 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -41,12 +41,12 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return {} @staticmethod - def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, str, List[int]]: + def parse_devices(devices: int | str | list[int]) -> int | str | list[int]: """Accelerator Parsing logic.""" return devices @staticmethod - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: int | str | list[int]) -> list[torch.device]: """Gets parallel devices for the given Accelerator.""" if isinstance(devices, int): return [torch.device("cpu")] * devices diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 0a07208ff0ee8..8faf59e6b878d 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -77,12 +77,12 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return get_nvidia_gpu_stats(device) @staticmethod - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + def parse_devices(devices: int | str | list[int]) -> list[int] | None: """Accelerator Parsing logic.""" return device_parser.parse_gpu_ids(devices) @staticmethod - def get_parallel_devices(devices: List[int]) -> List[torch.device]: + def get_parallel_devices(devices: list[int]) -> list[torch.device]: """Gets parallel devices for the given Accelerator.""" return [torch.device("cuda", i) for i in devices] From 6da57dd1432fa9a55febca9b4a5902ed9ee6bb3d Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 25 Feb 2022 16:20:09 +0530 Subject: [PATCH 22/37] Taking control over pre-commit --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/accelerators/cpu.py | 8 +++----- pytorch_lightning/accelerators/gpu.py | 8 +++----- pytorch_lightning/accelerators/ipu.py | 2 +- pytorch_lightning/accelerators/tpu.py | 2 +- 5 files changed, 9 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 610b8f9668af8..9355afdf114f6 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -63,7 +63,7 @@ def parse_devices(devices: Any) -> Any: @staticmethod @abstractmethod def get_parallel_devices(devices: Any) -> Any: - """Gets parallel devices for the given Accelerator.""" + """Gets parallel devices for the Accelerator.""" @staticmethod @abstractmethod diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 913ed018a2e6d..09635b8d5a4ba 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from typing import Any, List, Union import torch @@ -41,13 +39,13 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return {} @staticmethod - def parse_devices(devices: int | str | list[int]) -> int | str | list[int]: + def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, str, List[int]]: """Accelerator Parsing logic.""" return devices @staticmethod - def get_parallel_devices(devices: int | str | list[int]) -> list[torch.device]: - """Gets parallel devices for the given Accelerator.""" + def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + """Gets parallel devices for the Accelerator.""" if isinstance(devices, int): return [torch.device("cpu")] * devices rank_zero_warn(f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices}` instead.") diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 8faf59e6b878d..2dcc29a0d113b 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import logging import os import shutil @@ -77,13 +75,13 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return get_nvidia_gpu_stats(device) @staticmethod - def parse_devices(devices: int | str | list[int]) -> list[int] | None: + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: """Accelerator Parsing logic.""" return device_parser.parse_gpu_ids(devices) @staticmethod - def get_parallel_devices(devices: list[int]) -> list[torch.device]: - """Gets parallel devices for the given Accelerator.""" + def get_parallel_devices(devices: List[int]) -> List[torch.device]: + """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] @staticmethod diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index 7c0b62ab72b3f..ce7118b31bd0f 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -33,7 +33,7 @@ def parse_devices(devices: int) -> int: @staticmethod def get_parallel_devices(devices: int) -> List[int]: - """Gets parallel devices for the given Accelerator.""" + """Gets parallel devices for the Accelerator.""" return list(range(devices)) @staticmethod diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 422865d9a98ec..30eb9e5734e3e 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -51,7 +51,7 @@ def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, Li @staticmethod def get_parallel_devices(devices: Union[int, List[int]]) -> List[int]: - """Gets parallel devices for the given Accelerator.""" + """Gets parallel devices for the Accelerator.""" if isinstance(devices, int): return list(range(devices)) return devices From 9eb70eb73c0708f03a8f6b4d2c47149918a5310e Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 25 Feb 2022 16:21:35 +0530 Subject: [PATCH 23/37] Update tests/accelerators/test_common.py Co-authored-by: Rohit Gupta --- tests/accelerators/test_common.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 03f3b7f56f948..d2771b15fe8cf 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -32,8 +32,7 @@ def test_auto_device_count(device_count_mock): def test_pluggable_accelerator(tmpdir): class TestAccelerator(Accelerator): @staticmethod - def parse_devices(devices) -> int: - """Accelerator Parsing logic.""" + def parse_devices(devices): return devices @staticmethod @@ -41,12 +40,11 @@ def get_parallel_devices(devices): return [torch.device("cpu")] * devices @staticmethod - def auto_device_count() -> int: - """Get the devices when set to auto.""" + def auto_device_count(): return 1 @staticmethod - def is_available() -> bool: + def is_available(): return True model = BoringModel() From b55af6b7e815fac93ecb18734f13cf4c456b2862 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 25 Feb 2022 16:21:46 +0530 Subject: [PATCH 24/37] Update tests/accelerators/test_accelerator_connector.py Co-authored-by: Rohit Gupta --- tests/accelerators/test_accelerator_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 3bb97aa0be58c..ad065b4c0c9d1 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -342,8 +342,7 @@ def creates_processes_externally(self) -> bool: def test_custom_accelerator(device_count_mock, setup_distributed_mock): class Accel(Accelerator): @staticmethod - def parse_devices(devices) -> int: - """Accelerator Parsing logic.""" + def parse_devices(devices): return devices @staticmethod From 2b99d022abfda0e20fb8e63d9d428e7e90b3cc7c Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 25 Feb 2022 16:23:11 +0530 Subject: [PATCH 25/37] Address reviews --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/accelerators/cpu.py | 2 +- pytorch_lightning/accelerators/gpu.py | 2 +- pytorch_lightning/accelerators/ipu.py | 2 +- pytorch_lightning/accelerators/tpu.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 9355afdf114f6..cbd0e2309e311 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -58,7 +58,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: @staticmethod @abstractmethod def parse_devices(devices: Any) -> Any: - """Accelerator Parsing logic.""" + """Accelerator device parsing logic.""" @staticmethod @abstractmethod diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 09635b8d5a4ba..7ab1ebf67c0c2 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -40,7 +40,7 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: @staticmethod def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, str, List[int]]: - """Accelerator Parsing logic.""" + """Accelerator device parsing logic.""" return devices @staticmethod diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 2dcc29a0d113b..6b767b2d29e23 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -76,7 +76,7 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: @staticmethod def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: - """Accelerator Parsing logic.""" + """Accelerator device parsing logic.""" return device_parser.parse_gpu_ids(devices) @staticmethod diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index ce7118b31bd0f..2ac1c794610d8 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -28,7 +28,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: @staticmethod def parse_devices(devices: int) -> int: - """Accelerator Parsing logic.""" + """Accelerator device parsing logic.""" return devices @staticmethod diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 30eb9e5734e3e..cd84cccd8b493 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -46,7 +46,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: @staticmethod def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]: - """Accelerator Parsing logic.""" + """Accelerator device parsing logic.""" return device_parser.parse_tpu_cores(devices) @staticmethod From 258f58786a95d37659296f7985bc7fa7e98002c3 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 25 Feb 2022 16:23:42 +0530 Subject: [PATCH 26/37] Update pytorch_lightning/trainer/connectors/accelerator_connector.py Co-authored-by: Rohit Gupta --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 3585346d373f0..5a49bec28e453 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -464,7 +464,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: if self._accelerator_flag not in ACCELERATORS: raise MisconfigurationException( "When passing string value for the `accelerator` argument of `Trainer`," - f" it can only be one of {list(ACCELERATORS.keys())}." + f" it can only be one of {list(ACCELERATORS)}." ) accelerator_class = ACCELERATORS[self._accelerator_flag] self.accelerator = accelerator_class() # type: ignore From 438769f0b6b8528b0e0e4d5ccfae976b9abce68b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Feb 2022 12:54:15 +0100 Subject: [PATCH 27/37] Mypy --- .../trainer/connectors/accelerator_connector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 5a49bec28e453..e7f7683b76957 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -460,14 +460,15 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: "tpu": TPUAccelerator, "ipu": IPUAccelerator, } - self._accelerator_flag = self._accelerator_flag.lower() # type: ignore + assert self._accelerator_flag is not None + self._accelerator_flag = self._accelerator_flag.lower() if self._accelerator_flag not in ACCELERATORS: raise MisconfigurationException( "When passing string value for the `accelerator` argument of `Trainer`," f" it can only be one of {list(ACCELERATORS)}." ) accelerator_class = ACCELERATORS[self._accelerator_flag] - self.accelerator = accelerator_class() # type: ignore + self.accelerator = accelerator_class() # type: ignore[abstract] self._set_devices_flag_if_auto_passed() From 6e3584118dab7a8f3c3efc713a0d20869b6caa80 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 25 Feb 2022 17:38:56 +0530 Subject: [PATCH 28/37] Update pytorch_lightning/accelerators/cpu.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/accelerators/cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 7ab1ebf67c0c2..b6d5a97bc49ce 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -48,7 +48,7 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi """Gets parallel devices for the Accelerator.""" if isinstance(devices, int): return [torch.device("cpu")] * devices - rank_zero_warn(f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices}` instead.") + rank_zero_warn(f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices!r}` instead.") return [] @staticmethod From 4bdcaff46ea3d7e8918f8cea4633ecf3ed30f86e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 25 Feb 2022 12:10:22 +0000 Subject: [PATCH 29/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/accelerators/cpu.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index b6d5a97bc49ce..30cf7c29ec959 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -48,7 +48,9 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi """Gets parallel devices for the Accelerator.""" if isinstance(devices, int): return [torch.device("cpu")] * devices - rank_zero_warn(f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices!r}` instead.") + rank_zero_warn( + f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices!r}` instead." + ) return [] @staticmethod From 1bdd7a837222797407ba3d39cfdbdaf1d0436be6 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 25 Feb 2022 17:46:11 +0530 Subject: [PATCH 30/37] Update tests --- tests/accelerators/test_common.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index d2771b15fe8cf..9ecebbbc86281 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -18,7 +18,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import Accelerator, CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator from pytorch_lightning.strategies import DDPStrategy -from tests.helpers.boring_model import BoringModel @mock.patch("torch.cuda.device_count", return_value=2) @@ -47,7 +46,6 @@ def auto_device_count(): def is_available(): return True - model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, @@ -57,7 +55,6 @@ def is_available(): devices=2, strategy="ddp", ) - trainer.fit(model) assert isinstance(trainer.accelerator, TestAccelerator) assert trainer._accelerator_connector.parallel_devices == [torch.device("cpu")] * 2 @@ -70,7 +67,6 @@ def is_available(): strategy=DDPStrategy(TestAccelerator()), devices=2, ) - trainer.fit(model) assert isinstance(trainer.accelerator, TestAccelerator) assert trainer._accelerator_connector.parallel_devices == [torch.device("cpu")] * 2 From 82fbc02e8c7a0e5b21d84b6f6ed3b8e098228960 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 25 Feb 2022 18:40:57 +0530 Subject: [PATCH 31/37] Fix typing --- pytorch_lightning/accelerators/cpu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 30cf7c29ec959..d586478619c05 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Union +from typing import Any, Dict, List, Union import torch @@ -34,7 +34,7 @@ def setup_environment(self, root_device: torch.device) -> None: if root_device.type != "cpu": raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.") - def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """CPU device stats aren't supported yet.""" return {} From ec845a5c1b75ee9a97fe860f546129b78440463f Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 25 Feb 2022 18:46:11 +0530 Subject: [PATCH 32/37] fix typing --- pytorch_lightning/accelerators/gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 6b767b2d29e23..dbb0463ac9120 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -43,7 +43,7 @@ def setup_environment(self, root_device: torch.device) -> None: raise MisconfigurationException(f"Device should be GPU, got {root_device} instead") torch.cuda.set_device(root_device) - def setup(self, trainer: pl.Trainer) -> None: + def setup(self, trainer: "pl.Trainer") -> None: # TODO refactor input from trainer to local_rank @four4fish self.set_nvidia_flags(trainer.local_rank) # clear cache before training From a99e4ea2c786e3cdbff9051ca08c34981f368546 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 25 Feb 2022 18:49:34 +0530 Subject: [PATCH 33/37] Fix typing --- pytorch_lightning/accelerators/gpu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index dbb0463ac9120..909991aed84f9 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -15,7 +15,7 @@ import os import shutil import subprocess -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch @@ -57,7 +57,7 @@ def set_nvidia_flags(local_rank: int) -> None: devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") - def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """Gets stats for the given GPU device. Args: From 9869db1fa2884d8c4ed801989e438c93a7d7b18b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 25 Feb 2022 18:50:32 +0530 Subject: [PATCH 34/37] fix typing --- pytorch_lightning/accelerators/gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 909991aed84f9..f9181e8802e21 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -94,7 +94,7 @@ def is_available() -> bool: return torch.cuda.device_count() > 0 -def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: +def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. Args: From 1fa5d1a62388a52866db0b080baa58c795d7db85 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Feb 2022 15:41:16 +0100 Subject: [PATCH 35/37] Simplify test --- tests/accelerators/test_common.py | 35 ++++++++----------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 9ecebbbc86281..473546696e1e3 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -13,8 +13,6 @@ # limitations under the License. from unittest import mock -import torch - from pytorch_lightning import Trainer from pytorch_lightning.accelerators import Accelerator, CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator from pytorch_lightning.strategies import DDPStrategy @@ -28,7 +26,7 @@ def test_auto_device_count(device_count_mock): assert IPUAccelerator.auto_device_count() == 4 -def test_pluggable_accelerator(tmpdir): +def test_pluggable_accelerator(): class TestAccelerator(Accelerator): @staticmethod def parse_devices(devices): @@ -36,37 +34,22 @@ def parse_devices(devices): @staticmethod def get_parallel_devices(devices): - return [torch.device("cpu")] * devices + return ["foo"] * devices @staticmethod def auto_device_count(): - return 1 + return 3 @staticmethod def is_available(): return True - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - accelerator=TestAccelerator(), - devices=2, - strategy="ddp", - ) - + trainer = Trainer(accelerator=TestAccelerator(), devices=2, strategy="ddp") assert isinstance(trainer.accelerator, TestAccelerator) - assert trainer._accelerator_connector.parallel_devices == [torch.device("cpu")] * 2 - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - strategy=DDPStrategy(TestAccelerator()), - devices=2, - ) + assert isinstance(trainer.strategy, DDPStrategy) + assert trainer._accelerator_connector.parallel_devices == ["foo"] * 2 + trainer = Trainer(strategy=DDPStrategy(TestAccelerator()), devices="auto") assert isinstance(trainer.accelerator, TestAccelerator) - assert trainer._accelerator_connector.parallel_devices == [torch.device("cpu")] * 2 + assert isinstance(trainer.strategy, DDPStrategy) + assert trainer._accelerator_connector.parallel_devices == ["foo"] * 3 From e50d1d10bc369bf3aced06ec8babd421902adf92 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 28 Feb 2022 18:50:45 +0530 Subject: [PATCH 36/37] Fix tests --- pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b79b095feca94..8773e8cee3ee8 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -21,7 +21,7 @@ from torch.utils.data.distributed import DistributedSampler import pytorch_lightning as pl -from pytorch_lightning.accelerators import IPUAccelerator +from pytorch_lightning.accelerators.ipu import IPUAccelerator from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.trainer.states import RunningStage, TrainerFn From 190f7a8df8e76875a801f4172fd71e21f7eea474 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 28 Feb 2022 20:14:57 +0530 Subject: [PATCH 37/37] Fix tests --- tests/accelerators/test_accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 3c201c565e18c..24104d4073460 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -899,7 +899,7 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock @mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8) -def test_unsupported_tpu_choice(): +def test_unsupported_tpu_choice(mock_devices): with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"): Trainer(accelerator="tpu", precision=64)