Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,4 @@
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.registry import AcceleratorRegistry, call_register_accelerators # noqa: F401
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401

ACCELERATORS_BASE_MODULE = "pytorch_lightning.accelerators"

call_register_accelerators(ACCELERATORS_BASE_MODULE)
8 changes: 5 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def auto_device_count() -> int:
def is_available() -> bool:
"""Detect if the hardware is available."""

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
pass
@staticmethod
@abstractmethod
def name() -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The discussion for removing the name method. #12180 (comment)

Copy link
Contributor Author

@carmocca carmocca Mar 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having name is useful and clearly separates the accelerator from the registry. Otherwise the registry logic (a protected class) leaks inside the Accelerator (a stable class).

It also allows registering the accelerators automatically as shown in this PR.

One drawback I see is that one class couldn't register multiple instances of itself. But we could create subclasses with those defaults.

class MyAccel(Accelerator):
    def __init__(self, a=None):
        ...

    @staticmethod
    @abstractmethod
    def name():
        return "my_accel"



class MyAccelA2(MyAccel):
    def __init__(self, a=2):
        super().__init__(a=a)

    @staticmethod
    @abstractmethod
    def name():
        return "my_accel_a_2"

cc @justusschock @ananthsub as participants from prev discussion

"""Name of the Accelerator."""
raise NotImplementedError
11 changes: 4 additions & 7 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ def is_available() -> bool:
"""CPU is always available for execution."""
return True

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"cpu",
cls,
description=f"{cls.__class__.__name__}",
)
@staticmethod
def name() -> str:
"""Name of the Accelerator."""
return "cpu"
11 changes: 4 additions & 7 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,10 @@ def auto_device_count() -> int:
def is_available() -> bool:
return torch.cuda.device_count() > 0

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"gpu",
cls,
description=f"{cls.__class__.__name__}",
)
@staticmethod
def name() -> str:
"""Name of the Accelerator."""
return "gpu"


def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover
Expand Down
11 changes: 4 additions & 7 deletions pytorch_lightning/accelerators/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ def auto_device_count() -> int:
def is_available() -> bool:
return _HPU_AVAILABLE

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"hpu",
cls,
description=f"{cls.__class__.__name__}",
)
@staticmethod
def name() -> str:
"""Name of the Accelerator."""
return "hpu"
11 changes: 4 additions & 7 deletions pytorch_lightning/accelerators/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def auto_device_count() -> int:
def is_available() -> bool:
return _IPU_AVAILABLE

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"ipu",
cls,
description=f"{cls.__class__.__name__}",
)
@staticmethod
def name() -> str:
"""Name of the Accelerator."""
return "ipu"
119 changes: 51 additions & 68 deletions pytorch_lightning/accelerators/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from inspect import getmembers, isclass
from typing import Any, Callable, Dict, List, Optional
from typing import Any, List, Optional, Type

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.registry import _is_register_method_overridden


class _AcceleratorRegistry(dict):
"""This class is a Registry that stores information about the Accelerators.
"""This class is a dictionary that stores information about the Accelerators.

The Accelerators are mapped to strings. These strings are names that identify
an accelerator, e.g., "gpu". It also returns Optional description and
an accelerator, e.g., "gpu". It also includes an optional description and any
parameters to initialize the Accelerator, which were defined during the
registration.

Expand All @@ -34,89 +30,76 @@ class _AcceleratorRegistry(dict):

Example::

@AcceleratorRegistry.register("sota", description="Custom sota accelerator", a=1, b=True)
@ACCELERATOR_REGISTRY
class SOTAAccelerator(Accelerator):
def __init__(self, a, b):
def __init__(self, a):
...

or
@staticmethod
def name():
return "sota"

AcceleratorRegistry.register("sota", SOTAAccelerator, description="Custom sota accelerator", a=1, b=True)
# or to pass parameters
ACCELERATOR_REGISTRY.register(SOTAAccelerator, description="My SoTA accelerator", a=1)
"""

def __call__(self, *args: Any, **kwargs: Any) -> Type:
return self.register(*args, **kwargs)

def register(
self,
name: str,
accelerator: Optional[Callable] = None,
description: str = "",
accelerator: Type[Accelerator],
name: Optional[str] = None,
description: Optional[str] = None,
override: bool = False,
**init_params: Any,
) -> Callable:
"""Registers a accelerator mapped to a name and with required metadata.
**kwargs: Any,
) -> Type:
"""Registers an accelerator mapped to a name and with optional metadata.

Args:
name : the name that identifies a accelerator, e.g. "gpu"
accelerator : accelerator class
description : accelerator description
override : overrides the registered accelerator, if True
init_params: parameters to initialize the accelerator
accelerator: The accelerator class.
name: The alias for the accelerator, e.g. ``"gpu"``.
description: An optional description.
override: Whether to override the registered accelerator.
**kwargs: parameters to initialize the accelerator.
"""
if not (name is None or isinstance(name, str)):
raise TypeError(f"`name` must be a str, found {name}")

if name in self and not override:
raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.")

data: Dict[str, Any] = {}

data["description"] = description
data["init_params"] = init_params

def do_register(name: str, accelerator: Callable) -> Callable:
data["accelerator"] = accelerator
data["accelerator_name"] = name
self[name] = data
return accelerator

if accelerator is not None:
return do_register(name, accelerator)

return do_register

def get(self, name: str, default: Optional[Any] = None) -> Any:
if name is None:
name = accelerator.name()
if not isinstance(name, str):
raise TypeError(f"`name` for {accelerator} must be a str, found {name!r}")

if name not in self or override:
self[name] = {
"accelerator": accelerator,
"description": description if description is not None else accelerator.__class__.__name__,
"kwargs": kwargs,
}
return accelerator

def get(self, name: str, default: Optional[Accelerator] = None) -> Accelerator:
"""Calls the registered accelerator with the required parameters and returns the accelerator object.

Args:
name (str): the name that identifies a accelerator, e.g. "gpu"
name: The name that identifies a accelerator, e.g. "gpu".
default: A default value.

Raises:
KeyError: If the key does not exist.
"""
if name in self:
data = self[name]
return data["accelerator"](**data["init_params"])

return data["accelerator"](**data["kwargs"])
if default is not None:
return default
raise KeyError(f"{name!r} not found in registry. {self!s}")

err_msg = "'{}' not found in registry. Available names: {}"
available_names = self.available_accelerators()
raise KeyError(err_msg.format(name, available_names))

def remove(self, name: str) -> None:
"""Removes the registered accelerator by name."""
self.pop(name)

def available_accelerators(self) -> List[str]:
"""Returns a list of registered accelerators."""
return list(self.keys())
@property
def names(self) -> List[str]:
"""Returns the registered names."""
return sorted(list(self))

def __str__(self) -> str:
return "Registered Accelerators: {}".format(", ".join(self.available_accelerators()))


AcceleratorRegistry = _AcceleratorRegistry()
return f"Registered Accelerators: {self.names}"


def call_register_accelerators(base_module: str) -> None:
module = importlib.import_module(base_module)
for _, mod in getmembers(module, isclass):
if issubclass(mod, Accelerator) and _is_register_method_overridden(mod, Accelerator, "register_accelerators"):
mod.register_accelerators(AcceleratorRegistry)
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
11 changes: 4 additions & 7 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ def auto_device_count() -> int:
def is_available() -> bool:
return _TPU_AVAILABLE

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"tpu",
cls,
description=f"{cls.__class__.__name__}",
)
@staticmethod
def name() -> str:
"""Name of the Accelerator."""
return "tpu"
4 changes: 2 additions & 2 deletions pytorch_lightning/strategies/strategy_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.registry import _is_register_method_overridden
from pytorch_lightning.utilities.model_helpers import is_overridden


class _StrategyRegistry(dict):
Expand Down Expand Up @@ -118,5 +118,5 @@ def __str__(self) -> str:
def call_register_strategies(base_module: str) -> None:
module = importlib.import_module(base_module)
for _, mod in getmembers(module, isclass):
if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"):
if issubclass(mod, Strategy) and is_overridden("register_strategies", mod, Strategy):
mod.register_strategies(StrategyRegistry)
18 changes: 13 additions & 5 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.accelerators.hpu import HPUAccelerator
from pytorch_lightning.accelerators.ipu import IPUAccelerator
from pytorch_lightning.accelerators.registry import AcceleratorRegistry
from pytorch_lightning.accelerators.registry import ACCELERATOR_REGISTRY
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.plugins import (
ApexMixedPrecisionPlugin,
Expand Down Expand Up @@ -80,6 +80,7 @@
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE
from pytorch_lightning.utilities.meta import get_all_subclasses

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -156,7 +157,8 @@ def __init__(
# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
self._registered_strategies = StrategyRegistry.available_strategies()
self._accelerator_types = AcceleratorRegistry.available_accelerators()
_populate_registries()
self._accelerator_types = ACCELERATOR_REGISTRY.names
self._precision_types = ("16", "32", "64", "bf16", "mixed")

# Raise an exception if there are conflicts between flags
Expand Down Expand Up @@ -484,16 +486,16 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
else:
assert self._accelerator_flag is not None
self._accelerator_flag = self._accelerator_flag.lower()
if self._accelerator_flag not in AcceleratorRegistry:
if self._accelerator_flag not in ACCELERATOR_REGISTRY:
raise MisconfigurationException(
"When passing string value for the `accelerator` argument of `Trainer`,"
f" it can only be one of {self._accelerator_types}."
)
self.accelerator = AcceleratorRegistry.get(self._accelerator_flag)
self.accelerator = ACCELERATOR_REGISTRY.get(self._accelerator_flag)

if not self.accelerator.is_available():
available_accelerator = [
acc_str for acc_str in self._accelerator_types if AcceleratorRegistry.get(acc_str).is_available()
acc_str for acc_str in self._accelerator_types if ACCELERATOR_REGISTRY.get(acc_str).is_available()
]
raise MisconfigurationException(
f"{self.accelerator.__class__.__qualname__} can not run on your system"
Expand Down Expand Up @@ -820,3 +822,9 @@ def is_distributed(self) -> bool:
if isinstance(self.accelerator, TPUAccelerator):
is_distributed |= self.strategy.is_distributed
return is_distributed


def _populate_registries() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why have it in accelerator_connector, not during import time __init__?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we don't need to do it on import time, so better to delay it if possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? What's the benefit of delaying it?

If I do this

    from pytorch_lightning.accelerators.registry import ACCELERATOR_REGISTRY
    print(ACCELERATOR_REGISTRY.names)

I would be expecting a list of accelerator names, rather than calling one more function _populate_registeries before it. Also registeries is plural, and we are populating only one Registry

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing things at import time is usually problematic. In addition to slowing time to start running, it means users are very limited when they want to override or customize behaviour.

Also registeries is plural, and we are populating only one Registry

Yes, but we would eventually do the same changes for strategies.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For strategies, we really can't get rid of register_strategies, as we have multiple configurations for certain strategies.
That's the major con of this approach, we can't have multiple subclasses with different defaults in our code.

We can't have an inconsistent api for registering strategies and accelerators.

# automatically register accelerators
for cls in get_all_subclasses(Accelerator):
ACCELERATOR_REGISTRY(cls)
Loading