-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Proposed refactor
Introduce Trainer.device_ids which returns a list of device indexes, and Trainer.num_devices which returns len(Trainer.device_ids).
Deprecate Trainer.devices in favor of Trainer.device_ids and Trainer.num_devices
https://github.com/PyTorchLightning/pytorch-lightning/blob/7e2f9fbad555242b0ceb2a24e5e4c004f0701bae/pytorch_lightning/trainer/connectors/accelerator_connector.py#L788-L794
Motivation
Accelerator.devices was not used within PyTorchLightning other than Trainer.devices, also its implementation looks more like num_devices. Since Accelerator.devices is not meant to be a public property, we can remove that in favor of using Trainer.num_devices externally.
Introduce Trainer.device_ids that returns list of device indexes to deprecate Trainer.devices. It could be used to get the device indexes, and also list of devices combing indexes + Trainer.acclerator.
Currently Trainer.devices wasn't used within PytorchLightning other than tests, and its name is quite confusing -- it returns number of devices instead of list of devices/or device indexes.
Trainer.num_devices which is derived from Trainer.device_ids can also help with several Trainer properties migration since we plan to remove bunch of unused properties in AcceleratorConnector (part of #11449) including num_processes, tpu_cores , ipus and num_gpus .
Pitch
@property
def device_ids(self) -> List[int]:
if isinstance(self.strategy, ParallelStrategy):
return [torch._utils._get_device_index(device) for device in self.strategy.parallel_devices]
elif isinstance(self.strategy, SingleDeviceStrategy):
return [torch._utils._get_device_index(self.strategy.root_device)]
return []
@property
def num_devices(self) -> List[int]:
return len(device_ids)
cc @justusschock @awaelchli @rohitgr7 @four4fish
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.