diff --git a/apps/grpo/qwen3_30b_a3b.yaml b/apps/grpo/qwen3_30b_a3b.yaml new file mode 100644 index 000000000..cf155c2b7 --- /dev/null +++ b/apps/grpo/qwen3_30b_a3b.yaml @@ -0,0 +1,186 @@ +# Grouped Relative Policy Optimization (GRPO) for Qwen3 30B-A3B +# >>> python -m apps.grpo.main --config apps/grpo/qwen3_30b_a3b.yaml + +# Global configuration +group_size: 4 # Reduced for initial testing to avoid OOM +local_batch_size: 1 # per-device batch size (reduced for 30B MoE model to avoid OOM) +max_req_tokens: 512 # Reduced for initial testing +max_res_tokens: 512 # Reduced for initial testing +model: "Qwen/Qwen3-30B-A3B" +off_by_n: 1 # Off by one by default + +# GPU allocation for single-node (8 GPUs total): +# - Trainer: 4 GPUs (EP=4 for MoE experts) +# - Policy: 2 GPUs (EP=2 for MoE experts) +# - Ref Model: 2 GPUs (EP=2 for MoE experts) + +# Main loop configuration +rollout_threads: 1 # Recommended to set equal to policy.num_replicas + +# Observability configuration +metric_logging: + wandb: + project: grpo-training + group: qwen3_30b_a3b_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + console: + logging_mode: global_reduce + +# Dataset configuration +dataset: + path: "openai/gsm8k" + revision: "main" + data_split: "train" + streaming: true + model: ${model} + +# Policy configuration (uses vLLM for generation) +policy: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs + model: ${model} + tensor_parallel_size: 2 # 2 GPUs for policy + pipeline_parallel_size: 1 + enable_expert_parallel: true # Enable expert parallelism for MoE + enforce_eager: false + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams + n: ${group_size} + max_tokens: ${max_res_tokens} + temperature: 1.0 + top_p: 1.0 + +# Trainer configuration +trainer: + model: + name: qwen3 + flavor: 30B-A3B + hf_assets_path: hf://${model} + optimizer: + name: AdamW + lr: 8e-5 + eps: 1e-8 + lr_scheduler: + warmup_steps: 100 + decay_ratio: 0.8 + decay_type: "linear" + min_lr_factor: 0.0 + training: + local_batch_size: ${local_batch_size} + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens + max_norm: 1.0 + steps: 1000000 + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + components: ["model", "loss"] + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 4 # Must satisfy: dp_replicate * dp_shard * cp * tp * pp = world_size (4) + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 4 # EP borrows from dp_shard for MoE + expert_tensor_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + folder: ./checkpoint # The folder to save checkpoints to. + initial_load_path: hf://${model} + initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo + last_save_in_hf: true + interval: 500 + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + quantize: + linear: + float8: + enable_fsdp_float8_all_gather: false + precompute_float8_dynamic_scale_for_fsdp: false + recipe_name: "rowwise" + filter_fqns: ["output", "router.gate"] + grouped_mm: + float8: + fqns: ["experts"] + comm: + trace_buf_size: 0 + +# Replay buffer configuration +replay_buffer: + batch_size: ${local_batch_size} + max_policy_age: ${off_by_n} + dp_size: 4 # Total DP degree: dp_replicate * dp_shard = 1 * 4 = 4 + +# Reference model configuration +ref_model: + model: + name: qwen3 + flavor: 30B-A3B + hf_assets_path: hf://${model} + training: + seq_len: ${trainer.training.seq_len} + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 # Must satisfy: dp_replicate * dp_shard * cp * tp * pp = world_size (2) + tensor_parallel_degree: 2 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 2 # EP borrows from dp_shard for MoE + expert_tensor_parallel_degree: 1 + checkpoint: + enable: true + initial_load_path: hf://${model} + initial_load_in_hf: true + quantize: + linear: + float8: + enable_fsdp_float8_all_gather: false + precompute_float8_dynamic_scale_for_fsdp: false + recipe_name: "rowwise" + filter_fqns: ["output", "router.gate"] + grouped_mm: + float8: + fqns: ["experts"] + comm: + trace_buf_size: 0 + +# All resource allocations +services: + policy: + procs: 2 # 2 GPUs for policy with expert parallelism + num_replicas: 1 + mesh_name: policy + with_gpus: true + ref_model: + procs: 2 # 2 GPUs for reference model with expert parallelism + num_replicas: 1 + mesh_name: ref_model + with_gpus: true + reward_actor: + procs: 1 + num_replicas: 1 + mesh_name: reward_actor + with_gpus: false + +actors: + dataset: + procs: 1 + with_gpus: false + mesh_name: dataset + trainer: + procs: 4 # 4 GPUs for trainer with expert parallelism + with_gpus: true + mesh_name: trainer + replay_buffer: + procs: 1 + with_gpus: false + mesh_name: replay_buffer + compute_advantages: + procs: 1 + with_gpus: false + mesh_name: compute_advantages diff --git a/pyproject.toml b/pyproject.toml index 8460b5b78..9ad7363b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ # PyTorch "torch==2.9.0", "torchdata>=0.8.0", - "torchtitan==0.2.0", + "torchtitan==0.1.0.dev20251029", "torchmonarch==0.1.2", "torchstore==0.1.2", # vLLM @@ -83,6 +83,11 @@ members = [ name = "pytorch-cu128" url = "https://download.pytorch.org/whl/cu128" +# pytorch nightly +[[tool.uv.index]] +name = "pytorch-nightly-cu128" +url = "https://download.pytorch.org/whl/nightly/cu128" + # vllm [[tool.uv.index]] name = "vllm-forge" @@ -90,6 +95,7 @@ url = "https://download.pytorch.org/whl/preview/forge" [tool.uv.sources] torch = { index = "pytorch-cu128" } +torchtitan = { index = "pytorch-nightly-cu128" } vllm = { index = "vllm-forge" } [tool.uv] diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index 02a6e1410..973220f6a 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -22,6 +22,7 @@ Compile, Model, Parallelism, + Quantize, Training, ) from torchtitan.experiments.forge.engine import ForgeEngine @@ -61,6 +62,7 @@ class ReferenceModel(ForgeActor): (TP, PP, CP, DP) checkpoint (Checkpoint): Checkpoint loading configuration compile (Compile): Torch compilation settings + quantize (Quantize): Quantization settings (float8, etc.) comm (Comm): Communication backend configuration training (Training): Training-related settings (dtype, garbage collection, etc.) @@ -71,6 +73,7 @@ class ReferenceModel(ForgeActor): parallelism: Parallelism = field(default_factory=Parallelism) checkpoint: Checkpoint = field(default_factory=Checkpoint) compile: Compile = field(default_factory=Compile) + quantize: Quantize = field(default_factory=Quantize) comm: Comm = field(default_factory=Comm) training: Training = field( default_factory=Training