diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 658f349a22215..c28770a07deb7 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -172,7 +172,7 @@ def backward(ctx, *grad_output): torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) - return grad_output[torch.distributed.get_rank()] + return grad_output[torch.distributed.get_rank()], None def all_gather_ddp_if_available( diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 259f9f4c09871..d67c9473bbb2e 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -95,3 +95,26 @@ def training_epoch_end(self, outputs) -> None: trainer.fit(model) assert model.training_epoch_end_called + + +@RunIf(min_gpus=2, skip_windows=True, special=True) +def test_all_gather_sync_grads(tmpdir): + + class TestModel(BoringModel): + + training_step_called = False + + def training_step(self, batch, batch_idx): + self.training_step_called = True + tensor = torch.rand(2, 2, requires_grad=True, device=self.device) + gathered_tensor = self.all_gather(tensor, sync_grads=True) + assert gathered_tensor.shape == torch.Size([2, 2, 2]) + + loss = gathered_tensor.sum() + + return loss + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2) + trainer.fit(model) + assert model.training_step_called