-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[bug] Update broadcast + reduce decision ModelCheckpoint] #6410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
70 commits
Select commit
Hold shift + click to select a range
597ae27
resolve bug
tchaton ef11927
update
tchaton 85b327d
update changelog
tchaton 47f0b2c
update PR
tchaton bbe4255
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton 1c33b48
Update pytorch_lightning/trainer/connectors/logger_connector/epoch_re…
tchaton 6cd4713
add todo
tchaton 45d7239
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton b58d7fb
resolve issues
tchaton e3a084a
resolve flake8
tchaton 77edbed
update
tchaton 6bcc88d
add coverage for reduce
tchaton c63bca5
wip
tchaton e26d301
restore back to brodbact
tchaton ce239fd
remove test.py
tchaton d8f1dc9
resolve flake8
tchaton 237bbd2
update
tchaton f546ae4
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton 6fbe70d
check world size
tchaton 5f25fc5
resolve test
tchaton 46cf2c6
update
tchaton 7029b31
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton 8523167
use pytorch version when defined
tchaton f28f950
update on comments
tchaton 6eae79d
update on comments
tchaton 1cd9431
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton 9448964
flake8
tchaton 1b5c90a
resolve bugs
tchaton a1264d9
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton 9f3eb41
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton e88ef07
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton c21f148
Update CHANGELOG.md
tchaton 94e9aa9
update
tchaton 4626310
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton b260bf6
update
tchaton dd60ed1
update
tchaton 45b65f1
update
tchaton dcd6884
remove test
tchaton 2e046e8
update
tchaton 68ffb5b
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton 23b2c10
resolve flake8
tchaton b4c663b
update
tchaton aa89d5d
Merge branch 'master' into bugfix/5604_ddp_model_checkpoint
tchaton 73e83f7
update
tchaton c060444
Merge branch 'bugfix/5604_ddp_model_checkpoint' of https://github.com…
tchaton 2eb6db4
update
tchaton 060992b
proxy
tchaton 5bad135
update
tchaton 4579842
update
tchaton 5276cd0
Merge branch 'master' into bugfix/broadcast_2
tchaton 8027838
resolve typo
tchaton aa9a6ca
prune
tchaton 4b6a6c5
update parallel
tchaton 4b55c52
update
tchaton cbacf48
update changelog
tchaton 057fbf3
update
tchaton 7f515ea
Merge branch 'master' into bugfix/broadcast_2
tchaton 7cbf38b
try running pipe
tchaton 928cf2c
Merge branch 'master' into bugfix/broadcast_2
carmocca 690b61f
Update pytorch_lightning/callbacks/model_checkpoint.py
tchaton 300a632
update on comments
tchaton 5e30377
Update pytorch_lightning/callbacks/model_checkpoint.py
tchaton 015fbac
update on comennts
tchaton c213716
Merge branch 'bugfix/broadcast_2' of https://github.com/PyTorchLightn…
tchaton f668c3a
update
tchaton 30feb40
update
tchaton a4bf623
update
tchaton b482589
fix
tchaton 1ad9c62
remove comments
tchaton b64e105
resolve bugs
tchaton File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| import logging | ||
| import pickle | ||
|
|
||
| import torch | ||
|
|
||
| from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
| if torch.distributed.is_available(): | ||
| from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember | ||
|
|
||
| # The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py`` | ||
| # and enable broadcasting for PyTorch 1.6 and lower. | ||
|
|
||
|
|
||
| # https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160 | ||
| def _rank_not_in_group(group): | ||
| """ | ||
| Helper that checks if the current process's rank is not in a given group. | ||
| """ | ||
| if group is None: | ||
| return False | ||
| return group == GroupMember.NON_GROUP_MEMBER | ||
|
|
||
|
|
||
| # Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164 | ||
| def _object_to_tensor(obj): | ||
| buffer = pickle.dumps(obj) | ||
| byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined] | ||
| byte_tensor = torch.ByteTensor(byte_storage) | ||
| local_size = torch.LongTensor([byte_tensor.numel()]) | ||
| return byte_tensor, local_size | ||
|
|
||
|
|
||
| # Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py | ||
| def _tensor_to_object(tensor, tensor_size): | ||
| buf = tensor.numpy().tobytes()[:tensor_size] | ||
| out = pickle.loads(buf) | ||
| return out | ||
|
|
||
|
|
||
| # Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327 | ||
| def _broadcast_object_list(object_list, src=0, group=None): | ||
| if _rank_not_in_group(group): | ||
| return | ||
|
|
||
| my_rank = get_rank() | ||
| # Serialize object_list elements to tensors on src rank. | ||
| if my_rank == src: | ||
| tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list]) | ||
| object_sizes_tensor = torch.cat(size_list) | ||
| else: | ||
| object_sizes_tensor = torch.LongTensor(len(object_list)) | ||
|
|
||
| group_backend = get_backend(group) | ||
| is_nccl_backend = group_backend == Backend.NCCL | ||
| current_device = torch.device("cpu") | ||
| if is_nccl_backend: | ||
| # See note about using torch.cuda.current_device() here in docstring. | ||
| # We cannot simply use my_rank since rank == device is not necessarily | ||
| # true. | ||
| current_device = torch.device('cuda', torch.cuda.current_device()) | ||
| object_sizes_tensor = object_sizes_tensor.to(current_device) | ||
| object_sizes_tensor = object_sizes_tensor.to(current_device) | ||
|
|
||
| # Broadcast object sizes | ||
| broadcast(object_sizes_tensor, src=src, group=group) | ||
|
|
||
| # Concatenate and broadcast serialized object tensors | ||
| if my_rank == src: | ||
| object_tensor = torch.cat(tensor_list) | ||
| else: | ||
| object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) | ||
|
|
||
| if is_nccl_backend: | ||
| object_tensor = object_tensor.to(current_device) | ||
|
|
||
| broadcast(object_tensor, src=src, group=group) | ||
|
|
||
| # Deserialize objects using their stored sizes. | ||
| offset = 0 | ||
| if my_rank != src: | ||
| for i, obj_size in enumerate(object_sizes_tensor): | ||
| obj_view = object_tensor[offset:offset + obj_size] | ||
| obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload] | ||
| offset += obj_size | ||
| object_list[i] = _tensor_to_object(obj_view, obj_size) | ||
|
|
||
|
|
||
| if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available(): | ||
| from torch.distributed.distributed_c10d import broadcast_object_list | ||
| else: | ||
| broadcast_object_list = _broadcast_object_list |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.