-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Proposed refactor
Final state:
On devices related properties, we will keep device_ids, num_devices and num_nodes
Properties below will be deprecated in favor of directly derive from the above 3 directly.
https://github.com/PyTorchLightning/pytorch-lightning/blob/d4d197070fc2c6c04d460bbfb8b1b9d3a2ebc944/pytorch_lightning/trainer/trainer.py#L2029-L2057
Motivation
1/ There are bunch of device related properties on Trainer that retrieve the values from Trainer._accelerator_connector. However, those properties should be able to retrieve from Trainer.strategy or Trainer.accelerator. AccleratorConnector are internal-facing, and those properties are not meant to be public.
2/ Some of the properties can be directly derived from the others easily, would also love to deprecate them in favor of deriving from existing ones. This include num_processes, root_gpu, tpu_cores, ipus, num_gpus, data_parallel_device_ids
3/ Trainer.devices currently returns num_devices, which is a bit confusing from its naming, would also love to deprecate it in favor of Trainer.device_ids and Trainer.device_nums
Related discussions in #12126, #11624
Pitch
Kept properties
@property
def device_ids(self) -> List[int]:
devices = getattr(self.strategy, "parallel_devices", [self.strategy.root_device])
device_ids = []
for idx, device in enumerate(devices):
if isinstance(device, torch.device):
device_ids.append(device.index or idx)
elif isinstance(device, int):
device_ids.append(device)
return device_ids
@property
def num_devices(self) -> int:
return len(self.device_ids)
@property
def num_nodes(self) -> int:
return getattr(self.strategy, "num_nodes", 1)
The others will be deprecated, and change implementations to derive from the above. Examples
@property
def devices(self) -> Optional[Union[List[int], str, int]]:
return self._accelerator_connector.devices
rank_zero_deprecation(
"`Trainer.devices` was deprecated in v1.6 and will be removed in v1.8."
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
)
return self.num_devices
@property
def parallel_device_ids(self) -> List[int]:
def data_parallel_device_ids(self) -> List[int]:
rank_zero_deprecation(
"`Trainer.data_parallel_device_ids` was deprecated in v1.6 and will be removed in v1.8."
" Please use `self.device_ids if isinstance(self.accelerator, GPUAccelerator) else []` instead"
)
:
return self.device_ids if isinstance(self.accelerator, GPUAccelerator) else []
@property
def ipus(self) -> int:
rank_zero_deprecation(
"`Trainer.ipus` was deprecated in v1.6 and will be removed in v1.8."
" please use `self.num_devices if isinstance(self.accelerator, IPUAccelerator)` else 0 instead"
)
return self.num_devices if isinstance(self.accelerator, IPUAccelerator) else 0
The steps on implementation
1/ Introduce the new properties device_ids and num_devices
2/ Deprecate the others: Will change implementations to directly derive from the existing ones, and add deprecation messages.
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @justusschock @awaelchli @rohitgr7 @kaushikb11 @Borda @ananthsub @ninginthecloud @jjenniferdai @akihironitta