Skip to content

Conversation

@SeanNaren
Copy link
Contributor

@SeanNaren SeanNaren commented Oct 15, 2020

What does this PR do?

Closes #817. Lots of related comments in the issue, but overall fairscale has done a great job of taking the initial DeepSpeed code and making a pytorch module to support the ZERO optimization feature (the main feature from DeepSpeed aside from some custom kernel ops + fp16). Thus I think for a V1, we should offer integration with fairscale and assist in getting DDP changes for model partitioning (model parallel) + await optimisations.

Will require additional tests before merging.

I also note that fairscale install crashes on a remote ubuntu machine. Installing from source however runs fine.

cc @blefaudeux and @ananthsub who mentioned integration already existing internally with lightning, so this PR may unify efforts or be unnecessary. If I could get a lookover this PR I would really appreciate it as well :)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@SeanNaren SeanNaren added the feature Is an improvement or enhancement label Oct 15, 2020
@SeanNaren SeanNaren self-assigned this Oct 15, 2020
@SeanNaren SeanNaren requested review from a team and Borda and removed request for Borda October 15, 2020 15:44
@pep8speaks
Copy link

pep8speaks commented Oct 15, 2020

Hello @SeanNaren! Thanks for updating this PR.

Line 281:86: W292 no newline at end of file

Comment last updated at 2020-11-24 21:16:57 UTC

@SeanNaren
Copy link
Contributor Author

SeanNaren commented Oct 16, 2020

Going to put some stats here, I've been tracking CUDA memory using a callback inspired by this:

import torch
from pytorch_lightning import Callback


class CUDACallback(Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        # Reset the memory use counter
        torch.cuda.reset_peak_memory_stats(trainer.global_rank)
        torch.cuda.synchronize(trainer.global_rank)

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        torch.cuda.synchronize(trainer.global_rank)
        max_memory = torch.cuda.max_memory_allocated(trainer.global_rank) / 2 ** 20

        print(f"[{trainer.global_rank}] : Peak memory {max_memory:.1f}MiB")

Running this transformers script for token classification I've got some average peak memory stats after the first epoch:

Average peak memory allocated after 1 epoch on p3.8xlarge and p3.16xlarge instance types.

4GPUs DDP: 6840.1MiB
8GPUs DDP: 6840.1MiB
4GPUs FairScale OSS: 5263.2MiB (23% memory improvement compared to DDP)
8GPUs FairScale OSS: 4899.03MiB (28.38% memory improvement compared to DDP)

@blefaudeux
Copy link

Going to put some stats here, I've been tracking CUDA memory using a callback inspired by this:

import torch
from pytorch_lightning import Callback


class CUDACallback(Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        # Reset the memory use counter
        torch.cuda.reset_peak_memory_stats(trainer.global_rank)

        # Dummy training loop
        torch.cuda.synchronize(trainer.global_rank)

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        torch.cuda.synchronize(trainer.global_rank)
        max_memory = torch.cuda.max_memory_allocated(trainer.global_rank) / 2 ** 20

        print(f"[{trainer.global_rank}] : Peak memory {max_memory:.1f}MiB")

Running this transformers script for token classification I've got some average peak memory stats after the first epoch:

Average peak memory allocated after 1 epoch on p3.8xlarge and p3.16xlarge instance types.

4GPUs DDP: 6840.1MiB
8GPUs DDP: 6840.1MiB
4GPUs FairScale OSS: 5263.2MiB (23% memory improvement compared to DDP)
8GPUs FairScale OSS: 4899.03MiB (28.38% memory improvement compared to DDP)

Nice ! FYI there are more savings on the way, it would require wrapping the model and not using DDP (and autograd hook does the job) but that solves the gradient reduce issue and reduces the communications, which is important when multiple nodes are involved. It does not deprecate the version you tested, it just comes on top (if OSS is used with DDP you get what you measured, if OSS is used with a model wrap instead you get a better gradient sharding)

@SeanNaren
Copy link
Contributor Author

Thanks @blefaudeux! Was excited to get numbers for this, really awesome work overall :)

Are the additional savings only tied to fairscale? Is there a place I can have a look at the code? Would be good to get a head start on figuring out API changes to support this

@blefaudeux
Copy link

Thanks @blefaudeux! Was excited to get numbers for this, really awesome work overall :)

Are the additional savings only tied to fairscale? Is there a place I can have a look at the code? Would be good to get a head start on figuring out API changes to support this

No PR yet, the work is in https://github.com/facebookresearch/fairscale/tree/oss_autograd
The interface could be

model = myAwesomeModel()
optimizer = OSS(*optimizer_params)
model = ShardedDDP(model, optimizer, *some_basic_params)

and then a normal training loop with optimizer and model. Gradients are automatically reduced to the right rank and the optimizer state and gradients are sharded, which shaves some more memory from what you have right now.

What do you think ?

@SeanNaren
Copy link
Contributor Author

Yep interface makes sense to me. I also see the model dispatch making its way into this branch which is super cool!

I don't think it makes sense to support both DDP + OSS and Sharded DDP + OSS since we'll need to install fairscale regardless and Sharded DDP seems like the successor. I think it makes sense to go forward ShardedDDP and have it replace the current implementation, thoughts here?

@blefaudeux
Copy link

blefaudeux commented Oct 16, 2020

