-
Notifications
You must be signed in to change notification settings - Fork 3.6k
All gatherwith grads #5012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
All gatherwith grads #5012
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
0d1fbd2
all_gather
ananyahjha93 5bb4b54
ddp
ananyahjha93 db652cf
horovod
ananyahjha93 0a821f4
grad tests
ananyahjha93 2dd7680
fixed ddp
ananyahjha93 9f031a0
ddp fixed, removed tpu, horovod for now
ananyahjha93 eec2ff3
changelog
ananyahjha93 fbfdb43
windows fix
ananyahjha93 ea1d3d8
windows fix
ananyahjha93 f666020
removed batch from ctx
ananyahjha93 b3a50c3
Merge branch 'master' into all_gather
Borda ab1a864
all_gather
ananyahjha93 993aa51
ddp
ananyahjha93 6fc03bf
horovod
ananyahjha93 309a7e0
grad tests
ananyahjha93 519711d
fixed ddp
ananyahjha93 2ab162d
ddp fixed, removed tpu, horovod for now
ananyahjha93 23a4779
changelog
ananyahjha93 3ca4054
windows fix
ananyahjha93 f01d800
windows fix
ananyahjha93 586bb50
removed batch from ctx
ananyahjha93 b9d182d
removed code duplication
ananyahjha93 7b865a7
merge
ananyahjha93 f2161ee
merge
ananyahjha93 d67d47e
Merge branch 'master' into all_gather
ananyahjha93 347902a
Merge branch 'master' into all_gather
ananyahjha93 e625b4a
Merge branch 'master' into all_gather
ananyahjha93 a0603c4
Merge branch 'master' into all_gather
ananyahjha93 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,10 +22,14 @@ | |
|
|
||
| if torch.distributed.is_available(): | ||
| from torch.distributed import ReduceOp | ||
| from torch.distributed import group | ||
| else: | ||
| class ReduceOp: | ||
| SUM = None | ||
|
|
||
| class group: | ||
| WORLD = None | ||
|
|
||
|
|
||
| def rank_zero_only(fn): | ||
|
|
||
|
|
@@ -155,3 +159,54 @@ def sync_ddp( | |
| result = result / torch.distributed.get_world_size(group) | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| class AllGatherGrad(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, tensor, group=group.WORLD): | ||
| ctx.group = group | ||
|
|
||
| gathered_tensor = [ | ||
ananyahjha93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) | ||
| ] | ||
|
|
||
| torch.distributed.all_gather(gathered_tensor, tensor, group=group) | ||
| gathered_tensor = torch.stack(gathered_tensor, dim=0) | ||
|
|
||
| return gathered_tensor | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, *grad_output): | ||
| grad_output = torch.cat(grad_output) | ||
|
|
||
| torch.distributed.all_reduce( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you just use |
||
| grad_output, | ||
| op=torch.distributed.ReduceOp.SUM, | ||
| async_op=False, | ||
| group=ctx.group | ||
| ) | ||
|
|
||
| return grad_output[torch.distributed.get_rank()] | ||
|
|
||
|
|
||
| def all_gather_ddp_if_available( | ||
| tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Function to gather a tensor from several distributed processes | ||
|
|
||
| Args: | ||
| tensor: tensor of shape (batch, ...) | ||
| group: the process group to gather results from. Defaults to all processes (world) | ||
| sync_grads: flag that allows users to synchronize gradients for all_gather op | ||
|
|
||
| Return: | ||
| A tensor of shape (world_size, batch, ...) | ||
| """ | ||
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if sync_grads: | ||
| return AllGatherGrad.apply(tensor, group) | ||
ananyahjha93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| with torch.no_grad: | ||
| return AllGatherGrad.apply(tensor, group) | ||
| return tensor | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| import os | ||
| import pytest | ||
| import sys | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from pytorch_lightning.utilities import AllGatherGrad | ||
|
|
||
|
|
||
| def setup_ddp(rank, world_size): | ||
| """ Setup ddp enviroment """ | ||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = "8088" | ||
|
|
||
| if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): | ||
| torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) | ||
|
|
||
|
|
||
| def _test_all_gather_ddp(rank, world_size): | ||
| setup_ddp(rank, world_size) | ||
|
|
||
| tensor1 = torch.ones(8, requires_grad=True) | ||
| tensor2 = torch.ones((8, 16, 32), requires_grad=True) | ||
|
|
||
| tensor1_gathered = AllGatherGrad.apply(tensor1) | ||
| tensor2_gathered = AllGatherGrad.apply(tensor2) | ||
|
|
||
| tensor1_gathered = tensor1_gathered * rank | ||
| tensor2_gathered = tensor2_gathered * rank | ||
|
|
||
| tensor1_gathered.sum().backward() | ||
| tensor2_gathered.sum().backward() | ||
|
|
||
| grad1 = torch.zeros_like(tensor1.grad).fill_(torch.arange(world_size).sum().float()) | ||
| grad2 = torch.zeros_like(tensor2.grad).fill_(torch.arange(world_size).sum().float()) | ||
|
|
||
| assert torch.allclose(grad1, tensor1.grad) | ||
| assert torch.allclose(grad2, tensor2.grad) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") | ||
| def test_all_gather_ddp(): | ||
| world_size = 3 | ||
| torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.