Skip to content

Conversation

@jstjohn
Copy link
Contributor

@jstjohn jstjohn commented Jul 1, 2022

What does this PR do?

Possibly fixes #13500, needs more testing.

Does your PR introduce any breaking changes? If yes, please list them.

Potential breaking changes (unsure because behavior looks ok on my single node multi-gpu setup):

  • In the code block below in FSDP, it's possible that this change will change the interaction between the optimizer and the model and sharding. The model does train and weights are optimized, but I do not have any large language model type tests to make sure that sharding is working properly. Prior to this change, the outer Shard wrapper would not have been unwrapped, while other things would be unwrapped? It is worth checking that these changes do not break sharding for large language models, however the current behavior is arguably worse, the checkpoints themselves are likely corrupted in some way.
    def configure_ddp(self) -> None:
        log.detail(f"{self.__class__.__name__}: configuring FSDP... (cpu_offload: [{self.cpu_offload}])")
        if not self.cpu_offload:
            # When using CPU Offload, FSDP will manage the CUDA movement for us.
            # Note: this would be problematic for large model (which could not fit in one GPU)
            # as FSDP module.to(device) would first summon all parameters
            # (TODO: need to figure out solution)
            self.model_to_device()
        # setup optimizers after fully sharded has wrapped the lightning module
        self.setup_optimizers(self.lightning_module.trainer)

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

… around not unwrapping FullyShardedDataParallel wrapped models
@awaelchli awaelchli added bug Something isn't working strategy: fairscale fsdp (removed) Fully Sharded Data Parallel labels Jul 1, 2022
@awaelchli
Copy link
Contributor

@jstjohn Thanks for looking into this.
The unwrapping logic is a bit overcomplicated and as you found a source for bugs. I propose to completely remove it and simplify it greatly. Here is the proposal #13502. Not sure whether that would solve your bug as well but just wanted to point out the related issue here.

@jstjohn
Copy link
Contributor Author

jstjohn commented Jul 5, 2022

@jstjohn Thanks for looking into this. The unwrapping logic is a bit overcomplicated and as you found a source for bugs. I propose to completely remove it and simplify it greatly. Here is the proposal #13502. Not sure whether that would solve your bug as well but just wanted to point out the related issue here.

@awaelchli your solution seems great. Should I close this PR and let you take this on with that approach? I am also wondering whether or not any of the tests should have caught this? I see tests in place to make sure that weights match after checkpointing but I don't know if they cover the parallel case and also it looks like they only run on pytorch 1.12?

@jstjohn
Copy link
Contributor Author

jstjohn commented Jul 14, 2022

Closing this PR since I like the solution outlined in #13502 better, and will let @awaelchli execute on that.

@jstjohn jstjohn closed this Jul 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working strategy: fairscale fsdp (removed) Fully Sharded Data Parallel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FullyShardedDataParallel wrapped models not being unwrapped, leading to incorrect checkpoints.

2 participants