diff --git a/CHANGELOG.md b/CHANGELOG.md index 7034e9b0373f0..56ff447628c15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -137,6 +137,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `estimated_stepping_batches` property to `Trainer` ([#11599](https://github.com/PyTorchLightning/pytorch-lightning/pull/11599)) +- Added support for pluggable Accelerators ([#12030](https://github.com/PyTorchLightning/pytorch-lightning/pull/12030)) + + ### Changed - Make `benchmark` flag optional and set its value based on the deterministic flag ([#11944](https://github.com/PyTorchLightning/pytorch-lightning/pull/11944)) diff --git a/docs/source/extensions/accelerator.rst b/docs/source/extensions/accelerator.rst index 03c547f49a0f8..762f2b6e57a90 100644 --- a/docs/source/extensions/accelerator.rst +++ b/docs/source/extensions/accelerator.rst @@ -20,6 +20,7 @@ Each Accelerator gets two plugins upon initialization: One to handle differences from the training routine and one to handle different precisions. .. testcode:: + :skipif: torch.cuda.device_count() < 2 from pytorch_lightning import Trainer from pytorch_lightning.accelerators import GPUAccelerator @@ -28,8 +29,8 @@ One to handle differences from the training routine and one to handle different accelerator = GPUAccelerator() precision_plugin = NativeMixedPrecisionPlugin(precision=16, device="cuda") - training_type_plugin = DDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin) - trainer = Trainer(strategy=training_type_plugin) + training_strategy = DDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin) + trainer = Trainer(strategy=training_strategy, devices=2) We expose Accelerators and Plugins mainly for expert users who want to extend Lightning to work with new diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 06f82fb8d4b96..cbd0e2309e311 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -55,6 +55,16 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """ raise NotImplementedError + @staticmethod + @abstractmethod + def parse_devices(devices: Any) -> Any: + """Accelerator device parsing logic.""" + + @staticmethod + @abstractmethod + def get_parallel_devices(devices: Any) -> Any: + """Gets parallel devices for the Accelerator.""" + @staticmethod @abstractmethod def auto_device_count() -> int: diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 2fbe3bf18b079..d586478619c05 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - -from typing import Any +from typing import Any, Dict, List, Union import torch from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.types import _DEVICE @@ -35,10 +34,25 @@ def setup_environment(self, root_device: torch.device) -> None: if root_device.type != "cpu": raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.") - def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """CPU device stats aren't supported yet.""" return {} + @staticmethod + def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, str, List[int]]: + """Accelerator device parsing logic.""" + return devices + + @staticmethod + def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + """Gets parallel devices for the Accelerator.""" + if isinstance(devices, int): + return [torch.device("cpu")] * devices + rank_zero_warn( + f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices!r}` instead." + ) + return [] + @staticmethod def auto_device_count() -> int: """Get the devices when set to auto.""" diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index aa8b0d56dbf63..f9181e8802e21 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -11,18 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import logging import os import shutil import subprocess -from typing import Any +from typing import Any, Dict, List, Optional, Union import torch import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.types import _DEVICE @@ -44,7 +43,7 @@ def setup_environment(self, root_device: torch.device) -> None: raise MisconfigurationException(f"Device should be GPU, got {root_device} instead") torch.cuda.set_device(root_device) - def setup(self, trainer: pl.Trainer) -> None: + def setup(self, trainer: "pl.Trainer") -> None: # TODO refactor input from trainer to local_rank @four4fish self.set_nvidia_flags(trainer.local_rank) # clear cache before training @@ -58,7 +57,7 @@ def set_nvidia_flags(local_rank: int) -> None: devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") - def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """Gets stats for the given GPU device. Args: @@ -75,6 +74,16 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: return torch.cuda.memory_stats(device) return get_nvidia_gpu_stats(device) + @staticmethod + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + """Accelerator device parsing logic.""" + return device_parser.parse_gpu_ids(devices) + + @staticmethod + def get_parallel_devices(devices: List[int]) -> List[torch.device]: + """Gets parallel devices for the Accelerator.""" + return [torch.device("cuda", i) for i in devices] + @staticmethod def auto_device_count() -> int: """Get the devices when set to auto.""" @@ -85,7 +94,7 @@ def is_available() -> bool: return torch.cuda.device_count() > 0 -def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: +def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. Args: diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index 6928546cf8c50..2ac1c794610d8 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Union +from typing import Any, Dict, List, Union import torch @@ -26,6 +26,16 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """IPU device stats aren't supported yet.""" return {} + @staticmethod + def parse_devices(devices: int) -> int: + """Accelerator device parsing logic.""" + return devices + + @staticmethod + def get_parallel_devices(devices: int) -> List[int]: + """Gets parallel devices for the Accelerator.""" + return list(range(devices)) + @staticmethod def auto_device_count() -> int: """Get the devices when set to auto.""" diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index f1f598c3f1b3c..cd84cccd8b493 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Union +from typing import Any, Dict, List, Optional, Union import torch from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.imports import _TPU_AVAILABLE, _XLA_AVAILABLE if _XLA_AVAILABLE: @@ -43,6 +44,18 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: } return device_stats + @staticmethod + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]: + """Accelerator device parsing logic.""" + return device_parser.parse_tpu_cores(devices) + + @staticmethod + def get_parallel_devices(devices: Union[int, List[int]]) -> List[int]: + """Gets parallel devices for the Accelerator.""" + if isinstance(devices, int): + return list(range(devices)) + return devices + @staticmethod def auto_device_count() -> int: """Get the devices when set to auto.""" diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 6360d093d5781..0d36ff1384f93 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -168,6 +168,7 @@ def __init__( self._precision_flag: Optional[Union[int, str]] = None self._precision_plugin_flag: Optional[PrecisionPlugin] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None + self._parallel_devices: List[Union[int, torch.device]] = [] self.checkpoint_io: Optional[CheckpointIO] = None self._amp_type_flag: Optional[LightningEnum] = None self._amp_level_flag: Optional[str] = amp_level @@ -361,6 +362,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = "cpu" if self._strategy_flag.parallel_devices[0].type == "cuda": self._accelerator_flag = "gpu" + self._parallel_devices = self._strategy_flag.parallel_devices amp_type = amp_type if isinstance(amp_type, str) else None self._amp_type_flag = AMPType.from_str(amp_type) @@ -387,7 +389,7 @@ def _check_device_config_and_set_final_flags( devices, num_processes, gpus, ipus, tpu_cores ) - if self._devices_flag in ([], 0, "0", "0,"): + if self._devices_flag in ([], 0, "0"): rank_zero_warn(f"You passed `devices={devices}`, switching to `cpu` accelerator") self._accelerator_flag = "cpu" @@ -408,10 +410,8 @@ def _map_deprecated_devices_specfic_info_to_accelerator_and_device_flag( """Sets the `devices_flag` and `accelerator_flag` based on num_processes, gpus, ipus, tpu_cores.""" self._gpus: Optional[Union[List[int], str, int]] = gpus self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores - gpus = device_parser.parse_gpu_ids(gpus) - tpu_cores = device_parser.parse_tpu_cores(tpu_cores) deprecated_devices_specific_flag = num_processes or gpus or ipus or tpu_cores - if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in (0, "0"): + if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"): if devices: # TODO: @awaelchli improve error message rank_zero_warn( @@ -456,51 +456,34 @@ def _choose_accelerator(self) -> str: def _set_parallel_devices_and_init_accelerator(self) -> None: # TODO add device availability check - self._parallel_devices: List[Union[int, torch.device]] = [] - if isinstance(self._accelerator_flag, Accelerator): self.accelerator: Accelerator = self._accelerator_flag - elif self._accelerator_flag == "tpu": - self.accelerator = TPUAccelerator() - self._set_devices_flag_if_auto_passed() - if isinstance(self._devices_flag, int): - self._parallel_devices = list(range(self._devices_flag)) - else: - self._parallel_devices = self._devices_flag # type: ignore[assignment] - - elif self._accelerator_flag == "ipu": - self.accelerator = IPUAccelerator() - self._set_devices_flag_if_auto_passed() - if isinstance(self._devices_flag, int): - self._parallel_devices = list(range(self._devices_flag)) - - elif self._accelerator_flag == "gpu": - self.accelerator = GPUAccelerator() - self._set_devices_flag_if_auto_passed() - if isinstance(self._devices_flag, int) or isinstance(self._devices_flag, str): - self._devices_flag = int(self._devices_flag) - self._parallel_devices = ( - [torch.device("cuda", i) for i in device_parser.parse_gpu_ids(self._devices_flag)] # type: ignore - if self._devices_flag != 0 - else [] + else: + ACCELERATORS = { + "cpu": CPUAccelerator, + "gpu": GPUAccelerator, + "tpu": TPUAccelerator, + "ipu": IPUAccelerator, + } + assert self._accelerator_flag is not None + self._accelerator_flag = self._accelerator_flag.lower() + if self._accelerator_flag not in ACCELERATORS: + raise MisconfigurationException( + "When passing string value for the `accelerator` argument of `Trainer`," + f" it can only be one of {list(ACCELERATORS)}." ) - else: - self._parallel_devices = [torch.device("cuda", i) for i in self._devices_flag] # type: ignore + accelerator_class = ACCELERATORS[self._accelerator_flag] + self.accelerator = accelerator_class() # type: ignore[abstract] - elif self._accelerator_flag == "cpu": - self.accelerator = CPUAccelerator() - self._set_devices_flag_if_auto_passed() - if isinstance(self._devices_flag, int): - self._parallel_devices = [torch.device("cpu")] * self._devices_flag - else: - rank_zero_warn( - "The flag `devices` must be an int with `accelerator='cpu'`," - f" got `devices={self._devices_flag}` instead." - ) + self._set_devices_flag_if_auto_passed() self._gpus = self._devices_flag if not self._gpus else self._gpus self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores + self._devices_flag = self.accelerator.parse_devices(self._devices_flag) + if not self._parallel_devices: + self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) + def _set_devices_flag_if_auto_passed(self) -> None: if self._devices_flag == "auto" or not self._devices_flag: self._devices_flag = self.accelerator.auto_device_count() diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b0cf6a95fac35..e8932b3e56feb 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -21,7 +21,7 @@ from torch.utils.data.distributed import DistributedSampler import pytorch_lightning as pl -from pytorch_lightning.accelerators import IPUAccelerator +from pytorch_lightning.accelerators.ipu import IPUAccelerator from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.trainer.states import RunningStage, TrainerFn diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index cf745fbad4c8d..24104d4073460 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -341,6 +341,14 @@ def creates_processes_externally(self) -> bool: @mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) def test_custom_accelerator(device_count_mock, setup_distributed_mock): class Accel(Accelerator): + @staticmethod + def parse_devices(devices): + return devices + + @staticmethod + def get_parallel_devices(devices): + return [torch.device("cpu")] * devices + @staticmethod def auto_device_count() -> int: return 1 @@ -413,10 +421,17 @@ def test_ipython_incompatible_backend_error(_, monkeypatch): Trainer(strategy="dp") -@pytest.mark.parametrize("trainer_kwargs", [{}, dict(strategy="dp", accelerator="gpu"), dict(accelerator="tpu")]) -def test_ipython_compatible_backend(trainer_kwargs, monkeypatch): +@mock.patch("torch.cuda.device_count", return_value=2) +def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch): + monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) + trainer = Trainer(strategy="dp", accelerator="gpu") + assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible + + +@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8) +def test_ipython_compatible_strategy_tpu(_, monkeypatch): monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) - trainer = Trainer(**trainer_kwargs) + trainer = Trainer(accelerator="tpu") assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible @@ -883,10 +898,9 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock assert trainer.strategy.local_rank == 0 -def test_unsupported_tpu_choice(monkeypatch): - import pytorch_lightning.utilities.imports as imports +@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8) +def test_unsupported_tpu_choice(mock_devices): - monkeypatch.setattr(imports, "_XLA_AVAILABLE", True) with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"): Trainer(accelerator="tpu", precision=64) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 4190553d90115..473546696e1e3 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -13,7 +13,9 @@ # limitations under the License. from unittest import mock -from pytorch_lightning.accelerators import CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator +from pytorch_lightning import Trainer +from pytorch_lightning.accelerators import Accelerator, CPUAccelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator +from pytorch_lightning.strategies import DDPStrategy @mock.patch("torch.cuda.device_count", return_value=2) @@ -22,3 +24,32 @@ def test_auto_device_count(device_count_mock): assert GPUAccelerator.auto_device_count() == 2 assert TPUAccelerator.auto_device_count() == 8 assert IPUAccelerator.auto_device_count() == 4 + + +def test_pluggable_accelerator(): + class TestAccelerator(Accelerator): + @staticmethod + def parse_devices(devices): + return devices + + @staticmethod + def get_parallel_devices(devices): + return ["foo"] * devices + + @staticmethod + def auto_device_count(): + return 3 + + @staticmethod + def is_available(): + return True + + trainer = Trainer(accelerator=TestAccelerator(), devices=2, strategy="ddp") + assert isinstance(trainer.accelerator, TestAccelerator) + assert isinstance(trainer.strategy, DDPStrategy) + assert trainer._accelerator_connector.parallel_devices == ["foo"] * 2 + + trainer = Trainer(strategy=DDPStrategy(TestAccelerator()), devices="auto") + assert isinstance(trainer.accelerator, TestAccelerator) + assert isinstance(trainer.strategy, DDPStrategy) + assert trainer._accelerator_connector.parallel_devices == ["foo"] * 3 diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 3b54bcae74c7d..1e74cde1f70c6 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -266,28 +266,30 @@ def forward(self, x): assert torch.all(torch.eq(model.net_a.layer.weight, model.net_b.layer.weight)) +@RunIf(tpu=True) def test_tpu_invalid_raises(): strategy = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): - Trainer(strategy=strategy) + Trainer(strategy=strategy, devices=8) strategy = DDPStrategy(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"): - Trainer(strategy=strategy) + Trainer(strategy=strategy, devices=8) +@RunIf(tpu=True) def test_tpu_invalid_raises_set_precision_with_strategy(): accelerator = TPUAccelerator() strategy = TPUSpawnStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): - Trainer(strategy=strategy) + Trainer(strategy=strategy, devices=8) accelerator = TPUAccelerator() strategy = DDPStrategy(accelerator=accelerator, precision_plugin=TPUPrecisionPlugin()) with pytest.raises( ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy" ): - Trainer(strategy=strategy) + Trainer(strategy=strategy, devices=8) @RunIf(tpu=True) diff --git a/tests/conftest.py b/tests/conftest.py index 8ad7faa3cd769..4440548fa09bc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,6 +77,7 @@ def restore_env_variables(): "XRT_HOST_WORLD_SIZE", "XRT_SHARD_ORDINAL", "XRT_SHARD_LOCAL_ORDINAL", + "TF_CPP_MIN_LOG_LEVEL", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index b5713893f769b..8a074e20ee055 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -17,11 +17,11 @@ from unittest import mock import pytest +import torch import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.utilities import argparse -from tests.helpers.runif import RunIf @mock.patch("argparse.ArgumentParser.parse_args") @@ -163,11 +163,14 @@ def test_argparse_args_parsing_fast_dev_run(cli_args, expected): @pytest.mark.parametrize( ["cli_args", "expected_parsed", "expected_device_ids"], - [("", None, None), ("--accelerator gpu --devices 1", "1", [0]), ("--accelerator gpu --devices 0,", "0,", None)], + [("", None, None), ("--accelerator gpu --devices 1", "1", [0]), ("--accelerator gpu --devices 0,", "0,", [0])], ) -@RunIf(min_gpus=1) -def test_argparse_args_parsing_devices(cli_args, expected_parsed, expected_device_ids): +def test_argparse_args_parsing_devices(cli_args, expected_parsed, expected_device_ids, monkeypatch): """Test multi type argument with bool.""" + + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) + cli_args = cli_args.split(" ") if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): parser = ArgumentParser(add_help=False)