Skip to content

Fix: corrected fsdp in GRPO trainer #3582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

tryumanshow
Copy link

@tryumanshow tryumanshow commented Jun 13, 2025

What does this PR do?

Fixes a bug where GRPOTrainer using Fully Sharded Data Parallel (FSDP) with vLLM inference fails with AssertionError: Non-root FSDP instance's _is_root should not have been set yet or should have been set to False during parameter sync.

Fixes: #3394 (🧑‍🤝‍🧑 Co-Locating vLLM w/ training to for higher throughput and GPU utilization)

Root Cause

When syncing FSDP parameters to vLLM, summon_full_params was recursively called for every FSDP submodule, causing PyTorch's FSDP internal state to become inconsistent. PyTorch expects only the root FSDP module to perform summon_full_params(recurse=True).

Improvements:

  • Refactors GRPOTrainer._sync_fsdp_params_to_vllm to call FSDP.summon_full_params once at the root (recurse=True), instead of recursively calling it for each FSDP submodule.
  • Prevents assertion errors in multi-GPU FSDP training with vLLM parameter syncing.
  • Ensures memory-efficient, correct traversal for parameter extraction and weight updates to vLLM.

Testing

Click to view fsdp.yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: false
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

vllm server mode

click to view config.yaml
# Model arguments
model_name_or_path: Qwen/Qwen2.5-3B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
dataset_name: DigitalLearningGmbH/MATH-lighteval
dataset_config: default
dataset_prompt_column: problem
system_prompt: "You are a helpful AI Assistant, designed to provided well-reasoned and detailed responses. You FIRST think about the reasoning process as an internal monologue and then provide the user with the answer. The reasoning process MUST BE enclosed within <think> and </think> tags."

# GRPO trainer config
bf16: true
use_vllm: true
do_eval: false
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
learning_rate: 3.0e-06
log_completions: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: 50
num_generations: 16
num_train_epochs: 1
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 16
push_to_hub: false
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: steps
save_steps: 100
save_total_limit: 1
seed: 42
warmup_ratio: 0.1
# rollout
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-3B-Instruct

# train
CUDA_VISIBLE_DEVICES=1,2,3,4 ACCELERATE_LOG_LEVEL=info  accelerate launch --config_file ./accelerate_configs/fsdp.yaml --num_processes 4  open_r1.grpo --config config.yaml
click to view test images
  • GPU Occupation

    image

  • Training Log

    image

vllm colocate mode

click to view config.yaml
# Model arguments
model_name_or_path: Qwen/Qwen2.5-3B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
dataset_name: DigitalLearningGmbH/MATH-lighteval
dataset_config: default
dataset_prompt_column: problem
system_prompt: "You are a helpful AI Assistant, designed to provided well-reasoned and detailed responses. You FIRST think about the reasoning process as an internal monologue and then provide the user with the answer. The reasoning process MUST BE enclosed within <think> and </think> tags."

# GRPO trainer config
bf16: true
use_vllm: true
vllm_mode: "colocate"
vllm_tensor_parallel_size: 4
vllm_gpu_memory_utilization: 0.3
do_eval: false
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
learning_rate: 3.0e-06
log_completions: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: 50
num_generations: 16
num_train_epochs: 1
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 16
push_to_hub: false
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: steps
save_steps: 100
save_total_limit: 1
seed: 42
warmup_ratio: 0.1
# rollout & train
CUDA_VISIBLE_DEVICES=0,1,2,3 ACCELERATE_LOG_LEVEL=info  accelerate launch --config_file ./accelerate_configs/fsdp.yaml --num_processes 4  open_r1.grpo --config config.yaml
click to view test images
  • GPU Occupation

    image

  • Training Log

    image

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

New version of #3394 (partially modified the FSDP section)

CC @qgallouedec

@tryumanshow tryumanshow marked this pull request as draft June 13, 2025 19:36
@tryumanshow tryumanshow marked this pull request as ready for review June 13, 2025 19:58
@kashif kashif requested a review from Copilot June 21, 2025 06:19
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a bug in GRPOTrainer related to FSDP parameter syncing with vLLM, avoiding assertion errors by refactoring the synchronization logic.

  • Refactored _sync_fsdp_params_to_vllm to restrict FSDP.summon_full_params with recurse=True to only the root FSDP module.
  • Revised recursive traversal to ensure proper syncing of parameters without reprocessing submodules.
Comments suppressed due to low confidence (1)

trl/trainer/grpo_trainer.py:884

  • Using getattr(module, '_is_root', True) defaults to True, which may inadvertently treat modules without the '_is_root' attribute as root FSDP modules. Consider ensuring that the attribute is explicitly set for non-root FSDP modules to avoid potential unintended behavior.
        if isinstance(module, FSDP) and getattr(module, '_is_root', True):

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif
Copy link
Collaborator

kashif commented Jun 23, 2025

would you mind fixing the formatting issue by doing: make precommit in the root of the TRL repo?

@tryumanshow
Copy link
Author

would you mind fixing the formatting issue by doing: make precommit in the root of the TRL repo?

Sure! I'm done with it!

@shirinyamani shirinyamani self-requested a review June 24, 2025 15:09
@shirinyamani
Copy link
Member

Hi @tryumanshow thanks for your contribution!
I wanna test your pr, for that i wanna make sure i understand correctly the flow of your work.

click to view config.yaml

# rollout
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-3B-Instruct

# train
CUDA_VISIBLE_DEVICES=1,2,3,4 ACCELERATE_LOG_LEVEL=info  accelerate launch --config_file ./accelerate_configs/fsdp.yaml --num_processes 4  open_r1.grpo --config config.yaml

Here open_r1.grpo are you using the grpo_script from open-r1 as your train script?

@tryumanshow
Copy link
Author

Hi @tryumanshow thanks for your contribution! I wanna test your pr, for that i wanna make sure i understand correctly the flow of your work.

click to view config.yaml

# rollout
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-3B-Instruct

# train
CUDA_VISIBLE_DEVICES=1,2,3,4 ACCELERATE_LOG_LEVEL=info  accelerate launch --config_file ./accelerate_configs/fsdp.yaml --num_processes 4  open_r1.grpo --config config.yaml

Here open_r1.grpo are you using the grpo_script from open-r1 as your train script?

Hi, @shirinyamani.

Yes, I used the script from open-r1!

@kashif
Copy link
Collaborator

kashif commented Jun 28, 2025

@tryumanshow just to replicate can you also paste your trl env ?

@tryumanshow
Copy link
Author

@tryumanshow just to replicate can you also paste your trl env ?

I used the trl==0.18.1.

@mcleish7
Copy link

I had the same error and this fix worked for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants