-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
With the present Lightning accelerator design, new accelerators cannot be passed to Trainer unless they are part of the Lightning accelerators.
This is not possible.
trainer = Trainer(accelerator=NewSOTAAccelerator(), devices=4)There is a lot of innovation happening in the space of ML Accelerators, and the list will continue to grow. We should enable support for this functionality and make it easier for users to experiment with new accelerators using Lightning.
This proposal also aims at cleaning up and moving hardware specific logic from the accelerator connector to the accelerators.
For example, the HPUAccelerator PR, which is still in development, adds support for Habana's Gaudi Accelerator. Based on the above points, the Accelerator interface would look like this.
class HPUAccelerator(Accelerator):
"""Accelerator for HPU devices."""
@property
def accelerator_type(self) -> str:
"""Accelerator type."""
return "hpu"
@staticmethod
def parse_devices(devices) -> int:
# Include the HPU device parsing logic here
return devices
@staticmethod
def auto_device_count() -> int:
"""Get the HPU devices when set to auto."""
return habana.device_count()
@staticmethod
def get_parallel_devices(devices: int) -> List[torch.device]:
"""Gets parallel devices for the given HPU devices."""
# Moved the logic from accelerator connector
return [torch.device("hpu")] * devices
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Gets stats for the given HPU device."""
return {}After defining HPUAccelerator, the user could provide it to the Trainer without it being part of the Lightning accelerators.
trainer = Trainer(accelerator=HPUAccelerator(), devices=4, strategy=HPUPlugin())