-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Sharded Plugin #4178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharded Plugin #4178
Conversation
|
Hello @SeanNaren! Thanks for updating this PR.
Comment last updated at 2020-11-24 21:16:57 UTC |
|
Going to put some stats here, I've been tracking CUDA memory using a callback inspired by this: Running this transformers script for token classification I've got some average peak memory stats after the first epoch: |
Nice ! FYI there are more savings on the way, it would require wrapping the model and not using DDP (and autograd hook does the job) but that solves the gradient reduce issue and reduces the communications, which is important when multiple nodes are involved. It does not deprecate the version you tested, it just comes on top (if OSS is used with DDP you get what you measured, if OSS is used with a model wrap instead you get a better gradient sharding) |
|
Thanks @blefaudeux! Was excited to get numbers for this, really awesome work overall :) Are the additional savings only tied to fairscale? Is there a place I can have a look at the code? Would be good to get a head start on figuring out API changes to support this |
No PR yet, the work is in https://github.com/facebookresearch/fairscale/tree/oss_autograd and then a normal training loop with optimizer and model. Gradients are automatically reduced to the right rank and the optimizer state and gradients are sharded, which shaves some more memory from what you have right now. What do you think ? |
|
Yep interface makes sense to me. I also see the model dispatch making its way into this branch which is super cool! I don't think it makes sense to support both DDP + OSS and Sharded DDP + OSS since we'll need to install fairscale regardless and Sharded DDP seems like the successor. I think it makes sense to go forward ShardedDDP and have it replace the current implementation, thoughts here? |
It may require some benchmarking, basically one feature from DDP which is hard to replicate is the overlap in between BW and reduce, the gradients in DDP are all-reduced step by step when walking back the graph concurrently with the BW computations. Currently what's "easy" to do is FW then BW then reduce (not all reduce), but the overlap is lost. When we shard the model (next steps, a little more involving) it's not that much of an issue because we can overlap the reduce of the lower shard with the BW of the upper one. From what I can see, currently (state+gradient sharding):
|
|
Yeah still not sure it's worth supporting the DDP version. Having the DDP lightning wrapper just for a small benefit in speed for single node setups (and personally I think what's more important is reducing memory allocation) where the cap of total GPUs is small, I don't think is necessary and adds confusion unless I'm mistaken! |
|
Pushing WIP changes integrating ShardedDDP using the Running into an issue using multiple GPUs where training hangs, currently investigating this. Also seeing if there is a nicer way to handle requiring grads for the input. Main issue being I've made the assumption (for now) that the inputs are tuples within the batch i.e within ModelDispatch: # All inputs need to required_grad for autograd to properly track the first dispatch layer
# Will currently break if a dict or something, may require a recursive check
if isinstance(inputs, tuple):
for i in inputs:
i.requires_grad = True@blefaudeux after investigation I've noticed that it only hangs when using torch automatic mixed precision autocast with OSS + SDP. Any reasons you think this could happen off the top of your head? Will continue to investigate, but curious if you had any suggestions! |
2911102 to
d000ca9
Compare
d000ca9 to
9b9dd9f
Compare
|
@williamFalcon I've been trying to keep up with the plugin API which is neat! I still think the fairscale integration should live as a native accelerator, because there are more features to come from fairscale, but was curious on your thoughts. EDIT: offline conclusion is living as an accelerator :) |
|
A few updates to track here, running into issues supporting kwarg input for forward/backward pass. Seems a solution was discussed here: facebookresearch/fairscale#160 (comment) and here: facebookresearch/fairscale#160 (comment) Seems like there are some performance issues when moving to multi-node which I haven't been able to test that ben has reported: facebookresearch/fairscale#157 (comment) Since I don't think its worth longer term to just get the OSS optimizer into lightning without SDP, we'll put this on hold till performance issues and function args are solved! |
naming nit: can we name this as sharded ddp accelerator? fairscale as a library will eventually have more components we may want to plug in elsewhere |
|
@ananthsub I was thinking about that... even ZeRO is better or Deepspeed, but I don't mind just calling it ShardedDDP |
…to sync optimizer state before saving
# Conflicts: # pytorch_lightning/accelerators/accelerator.py # pytorch_lightning/plugins/ddp_plugin.py # pytorch_lightning/trainer/connectors/checkpoint_connector.py
# Conflicts: # pytorch_lightning/accelerators/ddp2_accelerator.py # pytorch_lightning/accelerators/ddp_accelerator.py # pytorch_lightning/accelerators/ddp_hpc_accelerator.py # pytorch_lightning/plugins/ddp_plugin.py
# Conflicts: # pytorch_lightning/trainer/training_loop.py
|
This pull request is now in conflict... :( |
What does this PR do?
Closes #817. Lots of related comments in the issue, but overall fairscale has done a great job of taking the initial DeepSpeed code and making a pytorch module to support the ZERO optimization feature (the main feature from DeepSpeed aside from some custom kernel ops + fp16). Thus I think for a V1, we should offer integration with fairscale and assist in getting DDP changes for model partitioning (model parallel) + await optimisations.
Will require additional tests before merging.
I also note that fairscale install crashes on a remote ubuntu machine. Installing from source however runs fine.
cc @blefaudeux and @ananthsub who mentioned integration already existing internally with lightning, so this PR may unify efforts or be unnecessary. If I could get a lookover this PR I would really appreciate it as well :)
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