-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[WIP] Accelerator registry follow-up #12461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,7 @@ | |
| from pytorch_lightning.accelerators.gpu import GPUAccelerator | ||
| from pytorch_lightning.accelerators.hpu import HPUAccelerator | ||
| from pytorch_lightning.accelerators.ipu import IPUAccelerator | ||
| from pytorch_lightning.accelerators.registry import AcceleratorRegistry | ||
| from pytorch_lightning.accelerators.registry import ACCELERATOR_REGISTRY | ||
| from pytorch_lightning.accelerators.tpu import TPUAccelerator | ||
| from pytorch_lightning.plugins import ( | ||
| ApexMixedPrecisionPlugin, | ||
|
|
@@ -80,6 +80,7 @@ | |
| ) | ||
| from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
| from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE | ||
| from pytorch_lightning.utilities.meta import get_all_subclasses | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -156,7 +157,8 @@ def __init__( | |
| # 1. Parsing flags | ||
| # Get registered strategies, built-in accelerators and precision plugins | ||
| self._registered_strategies = StrategyRegistry.available_strategies() | ||
| self._accelerator_types = AcceleratorRegistry.available_accelerators() | ||
| _populate_registries() | ||
| self._accelerator_types = ACCELERATOR_REGISTRY.names | ||
| self._precision_types = ("16", "32", "64", "bf16", "mixed") | ||
|
|
||
| # Raise an exception if there are conflicts between flags | ||
|
|
@@ -484,16 +486,16 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: | |
| else: | ||
| assert self._accelerator_flag is not None | ||
| self._accelerator_flag = self._accelerator_flag.lower() | ||
| if self._accelerator_flag not in AcceleratorRegistry: | ||
| if self._accelerator_flag not in ACCELERATOR_REGISTRY: | ||
| raise MisconfigurationException( | ||
| "When passing string value for the `accelerator` argument of `Trainer`," | ||
| f" it can only be one of {self._accelerator_types}." | ||
| ) | ||
| self.accelerator = AcceleratorRegistry.get(self._accelerator_flag) | ||
| self.accelerator = ACCELERATOR_REGISTRY.get(self._accelerator_flag) | ||
|
|
||
| if not self.accelerator.is_available(): | ||
| available_accelerator = [ | ||
| acc_str for acc_str in self._accelerator_types if AcceleratorRegistry.get(acc_str).is_available() | ||
| acc_str for acc_str in self._accelerator_types if ACCELERATOR_REGISTRY.get(acc_str).is_available() | ||
| ] | ||
| raise MisconfigurationException( | ||
| f"{self.accelerator.__class__.__qualname__} can not run on your system" | ||
|
|
@@ -820,3 +822,9 @@ def is_distributed(self) -> bool: | |
| if isinstance(self.accelerator, TPUAccelerator): | ||
| is_distributed |= self.strategy.is_distributed | ||
| return is_distributed | ||
|
|
||
|
|
||
| def _populate_registries() -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why have it in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because we don't need to do it on import time, so better to delay it if possible.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not? What's the benefit of delaying it? If I do this from pytorch_lightning.accelerators.registry import ACCELERATOR_REGISTRY
print(ACCELERATOR_REGISTRY.names)I would be expecting a list of accelerator names, rather than calling one more function
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doing things at import time is usually problematic. In addition to slowing time to start running, it means users are very limited when they want to override or customize behaviour.
Yes, but we would eventually do the same changes for strategies.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For strategies, we really can't get rid of We can't have an inconsistent api for registering strategies and accelerators. |
||
| # automatically register accelerators | ||
| for cls in get_all_subclasses(Accelerator): | ||
| ACCELERATOR_REGISTRY(cls) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The discussion for removing the
namemethod. #12180 (comment)Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having
nameis useful and clearly separates the accelerator from the registry. Otherwise the registry logic (a protected class) leaks inside theAccelerator(a stable class).It also allows registering the accelerators automatically as shown in this PR.
One drawback I see is that one class couldn't register multiple instances of itself. But we could create subclasses with those defaults.
cc @justusschock @ananthsub as participants from prev discussion