Skip to content

Remove AccleratorConnector.devices and Deprecate Trainer.devices in favor of Trainer.device_ids and Trainer.num_devices  #12126

@DuYicong515

Description

@DuYicong515

Proposed refactor

Remove
https://github.com/PyTorchLightning/pytorch-lightning/blob/7e2f9fbad555242b0ceb2a24e5e4c004f0701bae/pytorch_lightning/trainer/connectors/accelerator_connector.py#L788-L794

Introduce Trainer.device_ids which returns a list of device indexes, and Trainer.num_devices which returns len(Trainer.device_ids).

Deprecate Trainer.devices in favor of Trainer.device_ids and Trainer.num_devices
https://github.com/PyTorchLightning/pytorch-lightning/blob/7e2f9fbad555242b0ceb2a24e5e4c004f0701bae/pytorch_lightning/trainer/connectors/accelerator_connector.py#L788-L794

Motivation

Accelerator.devices was not used within PyTorchLightning other than Trainer.devices, also its implementation looks more like num_devices. Since Accelerator.devices is not meant to be a public property, we can remove that in favor of using Trainer.num_devices externally.

Introduce Trainer.device_ids that returns list of device indexes to deprecate Trainer.devices. It could be used to get the device indexes, and also list of devices combing indexes + Trainer.acclerator.

Currently Trainer.devices wasn't used within PytorchLightning other than tests, and its name is quite confusing -- it returns number of devices instead of list of devices/or device indexes.

Trainer.num_devices which is derived from Trainer.device_ids can also help with several Trainer properties migration since we plan to remove bunch of unused properties in AcceleratorConnector (part of #11449) including num_processes, tpu_cores , ipus and num_gpus .

Pitch

@property
def device_ids(self) -> List[int]:
    if isinstance(self.strategy, ParallelStrategy):
         return [torch._utils._get_device_index(device) for device in self.strategy.parallel_devices]
    elif isinstance(self.strategy, SingleDeviceStrategy):
          return [torch._utils._get_device_index(self.strategy.root_device)]
    return []
    
@property
def num_devices(self) -> List[int]:
    return len(device_ids)
    

cc @justusschock @awaelchli @rohitgr7 @four4fish


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions