Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions apps/grpo/qwen3_30b_a3b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Grouped Relative Policy Optimization (GRPO) for Qwen3 30B-A3B
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_30b_a3b.yaml

# Global configuration
group_size: 4 # Reduced for initial testing to avoid OOM
local_batch_size: 1 # per-device batch size (reduced for 30B MoE model to avoid OOM)
max_req_tokens: 512 # Reduced for initial testing
max_res_tokens: 512 # Reduced for initial testing
model: "Qwen/Qwen3-30B-A3B"
off_by_n: 1 # Off by one by default

# GPU allocation for single-node (8 GPUs total):
# - Trainer: 4 GPUs (EP=4 for MoE experts)
# - Policy: 2 GPUs (EP=2 for MoE experts)
# - Ref Model: 2 GPUs (EP=2 for MoE experts)

# Main loop configuration
rollout_threads: 1 # Recommended to set equal to policy.num_replicas

# Observability configuration
metric_logging:
wandb:
project: grpo-training
group: qwen3_30b_a3b_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
console:
logging_mode: global_reduce

# Dataset configuration
dataset:
path: "openai/gsm8k"
revision: "main"
data_split: "train"
streaming: true
model: ${model}

# Policy configuration (uses vLLM for generation)
policy:
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: ${model}
tensor_parallel_size: 2 # 2 GPUs for policy
pipeline_parallel_size: 1
enable_expert_parallel: true # Enable expert parallelism for MoE
enforce_eager: false
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0

# Trainer configuration
trainer:
model:
name: qwen3
flavor: 30B-A3B
hf_assets_path: hf://${model}
optimizer:
name: AdamW
lr: 8e-5
eps: 1e-8
lr_scheduler:
warmup_steps: 100
decay_ratio: 0.8
decay_type: "linear"
min_lr_factor: 0.0
training:
local_batch_size: ${local_batch_size}
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
enable: false
components: ["model", "loss"]
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 4 # Must satisfy: dp_replicate * dp_shard * cp * tp * pp = world_size (4)
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 4 # EP borrows from dp_shard for MoE
expert_tensor_parallel_degree: 1
disable_loss_parallel: true
checkpoint:
enable: true
folder: ./checkpoint # The folder to save checkpoints to.
initial_load_path: hf://${model}
initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
last_save_in_hf: true
interval: 500
async_mode: "disabled"
activation_checkpoint:
mode: selective
selective_ac_option: op
quantize:
linear:
float8:
enable_fsdp_float8_all_gather: false
precompute_float8_dynamic_scale_for_fsdp: false
recipe_name: "rowwise"
filter_fqns: ["output", "router.gate"]
grouped_mm:
float8:
fqns: ["experts"]
comm:
trace_buf_size: 0

# Replay buffer configuration
replay_buffer:
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
dp_size: 4 # Total DP degree: dp_replicate * dp_shard = 1 * 4 = 4

# Reference model configuration
ref_model:
model:
name: qwen3
flavor: 30B-A3B
hf_assets_path: hf://${model}
training:
seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1 # Must satisfy: dp_replicate * dp_shard * cp * tp * pp = world_size (2)
tensor_parallel_degree: 2
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 2 # EP borrows from dp_shard for MoE
expert_tensor_parallel_degree: 1
checkpoint:
enable: true
initial_load_path: hf://${model}
initial_load_in_hf: true
quantize:
linear:
float8:
enable_fsdp_float8_all_gather: false
precompute_float8_dynamic_scale_for_fsdp: false
recipe_name: "rowwise"
filter_fqns: ["output", "router.gate"]
grouped_mm:
float8:
fqns: ["experts"]
comm:
trace_buf_size: 0

# All resource allocations
services:
policy:
procs: 2 # 2 GPUs for policy with expert parallelism
num_replicas: 1
mesh_name: policy
with_gpus: true
ref_model:
procs: 2 # 2 GPUs for reference model with expert parallelism
num_replicas: 1
mesh_name: ref_model
with_gpus: true
reward_actor:
procs: 1
num_replicas: 1
mesh_name: reward_actor
with_gpus: false

actors:
dataset:
procs: 1
with_gpus: false
mesh_name: dataset
trainer:
procs: 4 # 4 GPUs for trainer with expert parallelism
with_gpus: true
mesh_name: trainer
replay_buffer:
procs: 1
with_gpus: false
mesh_name: replay_buffer
compute_advantages:
procs: 1
with_gpus: false
mesh_name: compute_advantages
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies = [
# PyTorch
"torch==2.9.0",
"torchdata>=0.8.0",
"torchtitan==0.2.0",
"torchtitan==0.1.0.dev20251029",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed for the config to work?

"torchmonarch==0.1.2",
"torchstore==0.1.2",
# vLLM
Expand Down Expand Up @@ -83,13 +83,19 @@ members = [
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"

# pytorch nightly
[[tool.uv.index]]
name = "pytorch-nightly-cu128"
url = "https://download.pytorch.org/whl/nightly/cu128"

# vllm
[[tool.uv.index]]
name = "vllm-forge"
url = "https://download.pytorch.org/whl/preview/forge"

[tool.uv.sources]
torch = { index = "pytorch-cu128" }
torchtitan = { index = "pytorch-nightly-cu128" }
vllm = { index = "vllm-forge" }

[tool.uv]
Expand Down
3 changes: 3 additions & 0 deletions src/forge/actors/reference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Compile,
Model,
Parallelism,
Quantize,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want these changes to be part of this PR?

Training,
)
from torchtitan.experiments.forge.engine import ForgeEngine
Expand Down Expand Up @@ -61,6 +62,7 @@ class ReferenceModel(ForgeActor):
(TP, PP, CP, DP)
checkpoint (Checkpoint): Checkpoint loading configuration
compile (Compile): Torch compilation settings
quantize (Quantize): Quantization settings (float8, etc.)
comm (Comm): Communication backend configuration
training (Training): Training-related settings (dtype, garbage
collection, etc.)
Expand All @@ -71,6 +73,7 @@ class ReferenceModel(ForgeActor):
parallelism: Parallelism = field(default_factory=Parallelism)
checkpoint: Checkpoint = field(default_factory=Checkpoint)
compile: Compile = field(default_factory=Compile)
quantize: Quantize = field(default_factory=Quantize)
comm: Comm = field(default_factory=Comm)
training: Training = field(
default_factory=Training
Expand Down
Loading