-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix: corrected fsdp in GRPO trainer #3582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes a bug in GRPOTrainer related to FSDP parameter syncing with vLLM, avoiding assertion errors by refactoring the synchronization logic.
- Refactored _sync_fsdp_params_to_vllm to restrict FSDP.summon_full_params with recurse=True to only the root FSDP module.
- Revised recursive traversal to ensure proper syncing of parameters without reprocessing submodules.
Comments suppressed due to low confidence (1)
trl/trainer/grpo_trainer.py:884
- Using getattr(module, '_is_root', True) defaults to True, which may inadvertently treat modules without the '_is_root' attribute as root FSDP modules. Consider ensuring that the attribute is explicitly set for non-root FSDP modules to avoid potential unintended behavior.
if isinstance(module, FSDP) and getattr(module, '_is_root', True):
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
would you mind fixing the formatting issue by doing: |
Sure! I'm done with it! |
Hi @tryumanshow thanks for your contribution!
Here |
Hi, @shirinyamani. Yes, I used the script from |
@tryumanshow just to replicate can you also paste your |
I used the |
I had the same error and this fix worked for me. |
What does this PR do?
Fixes a bug where GRPOTrainer using Fully Sharded Data Parallel (FSDP) with vLLM inference fails with
AssertionError: Non-root FSDP instance's _is_root should not have been set yet or should have been set to False
during parameter sync.Fixes: #3394 (🧑🤝🧑 Co-Locating vLLM w/ training to for higher throughput and GPU utilization)
Root Cause
When syncing FSDP parameters to vLLM,
summon_full_params
was recursively called for every FSDP submodule, causing PyTorch's FSDP internal state to become inconsistent. PyTorch expects only the root FSDP module to performsummon_full_params(recurse=True)
.Improvements:
GRPOTrainer._sync_fsdp_params_to_vllm
to callFSDP.summon_full_params
once at the root (recurse=True
), instead of recursively calling it for each FSDP submodule.Testing
Click to view fsdp.yaml
vllm server mode
click to view config.yaml
click to view test images
GPU Occupation
Training Log
vllm colocate mode
click to view config.yaml
# rollout & train CUDA_VISIBLE_DEVICES=0,1,2,3 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file ./accelerate_configs/fsdp.yaml --num_processes 4 open_r1.grpo --config config.yaml
click to view test images
GPU Occupation
Training Log
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
New version of #3394 (partially modified the FSDP section)
CC @qgallouedec