From 8ce85d223e5ab88c56ec6aad87c9bded9798bfad Mon Sep 17 00:00:00 2001 From: Arvin Zhuang Date: Mon, 22 Mar 2021 16:29:00 +1000 Subject: [PATCH 1/4] match the number of outputs of backward with that of inputs of forward for AllGatherGrad --- pytorch_lightning/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 658f349a22215..c28770a07deb7 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -172,7 +172,7 @@ def backward(ctx, *grad_output): torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) - return grad_output[torch.distributed.get_rank()] + return grad_output[torch.distributed.get_rank()], None def all_gather_ddp_if_available( From 95f5b482c9ed32ee95d0a261850caf4a192b9ece Mon Sep 17 00:00:00 2001 From: ArvinZhuang Date: Wed, 24 Mar 2021 11:53:15 +1000 Subject: [PATCH 2/4] add test_all_gather_sync_grads --- tests/utilities/test_all_gather_grad.py | 36 +++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 259f9f4c09871..5dc479e1a7f46 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -95,3 +95,39 @@ def training_epoch_end(self, outputs) -> None: trainer.fit(model) assert model.training_epoch_end_called + + +@RunIf(min_gpus=2, skip_windows=True, special=True) +def test_all_gather_sync_grads(tmpdir): + + class TestModel(BoringModel): + + training_step_called = False + + def training_step(self, batch, batch_idx): + self.training_step_called = True + tensor = torch.rand(2, 2, requires_grad=True, device=self.device) + gathered_tensor = self.all_gather(tensor, sync_grads=True) + assert gathered_tensor.shape == torch.Size([2, 2, 2]) + + loss = gathered_tensor.sum() + + return loss + + seed_everything(42) + + model = TestModel() + + limit_train_batches = 8 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + gpus=2, + accelerator="ddp", + ) + trainer.fit(model) + assert model.training_step_called From e3c4c1335c42c61ea532d1a54ba57b078ff2a7c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 24 Mar 2021 17:36:15 +0100 Subject: [PATCH 3/4] Update tests/utilities/test_all_gather_grad.py --- tests/utilities/test_all_gather_grad.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 5dc479e1a7f46..d26780122f1da 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -118,16 +118,6 @@ def training_step(self, batch, batch_idx): model = TestModel() - limit_train_batches = 8 - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=limit_train_batches, - limit_val_batches=2, - max_epochs=1, - log_every_n_steps=1, - accumulate_grad_batches=2, - gpus=2, - accelerator="ddp", - ) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2) trainer.fit(model) assert model.training_step_called From bd0e22611a7c55e53c03de4ef18740c3393de7ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 24 Mar 2021 17:37:09 +0100 Subject: [PATCH 4/4] Update tests/utilities/test_all_gather_grad.py --- tests/utilities/test_all_gather_grad.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index d26780122f1da..d67c9473bbb2e 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -114,10 +114,7 @@ def training_step(self, batch, batch_idx): return loss - seed_everything(42) - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2) trainer.fit(model) assert model.training_step_called