From 4879e6602eae6d03c63225492871ba1451158fb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 25 Mar 2022 19:11:42 +0100 Subject: [PATCH 1/2] Accelerator registry --- pytorch_lightning/accelerators/__init__.py | 5 - pytorch_lightning/accelerators/accelerator.py | 8 +- pytorch_lightning/accelerators/cpu.py | 11 +- pytorch_lightning/accelerators/gpu.py | 11 +- pytorch_lightning/accelerators/hpu.py | 11 +- pytorch_lightning/accelerators/ipu.py | 11 +- pytorch_lightning/accelerators/registry.py | 119 ++++++++---------- pytorch_lightning/accelerators/tpu.py | 11 +- .../strategies/strategy_registry.py | 4 +- .../connectors/accelerator_connector.py | 18 ++- .../accelerators/test_accelerator_registry.py | 58 ++++++--- 11 files changed, 132 insertions(+), 135 deletions(-) diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py index 1ab90e025b087..27e580fa5b496 100644 --- a/pytorch_lightning/accelerators/__init__.py +++ b/pytorch_lightning/accelerators/__init__.py @@ -15,9 +15,4 @@ from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401 -from pytorch_lightning.accelerators.registry import AcceleratorRegistry, call_register_accelerators # noqa: F401 from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401 - -ACCELERATORS_BASE_MODULE = "pytorch_lightning.accelerators" - -call_register_accelerators(ACCELERATORS_BASE_MODULE) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 5fe6b53dd54b5..6db4ba2f8f3c1 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -76,6 +76,8 @@ def auto_device_count() -> int: def is_available() -> bool: """Detect if the hardware is available.""" - @classmethod - def register_accelerators(cls, accelerator_registry: Dict) -> None: - pass + @staticmethod + @abstractmethod + def name() -> str: + """Name of the Accelerator.""" + raise NotImplementedError diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index 3d28a4d80f682..a027e7db6e209 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -63,10 +63,7 @@ def is_available() -> bool: """CPU is always available for execution.""" return True - @classmethod - def register_accelerators(cls, accelerator_registry: Dict) -> None: - accelerator_registry.register( - "cpu", - cls, - description=f"{cls.__class__.__name__}", - ) + @staticmethod + def name() -> str: + """Name of the Accelerator.""" + return "cpu" diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 1f74da7da3f4e..529d067025f97 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -93,13 +93,10 @@ def auto_device_count() -> int: def is_available() -> bool: return torch.cuda.device_count() > 0 - @classmethod - def register_accelerators(cls, accelerator_registry: Dict) -> None: - accelerator_registry.register( - "gpu", - cls, - description=f"{cls.__class__.__name__}", - ) + @staticmethod + def name() -> str: + """Name of the Accelerator.""" + return "gpu" def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: diff --git a/pytorch_lightning/accelerators/hpu.py b/pytorch_lightning/accelerators/hpu.py index 76fdb02b307b8..2214334d53d0d 100644 --- a/pytorch_lightning/accelerators/hpu.py +++ b/pytorch_lightning/accelerators/hpu.py @@ -60,10 +60,7 @@ def auto_device_count() -> int: def is_available() -> bool: return _HPU_AVAILABLE - @classmethod - def register_accelerators(cls, accelerator_registry: Dict) -> None: - accelerator_registry.register( - "hpu", - cls, - description=f"{cls.__class__.__name__}", - ) + @staticmethod + def name() -> str: + """Name of the Accelerator.""" + return "hpu" diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index b5110e58028a5..1e8b2bc27fe57 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -47,10 +47,7 @@ def auto_device_count() -> int: def is_available() -> bool: return _IPU_AVAILABLE - @classmethod - def register_accelerators(cls, accelerator_registry: Dict) -> None: - accelerator_registry.register( - "ipu", - cls, - description=f"{cls.__class__.__name__}", - ) + @staticmethod + def name() -> str: + """Name of the Accelerator.""" + return "ipu" diff --git a/pytorch_lightning/accelerators/registry.py b/pytorch_lightning/accelerators/registry.py index 992fa34b02aee..0565431419876 100644 --- a/pytorch_lightning/accelerators/registry.py +++ b/pytorch_lightning/accelerators/registry.py @@ -11,20 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib -from inspect import getmembers, isclass -from typing import Any, Callable, Dict, List, Optional +from typing import Any, List, Optional, Type from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.registry import _is_register_method_overridden class _AcceleratorRegistry(dict): - """This class is a Registry that stores information about the Accelerators. + """This class is a dictionary that stores information about the Accelerators. The Accelerators are mapped to strings. These strings are names that identify - an accelerator, e.g., "gpu". It also returns Optional description and + an accelerator, e.g., "gpu". It also includes an optional description and any parameters to initialize the Accelerator, which were defined during the registration. @@ -34,89 +30,76 @@ class _AcceleratorRegistry(dict): Example:: - @AcceleratorRegistry.register("sota", description="Custom sota accelerator", a=1, b=True) + @ACCELERATOR_REGISTRY class SOTAAccelerator(Accelerator): - def __init__(self, a, b): + def __init__(self, a): ... - or + @staticmethod + def name(): + return "sota" - AcceleratorRegistry.register("sota", SOTAAccelerator, description="Custom sota accelerator", a=1, b=True) + # or to pass parameters + ACCELERATOR_REGISTRY.register(SOTAAccelerator, description="My SoTA accelerator", a=1) """ + def __call__(self, *args: Any, **kwargs: Any) -> Type: + return self.register(*args, **kwargs) + def register( self, - name: str, - accelerator: Optional[Callable] = None, - description: str = "", + accelerator: Type[Accelerator], + name: Optional[str] = None, + description: Optional[str] = None, override: bool = False, - **init_params: Any, - ) -> Callable: - """Registers a accelerator mapped to a name and with required metadata. + **kwargs: Any, + ) -> Type: + """Registers an accelerator mapped to a name and with optional metadata. Args: - name : the name that identifies a accelerator, e.g. "gpu" - accelerator : accelerator class - description : accelerator description - override : overrides the registered accelerator, if True - init_params: parameters to initialize the accelerator + accelerator: The accelerator class. + name: The alias for the accelerator, e.g. ``"gpu"``. + description: An optional description. + override: Whether to override the registered accelerator. + **kwargs: parameters to initialize the accelerator. """ - if not (name is None or isinstance(name, str)): - raise TypeError(f"`name` must be a str, found {name}") - - if name in self and not override: - raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.") - - data: Dict[str, Any] = {} - - data["description"] = description - data["init_params"] = init_params - - def do_register(name: str, accelerator: Callable) -> Callable: - data["accelerator"] = accelerator - data["accelerator_name"] = name - self[name] = data - return accelerator - - if accelerator is not None: - return do_register(name, accelerator) - - return do_register - - def get(self, name: str, default: Optional[Any] = None) -> Any: + if name is None: + name = accelerator.name() + if not isinstance(name, str): + raise TypeError(f"`name` for {accelerator} must be a str, found {name!r}") + + if name not in self or override: + self[name] = { + "accelerator": accelerator, + "description": description if description is not None else accelerator.__class__.__name__, + "kwargs": kwargs, + } + return accelerator + + def get(self, name: str, default: Optional[Accelerator] = None) -> Accelerator: """Calls the registered accelerator with the required parameters and returns the accelerator object. Args: - name (str): the name that identifies a accelerator, e.g. "gpu" + name: The name that identifies a accelerator, e.g. "gpu". + default: A default value. + + Raises: + KeyError: If the key does not exist. """ if name in self: data = self[name] - return data["accelerator"](**data["init_params"]) - + return data["accelerator"](**data["kwargs"]) if default is not None: return default + raise KeyError(f"{name!r} not found in registry. {self!s}") - err_msg = "'{}' not found in registry. Available names: {}" - available_names = self.available_accelerators() - raise KeyError(err_msg.format(name, available_names)) - - def remove(self, name: str) -> None: - """Removes the registered accelerator by name.""" - self.pop(name) - - def available_accelerators(self) -> List[str]: - """Returns a list of registered accelerators.""" - return list(self.keys()) + @property + def names(self) -> List[str]: + """Returns the registered names.""" + return sorted(list(self)) def __str__(self) -> str: - return "Registered Accelerators: {}".format(", ".join(self.available_accelerators())) - - -AcceleratorRegistry = _AcceleratorRegistry() + return f"Registered Accelerators: {self.names}" -def call_register_accelerators(base_module: str) -> None: - module = importlib.import_module(base_module) - for _, mod in getmembers(module, isclass): - if issubclass(mod, Accelerator) and _is_register_method_overridden(mod, Accelerator, "register_accelerators"): - mod.register_accelerators(AcceleratorRegistry) +ACCELERATOR_REGISTRY = _AcceleratorRegistry() diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index fa8bd007cb25f..dfdc950e70124 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -65,10 +65,7 @@ def auto_device_count() -> int: def is_available() -> bool: return _TPU_AVAILABLE - @classmethod - def register_accelerators(cls, accelerator_registry: Dict) -> None: - accelerator_registry.register( - "tpu", - cls, - description=f"{cls.__class__.__name__}", - ) + @staticmethod + def name() -> str: + """Name of the Accelerator.""" + return "tpu" diff --git a/pytorch_lightning/strategies/strategy_registry.py b/pytorch_lightning/strategies/strategy_registry.py index 7dee7146d415d..84b52036abbd4 100644 --- a/pytorch_lightning/strategies/strategy_registry.py +++ b/pytorch_lightning/strategies/strategy_registry.py @@ -17,7 +17,7 @@ from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.registry import _is_register_method_overridden +from pytorch_lightning.utilities.model_helpers import is_overridden class _StrategyRegistry(dict): @@ -118,5 +118,5 @@ def __str__(self) -> str: def call_register_strategies(base_module: str) -> None: module = importlib.import_module(base_module) for _, mod in getmembers(module, isclass): - if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"): + if issubclass(mod, Strategy) and is_overridden("register_strategies", mod, Strategy): mod.register_strategies(StrategyRegistry) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 673964b13e5ab..ee8a281f82f77 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -24,7 +24,7 @@ from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.accelerators.hpu import HPUAccelerator from pytorch_lightning.accelerators.ipu import IPUAccelerator -from pytorch_lightning.accelerators.registry import AcceleratorRegistry +from pytorch_lightning.accelerators.registry import ACCELERATOR_REGISTRY from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.plugins import ( ApexMixedPrecisionPlugin, @@ -86,6 +86,7 @@ _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE, ) +from pytorch_lightning.utilities.meta import get_all_subclasses log = logging.getLogger(__name__) @@ -162,7 +163,8 @@ def __init__( # 1. Parsing flags # Get registered strategies, built-in accelerators and precision plugins self._registered_strategies = StrategyRegistry.available_strategies() - self._accelerator_types = AcceleratorRegistry.available_accelerators() + _populate_registries() + self._accelerator_types = ACCELERATOR_REGISTRY.names self._precision_types = ("16", "32", "64", "bf16", "mixed") # Raise an exception if there are conflicts between flags @@ -493,16 +495,16 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: else: assert self._accelerator_flag is not None self._accelerator_flag = self._accelerator_flag.lower() - if self._accelerator_flag not in AcceleratorRegistry: + if self._accelerator_flag not in ACCELERATOR_REGISTRY: raise MisconfigurationException( "When passing string value for the `accelerator` argument of `Trainer`," f" it can only be one of {self._accelerator_types}." ) - self.accelerator = AcceleratorRegistry.get(self._accelerator_flag) + self.accelerator = ACCELERATOR_REGISTRY.get(self._accelerator_flag) if not self.accelerator.is_available(): available_accelerator = [ - acc_str for acc_str in self._accelerator_types if AcceleratorRegistry.get(acc_str).is_available() + acc_str for acc_str in self._accelerator_types if ACCELERATOR_REGISTRY.get(acc_str).is_available() ] raise MisconfigurationException( f"{self.accelerator.__class__.__qualname__} can not run on your system" @@ -829,3 +831,9 @@ def is_distributed(self) -> bool: if isinstance(self.accelerator, TPUAccelerator): is_distributed |= self.strategy.is_distributed return is_distributed + + +def _populate_registries() -> None: + # automatically register accelerators + for cls in get_all_subclasses(Accelerator): + ACCELERATOR_REGISTRY(cls) diff --git a/tests/accelerators/test_accelerator_registry.py b/tests/accelerators/test_accelerator_registry.py index 4e2b521873408..b77c41e44860c 100644 --- a/tests/accelerators/test_accelerator_registry.py +++ b/tests/accelerators/test_accelerator_registry.py @@ -11,17 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from pytorch_lightning import Trainer -from pytorch_lightning.accelerators import Accelerator, AcceleratorRegistry +from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.accelerators.registry import ACCELERATOR_REGISTRY +from pytorch_lightning.trainer.connectors.accelerator_connector import _populate_registries -def test_accelerator_registry_with_new_accelerator(): +@pytest.fixture(autouse=True) +def clear_registries(): + # since the registries are global, it's good to clear them after each test to avoid unwanted interactions + yield + ACCELERATOR_REGISTRY.clear() - accelerator_name = "custom_accelerator" - accelerator_description = "Custom Accelerator" + +def test_accelerator_registry_with_new_accelerator(): + name = "custom" + description = "My custom Accelerator" class CustomAccelerator(Accelerator): - def __init__(self, param1, param2): + def __init__(self, param1=None, param2=None): self.param1 = param1 self.param2 = param2 super().__init__() @@ -42,25 +52,39 @@ def auto_device_count(): def is_available(): return True - AcceleratorRegistry.register( - accelerator_name, CustomAccelerator, description=accelerator_description, param1="abc", param2=123 - ) + @staticmethod + def name(): + return "custom" + + ACCELERATOR_REGISTRY.register(CustomAccelerator, name=name, description=description, param1="abc") - assert accelerator_name in AcceleratorRegistry + assert name in ACCELERATOR_REGISTRY + assert ACCELERATOR_REGISTRY[name] == { + "accelerator": CustomAccelerator, + "description": description, + "kwargs": {"param1": "abc"}, + } + instance = ACCELERATOR_REGISTRY.get(name) + assert isinstance(instance, CustomAccelerator) + assert instance.param1 == "abc" - assert AcceleratorRegistry[accelerator_name]["description"] == accelerator_description - assert AcceleratorRegistry[accelerator_name]["init_params"] == {"param1": "abc", "param2": 123} - assert AcceleratorRegistry[accelerator_name]["accelerator_name"] == accelerator_name + assert ACCELERATOR_REGISTRY.get("foo", 123) == 123 - assert isinstance(AcceleratorRegistry.get(accelerator_name), CustomAccelerator) + ACCELERATOR_REGISTRY.clear() - trainer = Trainer(accelerator=accelerator_name, devices="auto") + trainer = Trainer(accelerator=name, devices="auto") assert isinstance(trainer.accelerator, CustomAccelerator) assert trainer.strategy.parallel_devices == ["foo"] * 3 - AcceleratorRegistry.remove(accelerator_name) - assert accelerator_name not in AcceleratorRegistry + @ACCELERATOR_REGISTRY + class NewAccelerator(CustomAccelerator): + @staticmethod + def name(): + return "new" + + assert "new" in ACCELERATOR_REGISTRY def test_available_accelerators_in_registry(): - assert AcceleratorRegistry.available_accelerators() == ["cpu", "gpu", "hpu", "ipu", "tpu"] + _populate_registries() + assert ACCELERATOR_REGISTRY.names == ["cpu", "gpu", "hpu", "ipu", "tpu"] From 07703693f59ece39819262bd8fa8a50cef565333 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Mar 2022 02:13:15 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/accelerator_connector.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 86d60f5b6e7f0..8364a0c95b383 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -79,12 +79,7 @@ rank_zero_warn, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import ( - _HOROVOD_AVAILABLE, - _HPU_AVAILABLE, - _IPU_AVAILABLE, - _TPU_AVAILABLE, -) +from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE from pytorch_lightning.utilities.meta import get_all_subclasses log = logging.getLogger(__name__)