Skip to content

[RFC] Support passing pluggable Accelerators to Trainer #10687

@kaushikb11

Description

@kaushikb11

🚀 Feature

With the present Lightning accelerator design, new accelerators cannot be passed to Trainer unless they are part of the Lightning accelerators.

This is not possible.

trainer = Trainer(accelerator=NewSOTAAccelerator(), devices=4)

There is a lot of innovation happening in the space of ML Accelerators, and the list will continue to grow. We should enable support for this functionality and make it easier for users to experiment with new accelerators using Lightning.

This proposal also aims at cleaning up and moving hardware specific logic from the accelerator connector to the accelerators.

For example, the HPUAccelerator PR, which is still in development, adds support for Habana's Gaudi Accelerator. Based on the above points, the Accelerator interface would look like this.

class HPUAccelerator(Accelerator):
    """Accelerator for HPU devices."""
    
    @property
    def accelerator_type(self) -> str:
        """Accelerator type."""
        return "hpu"
	
    @staticmethod
    def parse_devices(devices) -> int:
        # Include the HPU device parsing logic here
        return devices
    
    @staticmethod
    def auto_device_count() -> int:
        """Get the HPU devices when set to auto."""
        return habana.device_count()
    
    @staticmethod
    def get_parallel_devices(devices: int) -> List[torch.device]:
        """Gets parallel devices for the given HPU devices."""
        # Moved the logic from accelerator connector
        return [torch.device("hpu")] * devices
    
    def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
        """Gets stats for the given HPU device."""
        return {}

After defining HPUAccelerator, the user could provide it to the Trainer without it being part of the Lightning accelerators.

trainer = Trainer(accelerator=HPUAccelerator(), devices=4, strategy=HPUPlugin())

cc @Borda @tchaton @rohitgr7 @akihironitta

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions