-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
distributedGeneric distributed-related topicGeneric distributed-related topicfeatureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on
Milestone
Description
🚀 Feature
Motivation
https://pytorch.org/docs/1.8.0/ddp_comm_hooks.html control communicate gradients across workers for all_reduce in DistributedDataParallel. such as fp16_compress_hook converts gradients to fp16 before all reduce is an effective way to improve training speed when using multi nodes.
Pitch
In DDPPlugin, provide an option for specifying ddp_comm_hook, and register in configure_ddp https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/ddp.py#L227
def configure_ddp(self):
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
)
register_ddp_comm_hook(
model=self._model,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self.ddp_comm_wrapper,
)
Alternatives
Additional context
SeanNarenananthsub and wayi1
Metadata
Metadata
Assignees
Labels
distributedGeneric distributed-related topicGeneric distributed-related topicfeatureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on