Yep interface makes sense to me. I also see the model dispatch making its way into this branch which is super cool!

I don't think it makes sense to support both DDP + OSS and Sharded DDP + OSS since we'll need to install fairscale regardless and Sharded DDP seems like the successor. I think it makes sense to go forward ShardedDDP and have it replace the current implementation, thoughts here?

It may require some benchmarking, basically one feature from DDP which is hard to replicate is the overlap in between BW and reduce, the gradients in DDP are all-reduced step by step when walking back the graph concurrently with the BW computations. Currently what's "easy" to do is FW then BW then reduce (not all reduce), but the overlap is lost. When we shard the model (next steps, a little more involving) it's not that much of an issue because we can overlap the reduce of the lower shard with the BW of the upper one.

From what I can see, currently (state+gradient sharding):

  • on a single node OSS + DDP is a little faster than OSS+ad-hoc-reduce but consumes more memory
  • when multiple nodes are involved, then OSS+ad-hoc-reduce wins it all (faster and less memory)

cc @msbaines and @mrshenli in case you're interested

@SeanNaren
Copy link
Contributor Author

Yeah still not sure it's worth supporting the DDP version. Having the DDP lightning wrapper just for a small benefit in speed for single node setups (and personally I think what's more important is reducing memory allocation) where the cap of total GPUs is small, I don't think is necessary and adds confusion unless I'm mistaken!

@SeanNaren SeanNaren changed the title Introduce fairscale accelerator [WIP] Introduce fairscale accelerator Oct 18, 2020
@SeanNaren
Copy link
Contributor Author

SeanNaren commented Oct 18, 2020

Pushing WIP changes integrating ShardedDDP using the oss_autograd fairscale branch, in a similar fashion to vanilla DDP. Unfortunately I had to do quite a bit of overriding due since the model requires multiple arguments (not just the input to the model, but multi-input/targets for forward/loss calculation).

Running into an issue using multiple GPUs where training hangs, currently investigating this. Also seeing if there is a nicer way to handle requiring grads for the input. Main issue being I've made the assumption (for now) that the inputs are tuples within the batch i.e within ModelDispatch:

# All inputs need to required_grad for autograd to properly track the first dispatch layer
# Will currently break if a dict or something, may require a recursive check
if isinstance(inputs, tuple):
     for i in inputs:
           i.requires_grad = True

@blefaudeux after investigation I've noticed that it only hangs when using torch automatic mixed precision autocast with OSS + SDP. Any reasons you think this could happen off the top of your head? Will continue to investigate, but curious if you had any suggestions!

@SeanNaren
Copy link
Contributor Author

SeanNaren commented Oct 25, 2020

@williamFalcon I've been trying to keep up with the plugin API which is neat! I still think the fairscale integration should live as a native accelerator, because there are more features to come from fairscale, but was curious on your thoughts.

EDIT: offline conclusion is living as an accelerator :)

@SeanNaren
Copy link
Contributor Author

A few updates to track here, running into issues supporting kwarg input for forward/backward pass. Seems a solution was discussed here: facebookresearch/fairscale#160 (comment) and here: facebookresearch/fairscale#160 (comment)

Seems like there are some performance issues when moving to multi-node which I haven't been able to test that ben has reported: facebookresearch/fairscale#157 (comment)

Since I don't think its worth longer term to just get the OSS optimizer into lightning without SDP, we'll put this on hold till performance issues and function args are solved!

@ananthsub
Copy link
Contributor

@williamFalcon I've been trying to keep up with the plugin API which is neat! I still think the fairscale integration should live as a native accelerator, because there are more features to come from fairscale, but was curious on your thoughts.

naming nit: can we name this as sharded ddp accelerator? fairscale as a library will eventually have more components we may want to plug in elsewhere

@SeanNaren
Copy link
Contributor Author

@ananthsub I was thinking about that... even ZeRO is better or Deepspeed, but I don't mind just calling it ShardedDDP

@SeanNaren SeanNaren changed the title [WIP] Introduce fairscale accelerator [WIP] Introduce Sharded Accelerator Nov 1, 2020
SeanNaren and others added 8 commits November 18, 2020 11:00
# Conflicts:
#	pytorch_lightning/accelerators/accelerator.py
#	pytorch_lightning/plugins/ddp_plugin.py
#	pytorch_lightning/trainer/connectors/checkpoint_connector.py
# Conflicts:
#	pytorch_lightning/accelerators/ddp2_accelerator.py
#	pytorch_lightning/accelerators/ddp_accelerator.py
#	pytorch_lightning/accelerators/ddp_hpc_accelerator.py
#	pytorch_lightning/plugins/ddp_plugin.py
# Conflicts:
#	pytorch_lightning/trainer/training_loop.py
@Borda Borda marked this pull request as ready for review November 30, 2020 17:57
@Borda Borda changed the title [WIP] Introduce Sharded Plugin Sharded Plugin Nov 30, 2020
@edenlightning edenlightning removed this from the 1.1 milestone Dec 8, 2020
@mergify
Copy link
Contributor

mergify bot commented Dec 12, 2020

This pull request is now in conflict... :(

@SeanNaren SeanNaren closed this Dec 12, 2020
@SeanNaren SeanNaren deleted the feature/817-fairscale branch December 12, 2020 17:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature Is an improvement or enhancement

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add deepspeed support

7 participants