Skip to content

Support DDP communication hook for speeding up training #6727

@shuyingsunshine21

Description

@shuyingsunshine21

🚀 Feature

Motivation

https://pytorch.org/docs/1.8.0/ddp_comm_hooks.html control communicate gradients across workers for all_reduce in DistributedDataParallel. such as fp16_compress_hook converts gradients to fp16 before all reduce is an effective way to improve training speed when using multi nodes.

Pitch

In DDPPlugin, provide an option for specifying ddp_comm_hook, and register in configure_ddp https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/ddp.py#L227

    def configure_ddp(self):
        self.pre_configure_ddp()
        self._model = DistributedDataParallel(
            LightningDistributedModule(self.model),
            device_ids=self.determine_ddp_device_ids(),
            **self._ddp_kwargs,
        )
        register_ddp_comm_hook(
            model=self._model,
            ddp_comm_state=self._ddp_comm_state,
            ddp_comm_hook=self._ddp_comm_hook,
            ddp_comm_wrapper=self.ddp_comm_wrapper,
        )

Alternatives

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    distributedGeneric distributed-related topicfeatureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions