From 4f22ae1feefa83eb54dae6f089f7920bc7b7a000 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 2 Mar 2022 17:31:53 +0530 Subject: [PATCH 01/21] Add AcceleratorRegistry --- pytorch_lightning/accelerators/__init__.py | 8 +++ pytorch_lightning/accelerators/registry.py | 78 ++++++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 pytorch_lightning/accelerators/registry.py diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py index 1c9e0024f39bd..d58defbfeb460 100644 --- a/pytorch_lightning/accelerators/__init__.py +++ b/pytorch_lightning/accelerators/__init__.py @@ -10,8 +10,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401 from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401 +from pytorch_lightning.accelerators.registry import AcceleratorRegistry, register_accelerators # noqa: F401 from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401 + +FILE_ROOT = Path(__file__).parent +ACCELERATORS_BASE_MODULE = "pytorch_lightning.accelerators" + +register_accelerators(FILE_ROOT, ACCELERATORS_BASE_MODULE) diff --git a/pytorch_lightning/accelerators/registry.py b/pytorch_lightning/accelerators/registry.py new file mode 100644 index 0000000000000..62dd1d4e8a381 --- /dev/null +++ b/pytorch_lightning/accelerators/registry.py @@ -0,0 +1,78 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from inspect import getmembers, isclass +from pathlib import Path +from typing import Any, List, Optional + +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class _AcceleratorRegistry(dict): + def register(self, accelerator: Accelerator, name: Optional[str] = None, override: bool = False) -> Accelerator: + """Registers an accelerator mapped to a name. + + Args: + accelerator: the accelerator to be mapped. + name: the name that identifies the provided accelerator. + override: Whether to override an existing key. + """ + if name is None: + name = accelerator.name() + elif not isinstance(name, str): + raise TypeError(f"`name` must be a str, found {name}") + + if name in self and not override: + raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.") + self[name] = accelerator + return accelerator + + def get(self, name: str, default: Optional[Any] = None) -> Any: + """Calls the registered Accelerator and returns the Accelerator object. + + Args: + name (str): the name that identifies a Accelerator, e.g. "tpu" + """ + if name in self: + accelerator = self[name] + return accelerator() + + if default is not None: + return default + + err_msg = "'{}' not found in registry. Available names: {}" + available_names = ", ".join(sorted(self.keys())) or "none" + raise KeyError(err_msg.format(name, available_names)) + + def remove(self, name: str) -> None: + """Removes the registered accelerator by name.""" + self.pop(name) + + def available_accelerators(self) -> List: + """Returns a list of registered accelerators.""" + return list(self.keys()) + + def __str__(self) -> str: + return "Registered Accelerators: {}".format(", ".join(self.keys())) + + +AcceleratorRegistry = _AcceleratorRegistry() + + +def register_accelerators(root: Path, base_module: str) -> None: + module = importlib.import_module(base_module) + for _, mod in getmembers(module, isclass): + if issubclass(mod, Accelerator) and mod is not Accelerator: + AcceleratorRegistry.register(mod) From ec17fbed9e6f0cab4a6042fe7c0a877fac91d8f4 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 2 Mar 2022 17:51:09 +0530 Subject: [PATCH 02/21] Add tests & update changelog --- CHANGELOG.md | 3 + .../accelerators/test_accelerator_registry.py | 57 +++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 tests/accelerators/test_accelerator_registry.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b90cf41b0736b..031b573461302 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -164,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Callback.state_dict()` and `Callback.load_state_dict()` methods ([#12232](https://github.com/PyTorchLightning/pytorch-lightning/pull/12232)) +- Added `AcceleratorRegistry` ([#12180](https://github.com/PyTorchLightning/pytorch-lightning/pull/12180)) + + ### Changed - Drop PyTorch 1.7 support ([#12191](https://github.com/PyTorchLightning/pytorch-lightning/pull/12191)) diff --git a/tests/accelerators/test_accelerator_registry.py b/tests/accelerators/test_accelerator_registry.py new file mode 100644 index 0000000000000..4ecb945a46c3b --- /dev/null +++ b/tests/accelerators/test_accelerator_registry.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Trainer +from pytorch_lightning.accelerators import Accelerator, AcceleratorRegistry + + +def test_accelerator_registry_with_new_accelerator(): + + accelerator_name = "custom_accelerator" + + class TestAccelerator(Accelerator): + @staticmethod + def parse_devices(devices): + return devices + + @staticmethod + def get_parallel_devices(devices): + return ["foo"] * devices + + @staticmethod + def auto_device_count(): + return 3 + + @staticmethod + def is_available(): + return True + + @staticmethod + def name(): + return accelerator_name + + AcceleratorRegistry.register(TestAccelerator) + + assert accelerator_name in AcceleratorRegistry + assert isinstance(AcceleratorRegistry.get(accelerator_name), TestAccelerator) + + trainer = Trainer(accelerator=TestAccelerator(), devices="auto") + assert isinstance(trainer.accelerator, TestAccelerator) + assert trainer._accelerator_connector.parallel_devices == ["foo"] * 3 + + AcceleratorRegistry.remove(accelerator_name) + assert accelerator_name not in AcceleratorRegistry + + +def test_available_accelerators_in_registry(): + assert AcceleratorRegistry.available_accelerators() == ["cpu", "gpu", "ipu", "tpu"] From 25aca8ad28b5b63c696af7895322ca8fdf05d134 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 2 Mar 2022 17:58:25 +0530 Subject: [PATCH 03/21] Update accelerator connector --- .../connectors/accelerator_connector.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 0d2013c1606cf..8415c615e0b8f 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -23,6 +23,7 @@ from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.accelerators.ipu import IPUAccelerator +from pytorch_lightning.accelerators.registry import AcceleratorRegistry from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.plugins import ( ApexMixedPrecisionPlugin, @@ -486,27 +487,24 @@ def _choose_accelerator(self) -> str: return "cpu" def _set_parallel_devices_and_init_accelerator(self) -> None: - ACCELERATORS = { - "cpu": CPUAccelerator, - "gpu": GPUAccelerator, - "tpu": TPUAccelerator, - "ipu": IPUAccelerator, - } if isinstance(self._accelerator_flag, Accelerator): self.accelerator: Accelerator = self._accelerator_flag else: assert self._accelerator_flag is not None self._accelerator_flag = self._accelerator_flag.lower() - if self._accelerator_flag not in ACCELERATORS: + if self._accelerator_flag not in AcceleratorRegistry: raise MisconfigurationException( "When passing string value for the `accelerator` argument of `Trainer`," - f" it can only be one of {list(ACCELERATORS)}." + f" it can only be one of {AcceleratorRegistry.available_accelerators()}." ) - accelerator_class = ACCELERATORS[self._accelerator_flag] - self.accelerator = accelerator_class() # type: ignore[abstract] + self.accelerator = AcceleratorRegistry.get(self._accelerator_flag) if not self.accelerator.is_available(): - available_accelerator = [acc_str for acc_str in list(ACCELERATORS) if ACCELERATORS[acc_str].is_available()] + available_accelerator = [ + acc_str + for acc_str in AcceleratorRegistry.available_accelerators() + if AcceleratorRegistry[acc_str].is_available() + ] raise MisconfigurationException( f"{self.accelerator.__class__.__qualname__} can not run on your system" f" since {self.accelerator.name().upper()}s are not available." From ad81525f7e2b55850268c9d23ff56f9265b20873 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 2 Mar 2022 18:07:24 +0530 Subject: [PATCH 04/21] Update test --- .../trainer/connectors/accelerator_connector.py | 8 +++----- tests/accelerators/test_accelerator_registry.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 8415c615e0b8f..aadd8ef92d65c 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -158,7 +158,7 @@ def __init__( # 1. Parsing flags # Get registered strategies, built-in accelerators and precision plugins self._registered_strategies = StrategyRegistry.available_strategies() - self._accelerator_types = ("tpu", "ipu", "gpu", "cpu") + self._accelerator_types = AcceleratorRegistry.available_accelerators() self._precision_types = ("16", "32", "64", "bf16", "mixed") # Raise an exception if there are conflicts between flags @@ -495,15 +495,13 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: if self._accelerator_flag not in AcceleratorRegistry: raise MisconfigurationException( "When passing string value for the `accelerator` argument of `Trainer`," - f" it can only be one of {AcceleratorRegistry.available_accelerators()}." + f" it can only be one of {self._accelerator_types}." ) self.accelerator = AcceleratorRegistry.get(self._accelerator_flag) if not self.accelerator.is_available(): available_accelerator = [ - acc_str - for acc_str in AcceleratorRegistry.available_accelerators() - if AcceleratorRegistry[acc_str].is_available() + acc_str for acc_str in self._accelerator_types if AcceleratorRegistry[acc_str].is_available() ] raise MisconfigurationException( f"{self.accelerator.__class__.__qualname__} can not run on your system" diff --git a/tests/accelerators/test_accelerator_registry.py b/tests/accelerators/test_accelerator_registry.py index 4ecb945a46c3b..6d7617cc2ef41 100644 --- a/tests/accelerators/test_accelerator_registry.py +++ b/tests/accelerators/test_accelerator_registry.py @@ -45,7 +45,7 @@ def name(): assert accelerator_name in AcceleratorRegistry assert isinstance(AcceleratorRegistry.get(accelerator_name), TestAccelerator) - trainer = Trainer(accelerator=TestAccelerator(), devices="auto") + trainer = Trainer(accelerator="custom_accelerator", devices="auto") assert isinstance(trainer.accelerator, TestAccelerator) assert trainer._accelerator_connector.parallel_devices == ["foo"] * 3 From ad3c87e0814446f89ed79dc34d2e53cfd417b661 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Tue, 22 Mar 2022 15:33:26 +0530 Subject: [PATCH 05/21] Update tests/accelerators/test_accelerator_registry.py Co-authored-by: Rohit Gupta --- tests/accelerators/test_accelerator_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/accelerators/test_accelerator_registry.py b/tests/accelerators/test_accelerator_registry.py index 6d7617cc2ef41..5cd3577cfe8d0 100644 --- a/tests/accelerators/test_accelerator_registry.py +++ b/tests/accelerators/test_accelerator_registry.py @@ -45,7 +45,7 @@ def name(): assert accelerator_name in AcceleratorRegistry assert isinstance(AcceleratorRegistry.get(accelerator_name), TestAccelerator) - trainer = Trainer(accelerator="custom_accelerator", devices="auto") + trainer = Trainer(accelerator=accelerator_name, devices="auto") assert isinstance(trainer.accelerator, TestAccelerator) assert trainer._accelerator_connector.parallel_devices == ["foo"] * 3 From 84379ef7d7a53bd114128a1210c82dab21d6b818 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 23 Mar 2022 02:25:02 +0530 Subject: [PATCH 06/21] Implement logic for register_accelerators --- pytorch_lightning/accelerators/__init__.py | 7 +- pytorch_lightning/accelerators/accelerator.py | 7 +- pytorch_lightning/accelerators/cpu.py | 11 ++- pytorch_lightning/accelerators/gpu.py | 11 ++- pytorch_lightning/accelerators/ipu.py | 11 ++- pytorch_lightning/accelerators/registry.py | 82 ++++++++++++++----- pytorch_lightning/accelerators/tpu.py | 11 ++- pytorch_lightning/strategies/__init__.py | 5 +- .../strategies/strategy_registry.py | 21 +---- pytorch_lightning/utilities/registry.py | 27 ++++++ 10 files changed, 127 insertions(+), 66 deletions(-) create mode 100644 pytorch_lightning/utilities/registry.py diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py index d58defbfeb460..fe9ae0a120cfb 100644 --- a/pytorch_lightning/accelerators/__init__.py +++ b/pytorch_lightning/accelerators/__init__.py @@ -10,16 +10,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path - from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401 from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401 -from pytorch_lightning.accelerators.registry import AcceleratorRegistry, register_accelerators # noqa: F401 +from pytorch_lightning.accelerators.registry import AcceleratorRegistry, call_register_accelerators # noqa: F401 from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401 -FILE_ROOT = Path(__file__).parent ACCELERATORS_BASE_MODULE = "pytorch_lightning.accelerators" -register_accelerators(FILE_ROOT, ACCELERATORS_BASE_MODULE) +call_register_accelerators(ACCELERATORS_BASE_MODULE) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index ad0779d88b96c..f63545a1e42e1 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -75,7 +75,6 @@ def auto_device_count() -> int: def is_available() -> bool: """Detect if the hardware is available.""" - @staticmethod - @abstractmethod - def name() -> str: - """Name of the Accelerator.""" + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + raise NotImplementedError(f"`register_accelerators` is not implemented for {cls.__name__}") diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index a027e7db6e209..3d28a4d80f682 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -63,7 +63,10 @@ def is_available() -> bool: """CPU is always available for execution.""" return True - @staticmethod - def name() -> str: - """Name of the Accelerator.""" - return "cpu" + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "cpu", + cls, + description=f"{cls.__class__.__name__}", + ) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 529d067025f97..1f74da7da3f4e 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -93,10 +93,13 @@ def auto_device_count() -> int: def is_available() -> bool: return torch.cuda.device_count() > 0 - @staticmethod - def name() -> str: - """Name of the Accelerator.""" - return "gpu" + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "gpu", + cls, + description=f"{cls.__class__.__name__}", + ) def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index 1e8b2bc27fe57..b5110e58028a5 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -47,7 +47,10 @@ def auto_device_count() -> int: def is_available() -> bool: return _IPU_AVAILABLE - @staticmethod - def name() -> str: - """Name of the Accelerator.""" - return "ipu" + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "ipu", + cls, + description=f"{cls.__class__.__name__}", + ) diff --git a/pytorch_lightning/accelerators/registry.py b/pytorch_lightning/accelerators/registry.py index 62dd1d4e8a381..80a8784e1ddb2 100644 --- a/pytorch_lightning/accelerators/registry.py +++ b/pytorch_lightning/accelerators/registry.py @@ -13,41 +13,85 @@ # limitations under the License. import importlib from inspect import getmembers, isclass -from pathlib import Path -from typing import Any, List, Optional +from typing import Any, Callable, Dict, List, Optional from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.registry import _is_register_method_overridden class _AcceleratorRegistry(dict): - def register(self, accelerator: Accelerator, name: Optional[str] = None, override: bool = False) -> Accelerator: - """Registers an accelerator mapped to a name. + """This class is a Registry that stores information about the Accelerators. + + The Accelerators are mapped to strings. These strings are names that identify + an accelerator, e.g., "gpu". It also returns Optional description and + parameters to initialize the Accelerator, which were defined during the + registration. + + The motivation for having a AcceleratorRegistry is to make it convenient + for the Users to try different accelerators by passing mapped aliases + to the accelerator flag to the Trainer. + + Example:: + + @AcceleratorRegistry.register("sota", description="Custom sota accelerator", a=1, b=True) + class SOTAAccelerator(Accelerator): + def __init__(self, a, b): + ... + + or + + AcceleratorRegistry.register("sota", SOTAAAccelerator, description="Custom sota accelerator", a=1, b=True) + """ + + def register( + self, + name: str, + accelerator: Optional[Callable] = None, + description: Optional[str] = None, + override: bool = False, + **init_params: Any, + ) -> Callable: + """Registers a accelerator mapped to a name and with required metadata. Args: - accelerator: the accelerator to be mapped. - name: the name that identifies the provided accelerator. - override: Whether to override an existing key. + name : the name that identifies a accelerator, e.g. "gpu" + accelerator : accelerator class + description : accelerator description + override : overrides the registered accelerator, if True + init_params: parameters to initialize the accelerator """ - if name is None: - name = accelerator.name() - elif not isinstance(name, str): + if not (name is None or isinstance(name, str)): raise TypeError(f"`name` must be a str, found {name}") if name in self and not override: raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.") - self[name] = accelerator - return accelerator + + data: Dict[str, Any] = {} + data["description"] = description if description is not None else "" + + data["init_params"] = init_params + + def do_register(name: str, accelerator: Callable) -> Callable: + data["accelerator"] = accelerator + data["accelerator_name"] = name + self[name] = data + return accelerator + + if accelerator is not None: + return do_register(name, accelerator) + + return do_register def get(self, name: str, default: Optional[Any] = None) -> Any: - """Calls the registered Accelerator and returns the Accelerator object. + """Calls the registered accelerator with the required parameters and returns the accelerator object. Args: - name (str): the name that identifies a Accelerator, e.g. "tpu" + name (str): the name that identifies a accelerator, e.g. "gpu" """ if name in self: - accelerator = self[name] - return accelerator() + data = self[name] + return data["accelerator"](**data["init_params"]) if default is not None: return default @@ -71,8 +115,8 @@ def __str__(self) -> str: AcceleratorRegistry = _AcceleratorRegistry() -def register_accelerators(root: Path, base_module: str) -> None: +def call_register_accelerators(base_module: str) -> None: module = importlib.import_module(base_module) for _, mod in getmembers(module, isclass): - if issubclass(mod, Accelerator) and mod is not Accelerator: - AcceleratorRegistry.register(mod) + if issubclass(mod, Accelerator) and _is_register_method_overridden(mod, Accelerator, "register_accelerators"): + mod.register_accelerators(AcceleratorRegistry) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index dfdc950e70124..fa8bd007cb25f 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -65,7 +65,10 @@ def auto_device_count() -> int: def is_available() -> bool: return _TPU_AVAILABLE - @staticmethod - def name() -> str: - """Name of the Accelerator.""" - return "tpu" + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "tpu", + cls, + description=f"{cls.__class__.__name__}", + ) diff --git a/pytorch_lightning/strategies/__init__.py b/pytorch_lightning/strategies/__init__.py index f06edfa53ec7a..a4cd57a50ac1d 100644 --- a/pytorch_lightning/strategies/__init__.py +++ b/pytorch_lightning/strategies/__init__.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path - from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401 from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401 from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401 @@ -31,7 +29,6 @@ from pytorch_lightning.strategies.strategy_registry import call_register_strategies, StrategyRegistry # noqa: F401 from pytorch_lightning.strategies.tpu_spawn import TPUSpawnStrategy # noqa: F401 -FILE_ROOT = Path(__file__).parent STRATEGIES_BASE_MODULE = "pytorch_lightning.strategies" -call_register_strategies(FILE_ROOT, STRATEGIES_BASE_MODULE) +call_register_strategies(STRATEGIES_BASE_MODULE) diff --git a/pytorch_lightning/strategies/strategy_registry.py b/pytorch_lightning/strategies/strategy_registry.py index 17e08acb23bcc..7dee7146d415d 100644 --- a/pytorch_lightning/strategies/strategy_registry.py +++ b/pytorch_lightning/strategies/strategy_registry.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib -import inspect from inspect import getmembers, isclass -from pathlib import Path from typing import Any, Callable, Dict, List, Optional from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.registry import _is_register_method_overridden class _StrategyRegistry(dict): @@ -116,22 +115,8 @@ def __str__(self) -> str: StrategyRegistry = _StrategyRegistry() -def is_register_strategies_overridden(strategy: type) -> bool: - - method_name = "register_strategies" - strategy_attr = getattr(strategy, method_name) - previous_super_cls = inspect.getmro(strategy)[1] - - if issubclass(previous_super_cls, Strategy): - super_attr = getattr(previous_super_cls, method_name) - else: - return False - - return strategy_attr.__code__ is not super_attr.__code__ - - -def call_register_strategies(root: Path, base_module: str) -> None: +def call_register_strategies(base_module: str) -> None: module = importlib.import_module(base_module) for _, mod in getmembers(module, isclass): - if issubclass(mod, Strategy) and is_register_strategies_overridden(mod): + if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"): mod.register_strategies(StrategyRegistry) diff --git a/pytorch_lightning/utilities/registry.py b/pytorch_lightning/utilities/registry.py new file mode 100644 index 0000000000000..ff0d5e8a94413 --- /dev/null +++ b/pytorch_lightning/utilities/registry.py @@ -0,0 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Callable + + +def _is_register_method_overridden(mod: type, base_cls: Callable, method: str) -> bool: + mod_attr = getattr(mod, method) + previous_super_cls = inspect.getmro(mod)[1] + + if issubclass(previous_super_cls, base_cls): + super_attr = getattr(previous_super_cls, method) + else: + return False + + return mod_attr.__code__ is not super_attr.__code__ From 5d736d9a06f8119dedfded774626f34077c3bff2 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 23 Mar 2022 02:34:55 +0530 Subject: [PATCH 07/21] Update tests --- .../accelerators/test_accelerator_registry.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/tests/accelerators/test_accelerator_registry.py b/tests/accelerators/test_accelerator_registry.py index 5cd3577cfe8d0..57eeb907b1c3b 100644 --- a/tests/accelerators/test_accelerator_registry.py +++ b/tests/accelerators/test_accelerator_registry.py @@ -18,8 +18,14 @@ def test_accelerator_registry_with_new_accelerator(): accelerator_name = "custom_accelerator" + accelerator_description = "Custom Accelerator" + + class CustomAccelerator(Accelerator): + def __init__(self, param1, param2): + self.param1 = param1 + self.param2 = param2 + super().__init__() - class TestAccelerator(Accelerator): @staticmethod def parse_devices(devices): return devices @@ -36,17 +42,28 @@ def auto_device_count(): def is_available(): return True - @staticmethod - def name(): - return accelerator_name + @classmethod + def register_accelerators(cls, accelerator_registry) -> None: + accelerator_registry.register( + accelerator_name, + cls, + description=f"{cls.__class__.__name__}", + ) - AcceleratorRegistry.register(TestAccelerator) + AcceleratorRegistry.register( + accelerator_name, CustomAccelerator, description=accelerator_description, param1="abc", param2=123 + ) assert accelerator_name in AcceleratorRegistry - assert isinstance(AcceleratorRegistry.get(accelerator_name), TestAccelerator) + + assert AcceleratorRegistry[accelerator_name]["description"] == accelerator_description + assert AcceleratorRegistry[accelerator_name]["init_params"] == {"param1": "abc", "param2": 123} + assert AcceleratorRegistry[accelerator_name]["accelerator_name"] == accelerator_name + + assert isinstance(AcceleratorRegistry.get(accelerator_name), CustomAccelerator) trainer = Trainer(accelerator=accelerator_name, devices="auto") - assert isinstance(trainer.accelerator, TestAccelerator) + assert isinstance(trainer.accelerator, CustomAccelerator) assert trainer._accelerator_connector.parallel_devices == ["foo"] * 3 AcceleratorRegistry.remove(accelerator_name) From 42061984f78fc2bce41dde83a4abdfa124c3fa95 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Wed, 23 Mar 2022 02:26:01 +0530 Subject: [PATCH 08/21] Update pytorch_lightning/accelerators/accelerator.py --- pytorch_lightning/accelerators/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index f63545a1e42e1..2e2414cd71c10 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -77,4 +77,4 @@ def is_available() -> bool: @classmethod def register_accelerators(cls, accelerator_registry: Dict) -> None: - raise NotImplementedError(f"`register_accelerators` is not implemented for {cls.__name__}") + raise NotImplementedError(f"`register_accelerators` is not implemented for {cls.__class__.__name__}") From f7f5eda4addd91e4d4b2cb41b77704cee2bcff4b Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Wed, 23 Mar 2022 02:26:26 +0530 Subject: [PATCH 09/21] Update pytorch_lightning/accelerators/registry.py --- pytorch_lightning/accelerators/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/registry.py b/pytorch_lightning/accelerators/registry.py index 80a8784e1ddb2..090ae653b0eeb 100644 --- a/pytorch_lightning/accelerators/registry.py +++ b/pytorch_lightning/accelerators/registry.py @@ -41,7 +41,7 @@ def __init__(self, a, b): or - AcceleratorRegistry.register("sota", SOTAAAccelerator, description="Custom sota accelerator", a=1, b=True) + AcceleratorRegistry.register("sota", SOTAAccelerator, description="Custom sota accelerator", a=1, b=True) """ def register( From ef188a14591b0d98c9ac74cbd17a9ae5db6a7858 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 23 Mar 2022 02:58:26 +0530 Subject: [PATCH 10/21] Fix tests --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index aadd8ef92d65c..15e7e06571aa1 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -501,7 +501,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: if not self.accelerator.is_available(): available_accelerator = [ - acc_str for acc_str in self._accelerator_types if AcceleratorRegistry[acc_str].is_available() + acc_str for acc_str in self._accelerator_types if AcceleratorRegistry.get(acc_str).is_available() ] raise MisconfigurationException( f"{self.accelerator.__class__.__qualname__} can not run on your system" From eb22e015681c428ae2f60b870d959be685d4b377 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 23 Mar 2022 14:29:49 +0530 Subject: [PATCH 11/21] Update exception message --- .../trainer/connectors/accelerator_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 15e7e06571aa1..e1cf4c6232f90 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -505,9 +505,9 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: ] raise MisconfigurationException( f"{self.accelerator.__class__.__qualname__} can not run on your system" - f" since {self.accelerator.name().upper()}s are not available." - " The following accelerator(s) is available and can be passed into" - f" `accelerator` argument of `Trainer`: {available_accelerator}." + " since the accelerator is not available. The following accelerator(s)" + " is available and can be passed into `accelerator` argument of" + f" `Trainer`: {available_accelerator}." ) self._set_devices_flag_if_auto_passed() From cdb0572ca741943256b0cc07826542c4bd97f8d3 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 23 Mar 2022 16:16:56 +0530 Subject: [PATCH 12/21] Update test --- tests/accelerators/test_accelerator_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index fce6a7fae2502..794cb9b2922cd 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -498,7 +498,8 @@ def test_accelerator_cpu(_): with pytest.raises(MisconfigurationException, match="You requested gpu:"): trainer = Trainer(gpus=1) with pytest.raises( - MisconfigurationException, match="GPUAccelerator can not run on your system since GPUs are not available." + MisconfigurationException, + match="GPUAccelerator can not run on your system since the accelerator is not available.", ): trainer = Trainer(accelerator="gpu") with pytest.raises(MisconfigurationException, match="You requested gpu:"): From c8ea7f21b783e776d7c1fbab5f0a05a16c52bef9 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 23 Mar 2022 16:26:56 +0530 Subject: [PATCH 13/21] Remove NotImplementedError --- pytorch_lightning/accelerators/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 2e2414cd71c10..526cec3e47319 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -77,4 +77,4 @@ def is_available() -> bool: @classmethod def register_accelerators(cls, accelerator_registry: Dict) -> None: - raise NotImplementedError(f"`register_accelerators` is not implemented for {cls.__class__.__name__}") + pass From 8c96d9c440fe54f29a97950c9e2fb05cc9e0bfa0 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 23 Mar 2022 16:28:14 +0530 Subject: [PATCH 14/21] Update test --- tests/accelerators/test_accelerator_registry.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/accelerators/test_accelerator_registry.py b/tests/accelerators/test_accelerator_registry.py index 57eeb907b1c3b..b21cd95e33cbd 100644 --- a/tests/accelerators/test_accelerator_registry.py +++ b/tests/accelerators/test_accelerator_registry.py @@ -42,14 +42,6 @@ def auto_device_count(): def is_available(): return True - @classmethod - def register_accelerators(cls, accelerator_registry) -> None: - accelerator_registry.register( - accelerator_name, - cls, - description=f"{cls.__class__.__name__}", - ) - AcceleratorRegistry.register( accelerator_name, CustomAccelerator, description=accelerator_description, param1="abc", param2=123 ) From d18091d0a8e37f34fc44242101573c0f86fcb322 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 24 Mar 2022 00:32:10 +0530 Subject: [PATCH 15/21] Update pytorch_lightning/accelerators/registry.py Co-authored-by: ananthsub --- pytorch_lightning/accelerators/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/registry.py b/pytorch_lightning/accelerators/registry.py index 090ae653b0eeb..ba5fc71f925a0 100644 --- a/pytorch_lightning/accelerators/registry.py +++ b/pytorch_lightning/accelerators/registry.py @@ -48,7 +48,7 @@ def register( self, name: str, accelerator: Optional[Callable] = None, - description: Optional[str] = None, + description: str = "", override: bool = False, **init_params: Any, ) -> Callable: From a5925d9effa507804d96936aa3e8e2e48676c7a1 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 24 Mar 2022 00:32:15 +0530 Subject: [PATCH 16/21] Update pytorch_lightning/accelerators/registry.py Co-authored-by: ananthsub --- pytorch_lightning/accelerators/registry.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/accelerators/registry.py b/pytorch_lightning/accelerators/registry.py index ba5fc71f925a0..a9b5ac5cf9f0e 100644 --- a/pytorch_lightning/accelerators/registry.py +++ b/pytorch_lightning/accelerators/registry.py @@ -68,7 +68,6 @@ def register( raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.") data: Dict[str, Any] = {} - data["description"] = description if description is not None else "" data["init_params"] = init_params From 75884c6e1d02f72d1e5b75313a67ee194c8ca7ff Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 24 Mar 2022 00:32:45 +0530 Subject: [PATCH 17/21] Update pytorch_lightning/accelerators/registry.py Co-authored-by: ananthsub --- pytorch_lightning/accelerators/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/registry.py b/pytorch_lightning/accelerators/registry.py index a9b5ac5cf9f0e..4a16e6ca545eb 100644 --- a/pytorch_lightning/accelerators/registry.py +++ b/pytorch_lightning/accelerators/registry.py @@ -103,7 +103,7 @@ def remove(self, name: str) -> None: """Removes the registered accelerator by name.""" self.pop(name) - def available_accelerators(self) -> List: + def available_accelerators(self) -> List[str]: """Returns a list of registered accelerators.""" return list(self.keys()) From 197dbaa3d53796c52ddaf6b95cc0cf96662a7fac Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 24 Mar 2022 00:32:57 +0530 Subject: [PATCH 18/21] Update pytorch_lightning/accelerators/registry.py Co-authored-by: ananthsub --- pytorch_lightning/accelerators/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/registry.py b/pytorch_lightning/accelerators/registry.py index 4a16e6ca545eb..b7e305c7eeedd 100644 --- a/pytorch_lightning/accelerators/registry.py +++ b/pytorch_lightning/accelerators/registry.py @@ -96,7 +96,7 @@ def get(self, name: str, default: Optional[Any] = None) -> Any: return default err_msg = "'{}' not found in registry. Available names: {}" - available_names = ", ".join(sorted(self.keys())) or "none" + available_names = self.available_accelerators() raise KeyError(err_msg.format(name, available_names)) def remove(self, name: str) -> None: From b0e11200a938fbeb14e8ecd6357e8f58157c7c7e Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 24 Mar 2022 00:33:03 +0530 Subject: [PATCH 19/21] Update pytorch_lightning/accelerators/registry.py Co-authored-by: ananthsub --- pytorch_lightning/accelerators/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/registry.py b/pytorch_lightning/accelerators/registry.py index b7e305c7eeedd..f6bd3aa83328b 100644 --- a/pytorch_lightning/accelerators/registry.py +++ b/pytorch_lightning/accelerators/registry.py @@ -108,7 +108,7 @@ def available_accelerators(self) -> List[str]: return list(self.keys()) def __str__(self) -> str: - return "Registered Accelerators: {}".format(", ".join(self.keys())) + return "Registered Accelerators: {}".format(", ".join(self.available_accelerators())) AcceleratorRegistry = _AcceleratorRegistry() From d1de49ec2c56dbb5877aa5b17d2137d8ee31f0a7 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 24 Mar 2022 02:03:29 +0530 Subject: [PATCH 20/21] Fix tests --- pytorch_lightning/accelerators/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/accelerators/registry.py b/pytorch_lightning/accelerators/registry.py index f6bd3aa83328b..992fa34b02aee 100644 --- a/pytorch_lightning/accelerators/registry.py +++ b/pytorch_lightning/accelerators/registry.py @@ -69,6 +69,7 @@ def register( data: Dict[str, Any] = {} + data["description"] = description data["init_params"] = init_params def do_register(name: str, accelerator: Callable) -> Callable: From 096afb937ce96f4068c0b926f26d89f72f075b19 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 24 Mar 2022 02:08:13 +0530 Subject: [PATCH 21/21] Fix mypy --- pytorch_lightning/utilities/registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/registry.py b/pytorch_lightning/utilities/registry.py index ff0d5e8a94413..83970e885bdcd 100644 --- a/pytorch_lightning/utilities/registry.py +++ b/pytorch_lightning/utilities/registry.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Callable +from typing import Any -def _is_register_method_overridden(mod: type, base_cls: Callable, method: str) -> bool: +def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool: mod_attr = getattr(mod, method) previous_super_cls = inspect.getmro(mod)[1]