diff --git a/examples/geo3k/prepare_geo3k_data.py b/examples/geo3k/prepare_geo3k_data.py
new file mode 100755
index 000000000..09f536581
--- /dev/null
+++ b/examples/geo3k/prepare_geo3k_data.py
@@ -0,0 +1,126 @@
+#!/usr/bin/env python3
+"""Prepare GEO3K multimodal geometry dataset for RLLM training."""
+
+import base64
+from io import BytesIO
+from typing import Iterable, List
+
+from datasets import load_dataset
+from PIL import Image
+
+from rllm.data.dataset import DatasetRegistry
+
+DATA_URI_PREFIX = "data:image/png;base64,"
+
+
+def _serialize_images(images: Iterable[Image.Image]) -> List[str]:
+ """Serialize a list of PIL images into base64 data URIs for Parquet storage."""
+
+ serialized: list[str] = []
+ for image in images or []:
+ if not isinstance(image, Image.Image):
+ continue
+ buffer = BytesIO()
+ image.convert("RGB").save(buffer, format="PNG")
+ encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
+ serialized.append(f"{DATA_URI_PREFIX}{encoded}")
+ return serialized
+
+
+def prepare_geo3k_data():
+ """
+ Prepare GEO3K dataset following RLLM conventions.
+
+ Returns:
+ Tuple of (train_dataset, test_dataset) registered with DatasetRegistry
+ """
+ print("š„ Loading GEO3K dataset from HuggingFace...")
+ data_source = "hiyouga/geometry3k"
+ dataset = load_dataset(data_source)
+
+ train_dataset = dataset["train"]
+ test_dataset = dataset["test"]
+
+ print(f"ā
Dataset loaded:")
+ print(f" - Train samples: {len(train_dataset)}")
+ print(f" - Test samples: {len(test_dataset)}")
+
+ # Instruction template based on Verl's GEO3K processing
+ instruction_following = (
+ "You must strictly follow this answer template: "
+ "(1) write all reasoning as an internal monologue inside ...; "
+ "(2) after , output exactly one line in the form `Final answer: \\boxed{value}` with the numeric solution."
+ )
+
+ formatting_example = (
+ "Reason step by step here....\n"
+ "Final answer: \\boxed{42}"
+ )
+
+ def preprocess_fn(example, idx):
+ """
+ Preprocess function to convert GEO3K data to RLLM format.
+
+ This follows the pattern from verl/examples/data_preprocess/geo3k.py
+ but adapts it for RLLM's simpler format requirements.
+ """
+ problem = example["problem"]
+ answer = example["answer"]
+ images = _serialize_images(example.get("images", []))
+
+ # Create the full prompt with instruction following
+ prompt = problem + "\n\n" + instruction_following + "\nExample format:\n" + formatting_example
+
+ # Return RLLM-compatible format
+ return {
+ "question": prompt, # RLLM expects 'question' field
+ "ground_truth": answer, # RLLM expects 'ground_truth' field
+ "data_source": "hiyouga/geometry3k", # Data source identifier matching verl's reward function
+ "images": images, # Serialized data URIs; reconstructed during training/inference
+ "extra_info": {
+ "original_problem": problem,
+ "answer": answer,
+ "has_images": len(images) > 0,
+ "num_images": len(images),
+ "formatting_example": formatting_example,
+ }
+ }
+
+ print("š Preprocessing datasets...")
+
+ # Apply preprocessing to both splits
+ train_dataset = train_dataset.map(preprocess_fn, with_indices=True)
+ test_dataset = test_dataset.map(preprocess_fn, with_indices=True)
+
+ # Register datasets with RLLM DatasetRegistry
+ print("š Registering datasets with RLLM...")
+
+ train_dataset = DatasetRegistry.register_dataset("geo3k_train", train_dataset, "train")
+ test_dataset = DatasetRegistry.register_dataset("geo3k_test", test_dataset, "test")
+
+ print("ā
Datasets registered:")
+ print(" - geo3k_train (training data)")
+ print(" - geo3k_test (evaluation data)")
+
+ return train_dataset, test_dataset
+
+
+if __name__ == "__main__":
+ train_dataset, test_dataset = prepare_geo3k_data()
+ print(f"\nš Dataset Statistics:")
+ print(f" - Training samples: {len(train_dataset)}")
+ print(f" - Test samples: {len(test_dataset)}")
+
+ # Show a sample to verify format
+ sample = train_dataset[0]
+ print(f"\nš Sample data format:")
+ print(f" - Question length: {len(sample['question'])} chars")
+ print(f" - Has images: {len(sample['images']) > 0}")
+ print(f" - Number of images: {len(sample['images'])}")
+ print(f" - Ground truth: {sample['ground_truth'][:100]}...")
+
+ print(f"\nš GEO3K dataset preparation complete!")
+ print(f"\nNext steps:")
+ print(f"1. Configure multimodal model in training config")
+ print(f"2. Run training: python train_geo3k_agent.py")
+ print(f"3. Test inference: python run_geo3k_agent.py")
diff --git a/examples/geo3k/run_geo3k_agent.py b/examples/geo3k/run_geo3k_agent.py
new file mode 100755
index 000000000..ae1f27b66
--- /dev/null
+++ b/examples/geo3k/run_geo3k_agent.py
@@ -0,0 +1,168 @@
+#!/usr/bin/env python3
+"""
+Inference script for GEO3K multimodal geometry agent.
+
+This script tests the trained multimodal agent on GEO3K geometry problems
+and evaluates its performance on problems with images.
+"""
+
+import asyncio
+import os
+
+from transformers import AutoTokenizer
+
+from rllm.agents import Geo3kAgent
+from rllm.data.dataset import DatasetRegistry
+from rllm.engine.agent_execution_engine import AgentExecutionEngine
+from rllm.environments.base.single_turn_env import SingleTurnEnvironment
+from rllm.rewards.reward_fn import math_reward_fn
+from rllm.utils import compute_pass_at_k
+
+
+def main():
+ """
+ Main inference function for testing GEO3K multimodal agent.
+ """
+ print("š GEO3K Multimodal Agent Inference")
+ print("=" * 40)
+
+ # Set environment variables
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+ # Configuration
+ n_parallel_agents = 16 # Adjust based on your GPU memory
+ model_name = "Qwen/Qwen2-VL-2B-Instruct" # Default multimodal model
+
+ print(f"š§ Configuration:")
+ print(f" - Model: {model_name}")
+ print(f" - Parallel agents: {n_parallel_agents}")
+
+ # Initialize tokenizer
+ print("š Loading tokenizer...")
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+
+ # Agent configuration for multimodal geometry reasoning
+ agent_args = {
+ "accumulate_thinking": True,
+ "include_images_in_completion": True, # Include image info in completions for debugging
+ }
+
+ # Environment configuration
+ env_args = {
+ "reward_fn": math_reward_fn,
+ }
+
+ # Sampling parameters for inference
+ sampling_params = {
+ "temperature": 0.7, # Slightly higher for reasoning diversity
+ "top_p": 0.95,
+ "max_tokens": 2048,
+ "model": model_name
+ }
+
+ print("š§ Initializing agent execution engine...")
+
+ # Initialize execution engine
+ # Note: You can switch between "openai" and "verl" engine
+ # For multimodal models, "verl" engine with SGLang backend is recommended
+ engine = AgentExecutionEngine(
+ agent_class=Geo3kAgent,
+ agent_args=agent_args,
+ env_class=SingleTurnEnvironment,
+ env_args=env_args,
+ engine_name="openai", # Can be "openai" or "verl"
+ rollout_engine_args={
+ "base_url": "http://localhost:30000/v1", # SGLang server URL
+ "api_key": "None"
+ },
+ tokenizer=tokenizer,
+ sampling_params=sampling_params,
+ max_response_length=2048,
+ max_prompt_length=2048,
+ n_parallel_agents=n_parallel_agents,
+ )
+
+ # Load test dataset
+ print("š Loading test dataset...")
+ try:
+ test_dataset = DatasetRegistry.load_dataset("geo3k_test", "test")
+ print(f"ā
Test dataset loaded: {len(test_dataset)} samples")
+ except Exception as e:
+ print(f"ā Dataset not found: {e}")
+ print("š Preparing dataset...")
+ from prepare_geo3k_data import prepare_geo3k_data
+
+ _, test_dataset = prepare_geo3k_data()
+
+ test_data = test_dataset.get_data()
+
+ # Take a smaller subset for quick testing
+ test_samples = 50 # Adjust as needed
+ subset_size = min(test_samples, len(test_data))
+ test_subset = test_data[:subset_size]
+ print(f"šÆ Testing on {subset_size} samples")
+
+ # Check multimodal content statistics
+ multimodal_count = sum(1 for sample in test_subset if sample.get("images"))
+ print(f"šø Multimodal samples: {multimodal_count}/{subset_size}")
+
+ # Repeat samples for pass@k evaluation
+ n_repeats = 4 # Number of attempts per problem
+ tasks = [sample for sample in test_subset for _ in range(n_repeats)]
+ print(f"š Total tasks (with repeats): {len(tasks)}")
+
+ # Show sample problem
+ if not test_subset:
+ print("ā ļø No samples found in test dataset. Exiting.")
+ return
+
+ sample = test_subset[0]
+ print(f"\nš Sample Problem:")
+ question_preview = sample.get("question", "")[:200]
+ ground_truth_preview = str(sample.get("ground_truth", ""))[:100]
+ has_images = bool(sample.get("images"))
+ print(f" Question: {question_preview}...")
+ print(f" Ground Truth: {ground_truth_preview}...")
+ print(f" Has Images: {has_images}")
+
+ # Run inference
+ print("\nš Starting inference...")
+ print("ā³ This may take a while depending on the number of samples and model speed...")
+
+ try:
+ results = asyncio.run(engine.execute_tasks(tasks))
+ print(f"\nā
Inference completed!")
+ print(f"š Results: {len(results)} task completions")
+
+ # Compute and display pass@k metrics
+ print("\nš Computing Pass@K metrics...")
+ pass_at_k_results = compute_pass_at_k(results)
+
+ # Display results
+ print(f"\nšÆ Performance Summary:")
+ for k, score in pass_at_k_results.items():
+ print(f" - Pass@{k}: {score:.3f}")
+
+ # Show some example results
+ print(f"\nš Example Results:")
+ for i, result in enumerate(results[:3]): # Show first 3 results
+ reward = result.get('reward', 0)
+ success = "ā
" if reward > 0.5 else "ā"
+ print(f" {success} Task {i+1}: Reward = {reward:.3f}")
+
+ except Exception as e:
+ print(f"ā Inference failed: {e}")
+ print("š” Make sure your model server is running:")
+ print(" SGLang: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-2B-Instruct")
+ print(" vLLM: python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2-VL-2B-Instruct")
+ raise
+
+ print(f"\nš GEO3K inference completed!")
+ print(f"š” For better performance, consider:")
+ print(f" - Using a larger multimodal model (e.g., Qwen2-VL-7B)")
+ print(f" - Fine-tuning on GEO3K training data")
+ print(f" - Adjusting sampling parameters")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/geo3k/train_geo3k_agent.py b/examples/geo3k/train_geo3k_agent.py
new file mode 100644
index 000000000..54f9103d2
--- /dev/null
+++ b/examples/geo3k/train_geo3k_agent.py
@@ -0,0 +1,116 @@
+#!/usr/bin/env python3
+"""
+Training script for GEO3K multimodal geometry agent using RLLM framework.
+
+This script follows RLLM patterns and leverages Verl's multimodal capabilities
+for training on geometry problems with images.
+"""
+
+import hydra
+from omegaconf import OmegaConf, open_dict
+
+from rllm.agents import Geo3kAgent
+from rllm.data.dataset import DatasetRegistry
+from rllm.environments.base.single_turn_env import SingleTurnEnvironment
+from rllm.rewards.reward_fn import math_reward_fn
+from rllm.trainer.agent_trainer import AgentTrainer
+
+
+@hydra.main(config_path="pkg://rllm.trainer.config", config_name="geo3k_multimodal_trainer", version_base=None)
+def main(config):
+ """
+ Main training function for GEO3K multimodal agent.
+
+ This follows the same pattern as other RLLM examples while enabling
+ multimodal training through Verl's capabilities.
+ """
+ # Load datasets (must be prepared first using prepare_geo3k_data.py)
+ train_dataset = DatasetRegistry.load_dataset("geo3k_train", "train")
+ test_dataset = DatasetRegistry.load_dataset("geo3k_test", "test")
+
+ # Ensure configuration matches Verl GEO3K example defaults
+ if hasattr(config, "actor_rollout_ref") and hasattr(config.actor_rollout_ref, "model"):
+ with open_dict(config.actor_rollout_ref.model):
+ config.actor_rollout_ref.model.path = "Qwen/Qwen2.5-VL-7B-Instruct"
+
+ if not hasattr(config, "multimodal") or config.multimodal is None:
+ with open_dict(config):
+ config.multimodal = OmegaConf.create({"enable": True, "image_key": "images"})
+ else:
+ with open_dict(config.multimodal):
+ config.multimodal.enable = True
+ config.multimodal.image_key = "images"
+
+ rejection_multiplier = 1
+ if hasattr(config, "rllm") and hasattr(config.rllm, "rejection_sample") and config.rllm.rejection_sample is not None:
+ rejection_multiplier = getattr(config.rllm.rejection_sample, "multiplier", 1) or 1
+
+ if hasattr(config, "data"):
+ with open_dict(config.data):
+ config.data.image_key = "images"
+ config.data.return_multi_modal_inputs = True
+ config.data.train_batch_size = 32
+ config.data.val_batch_size = 32
+ config.data.gen_batch_size = 32 * max(1, rejection_multiplier)
+
+ # Ensure PPO mini-batch sizes do not exceed data batch sizes
+ train_batch_size = getattr(config.data, "train_batch_size", None)
+ val_batch_size = getattr(config.data, "val_batch_size", None)
+
+ if hasattr(config, "actor_rollout_ref") and hasattr(config.actor_rollout_ref, "actor"):
+ with open_dict(config.actor_rollout_ref.actor):
+ mini_batch = getattr(config.actor_rollout_ref.actor, "ppo_mini_batch_size", None)
+ if train_batch_size is not None and (mini_batch is None or mini_batch > train_batch_size):
+ config.actor_rollout_ref.actor.ppo_mini_batch_size = train_batch_size
+
+ if hasattr(config, "critic"):
+ with open_dict(config.critic):
+ mini_batch = getattr(config.critic, "ppo_mini_batch_size", None)
+ if val_batch_size is not None and (mini_batch is None or mini_batch > val_batch_size):
+ config.critic.ppo_mini_batch_size = val_batch_size
+
+ # Disable workflow mode for single-turn geo3k training
+ if hasattr(config, "rllm") and hasattr(config.rllm, "workflow"):
+ with open_dict(config.rllm.workflow):
+ config.rllm.workflow.use_workflow = False
+
+ # Reduce rollout parallelism to avoid GPU OOM
+ if hasattr(config, "actor_rollout_ref") and hasattr(config.actor_rollout_ref, "rollout"):
+ with open_dict(config.actor_rollout_ref.rollout):
+ config.actor_rollout_ref.rollout.n = 1
+ config.actor_rollout_ref.rollout.gpu_memory_utilization = min(0.6, config.actor_rollout_ref.rollout.gpu_memory_utilization)
+ config.actor_rollout_ref.rollout.max_num_batched_tokens = 4096
+ config.actor_rollout_ref.rollout.max_num_seqs = 128
+
+ if hasattr(config.actor_rollout_ref.rollout, "agent"):
+ with open_dict(config.actor_rollout_ref.rollout.agent):
+ current_workers = getattr(config.actor_rollout_ref.rollout.agent, "num_workers", 0) or 0
+ config.actor_rollout_ref.rollout.agent.num_workers = max(1, current_workers)
+
+ # Agent configuration - minimal args following RLLM patterns
+ agent_args = {
+ "accumulate_thinking": True,
+ "include_images_in_completion": True, # Enable image info in completions for multimodal training
+ }
+
+ # Environment configuration - simple single-turn math environment
+ env_args = {
+ "reward_fn": math_reward_fn,
+ }
+
+ # Initialize trainer following RLLM patterns
+ trainer = AgentTrainer(
+ agent_class=Geo3kAgent,
+ env_class=SingleTurnEnvironment,
+ agent_args=agent_args,
+ env_args=env_args,
+ config=config,
+ train_dataset=train_dataset,
+ val_dataset=test_dataset,
+ )
+
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/geo3k/train_geo3k_agent.sh b/examples/geo3k/train_geo3k_agent.sh
new file mode 100755
index 000000000..3b5e2bb0d
--- /dev/null
+++ b/examples/geo3k/train_geo3k_agent.sh
@@ -0,0 +1,80 @@
+#!/bin/bash
+# GEO3K Multimodal Agent Training Script
+# This script follows RLLM patterns while enabling multimodal training through Verl
+
+set -x
+
+# Environment setup for multimodal training
+export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
+export TOKENIZERS_PARALLELISM=false
+
+# Use SGLang for multimodal support (following Verl's GEO3K example)
+# Uncomment these for vLLM if preferred
+# export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+# export VLLM_USE_V1=1
+# export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
+# export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
+
+# Find the directory where rllm package is located
+RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
+
+# Train GEO3K agent with multimodal support
+# Configuration based on Verl's geo3k_multiturn examples and RLLM patterns
+python3 -m examples.geo3k.train_geo3k_agent \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=64 \
+ data.val_batch_size=64 \
+ data.max_prompt_length=2048 \
+ data.max_response_length=2048 \
+ data.return_raw_chat=True \
+ data.return_multi_modal_inputs=True \
+ data.trust_remote_code=True \
+ actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
+ actor_rollout_ref.model.trust_remote_code=True \
+ actor_rollout_ref.hybrid_engine=True \
+ actor_rollout_ref.actor.optim.lr=5e-7 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.loss_agg_mode=token-mean \
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
+ actor_rollout_ref.actor.use_dynamic_bsz=False \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=16384 \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ actor_rollout_ref.actor.clip_ratio=0.2 \
+ actor_rollout_ref.actor.entropy_coeff=0.01 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=sglang \
+ actor_rollout_ref.rollout.mode=async \
+ actor_rollout_ref.rollout.dtype=bfloat16 \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
+ actor_rollout_ref.rollout.temperature=1.0 \
+ actor_rollout_ref.rollout.top_p=1.0 \
+ actor_rollout_ref.rollout.n=4 \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.temperature=0.7 \
+ actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
+ algorithm.kl_ctrl.kl_coef=0.001 \
+ critic.ppo_micro_batch_size_per_gpu=2 \
+ rllm.agent.name=geo3k_agent \
+ rllm.agent.max_steps=1 \
+ rllm.env.name=math \
+ rllm.mask_truncated_samples=False \
+ rllm.stepwise_advantage.enable=False \
+ trainer.critic_warmup=0 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name='rllm-geo3k-multimodal' \
+ trainer.experiment_name='geo3k-multimodal-qwen2vl-2b' \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=2 \
+ trainer.nnodes=1 \
+ trainer.save_freq=100 \
+ trainer.test_freq=50 \
+ trainer.total_epochs=10 \
+ trainer.default_hdfs_dir=null
\ No newline at end of file
diff --git a/rllm/agents/__init__.py b/rllm/agents/__init__.py
index 43924560d..d4e9f007e 100644
--- a/rllm/agents/__init__.py
+++ b/rllm/agents/__init__.py
@@ -15,6 +15,7 @@ def safe_import(module_path, class_name):
# Define all agent imports
AGENT_IMPORTS = [
+ ("rllm.agents.geo3k_agent", "Geo3kAgent"),
("rllm.agents.miniwob_agent", "MiniWobAgent"),
("rllm.agents.frozenlake_agent", "FrozenLakeAgent"),
("rllm.agents.swe_agent", "SWEAgent"),
diff --git a/rllm/agents/geo3k_agent.py b/rllm/agents/geo3k_agent.py
new file mode 100644
index 000000000..628ae4a6c
--- /dev/null
+++ b/rllm/agents/geo3k_agent.py
@@ -0,0 +1,245 @@
+"""Geometry3k Agent with multimodal support.
+
+This agent is designed to solve geometry problems that include both text and images,
+leveraging RLLM's multimodal capabilities built on top of Verl.
+"""
+
+from typing import Any, List, Union, Dict
+from PIL import Image
+
+from rllm.agents.agent import Action, BaseAgent, Step, Trajectory
+from rllm.data.multimodal import MultimodalMessage, as_pil_image, create_multimodal_conversation
+
+
+class Geo3kAgent(BaseAgent):
+ """
+ A geometry agent that solves mathematical problems with multimodal inputs (text + images).
+
+ This agent extends the standard MathAgent pattern to handle images that are common
+ in geometry problems, while maintaining full compatibility with RLLM's training infrastructure.
+ """
+
+ def __init__(self, accumulate_thinking=True, include_images_in_completion=False):
+ """
+ Initialize the Geo3kAgent.
+
+ Args:
+ accumulate_thinking: Whether to accumulate thinking across turns
+ include_images_in_completion: Whether to include image info in chat_completions
+ """
+ self._trajectory = Trajectory()
+ self.messages = [] # Store MultimodalMessage objects
+ self.accumulate_thinking = accumulate_thinking
+ self.include_images_in_completion = include_images_in_completion
+ self._current_images = [] # Store images for current problem
+
+ def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs):
+ """Process environment feedback and update internal state with multimodal support."""
+
+ # If observation is None, this is a reward update for the existing step
+ if observation is None:
+ if self.trajectory.steps:
+ cur_step = self.get_current_state()
+ cur_step.reward = reward
+ cur_step.done = done
+ cur_step.info = info
+ return
+
+ # Extract multimodal data from observation
+ images = []
+ if isinstance(observation, dict):
+ formatted_observation = observation.get("question", str(observation))
+ raw_images = observation.get("images", [])
+ print(f"DEBUG: Observation dict has images: {len(raw_images) if raw_images else 0}")
+ elif isinstance(observation, str):
+ formatted_observation = observation
+ # Check if kwargs contains images
+ raw_images = kwargs.get("images", [])
+ print(f"DEBUG: String observation, kwargs has images: {len(raw_images) if raw_images else 0}")
+ else:
+ raise ValueError(f"Invalid observation type: {type(observation)}")
+
+ print(f"DEBUG: Raw images count: {len(raw_images) if raw_images else 0}")
+ if raw_images:
+ print(f"DEBUG: First raw image type: {type(raw_images[0])}")
+ if isinstance(raw_images[0], dict):
+ print(f"DEBUG: First raw image dict keys: {raw_images[0].keys()}")
+
+ decoded_images: list[Image.Image] = []
+ # Handle numpy arrays properly
+ images_to_process = []
+ if raw_images is not None:
+ if hasattr(raw_images, '__len__') and len(raw_images) > 0:
+ images_to_process = raw_images
+
+ for i, image in enumerate(images_to_process):
+ pil_image = as_pil_image(image)
+ print(f"DEBUG: Image {i} decoded successfully: {pil_image is not None}")
+ if pil_image is None and isinstance(image, str):
+ try:
+ pil_image = Image.open(image).convert("RGB")
+ print(f"DEBUG: Image {i} decoded from file path")
+ except Exception as e:
+ print(f"DEBUG: Image {i} failed file decode: {e}")
+ pil_image = None
+ if pil_image is not None:
+ decoded_images.append(pil_image)
+ print(f"DEBUG: Image {i} added to decoded list, size: {pil_image.size}")
+
+ images = decoded_images
+ self._current_images = images
+ print(f"DEBUG: Final decoded images count: {len(images)}")
+
+ # Create multimodal message
+ multimodal_message = MultimodalMessage(
+ role="user",
+ text=formatted_observation,
+ images=images if images else None
+ )
+ self.messages.append(multimodal_message)
+
+ # Create new step
+ new_step = Step(observation=formatted_observation)
+ # Store multimodal info in step's info dict for compatibility
+ if images:
+ new_step.info = {"images": images}
+
+ self._trajectory.steps.append(new_step)
+
+ def update_from_model(self, response: str, **kwargs) -> Action:
+ """
+ Updates the agent's internal state based on the model's response.
+ """
+
+ # Create assistant message (models typically don't generate images in geo3k)
+ assistant_message = MultimodalMessage(
+ role="assistant",
+ text=response
+ )
+ self.messages.append(assistant_message)
+
+ # Update the latest step
+ cur_step = self.get_current_state()
+ cur_step.chat_completions = self.chat_completions
+ cur_step.model_response = response
+
+ # Parse thinking and action (same as MathAgent)
+ if response.count("") == 1:
+ thought, sep, action = response.partition("")
+ thought = thought + sep
+ action = Action(action.strip())
+ else:
+ thought = None
+ action = Action(response.strip())
+
+ cur_step.thought = thought
+ cur_step.action = action
+
+ return action
+
+ def reset(self) -> None:
+ """Reset agent state for new episode (wipes trajectory and messages)."""
+ self._trajectory = Trajectory()
+ self.messages = []
+ self._current_images = []
+
+ @property
+ def chat_completions(self) -> List[Dict[str, str]]:
+ """Return conversation history for model interaction.
+
+ For multimodal agents, this returns the Verl-compatible format.
+ The actual multimodal data is provided via get_multimodal_data().
+ """
+ # Convert multimodal messages to Verl format
+ conversation = self.get_multimodal_conversation()
+
+ # For backward compatibility, also provide text-only format
+ completions = []
+ for msg in conversation:
+ completion = {"role": msg["role"]}
+
+ if isinstance(msg.get("content"), list):
+ # Extract text parts from multimodal content
+ text_parts = []
+ for item in msg["content"]:
+ if item.get("type") == "text":
+ text_parts.append(item["text"])
+ elif item.get("type") == "image" and self.include_images_in_completion:
+ text_parts.append(f"[Images: {len(item.get('image', []))} geometry diagrams]")
+ completion["content"] = " ".join(text_parts) if text_parts else ""
+ else:
+ completion["content"] = msg.get("content", "")
+
+ completions.append(completion)
+
+ # Apply thinking accumulation logic (same as MathAgent)
+ if not self.accumulate_thinking:
+ for msg in completions[:-1]:
+ if msg["role"] == "assistant":
+ _, sep, after = msg["content"].partition("")
+ if sep:
+ msg["content"] = after
+
+ return completions
+
+ @property
+ def trajectory(self) -> Trajectory:
+ """Return complete interaction trajectory."""
+ return self._trajectory
+
+ def get_current_state(self) -> Step:
+ """Returns the current step/state of the agent."""
+ assert self._trajectory.steps, "Trajectory should not be empty when get_current_state is called."
+ return self._trajectory.steps[-1]
+
+ # Multimodal-specific methods
+
+ def get_multimodal_conversation(self) -> List[Dict[str, Any]]:
+ """Get the conversation in Verl's multimodal format."""
+ return create_multimodal_conversation(self.messages)
+
+ def get_multimodal_data(self) -> Dict[str, List[Any]]:
+ """Extract all multimodal data for Verl processing."""
+ # Collect all images from current trajectory
+ all_images = []
+ all_videos = []
+
+ # Add images from current observation if any
+ if self._current_images:
+ all_images.extend(self._current_images)
+
+ # Add images from all messages in conversation
+ for msg in self.messages:
+ if msg.images:
+ all_images.extend(msg.images)
+ if msg.videos:
+ all_videos.extend(msg.videos)
+
+ print(f"DEBUG: get_multimodal_data - total images: {len(all_images)}, total videos: {len(all_videos)}")
+ if all_images:
+ print(f"DEBUG: get_multimodal_data - first image type: {type(all_images[0])}")
+
+ # Return in Verl's expected format
+ return {
+ "image": all_images,
+ "video": all_videos
+ }
+
+ def get_current_images(self) -> List[Union[str, Dict, Image.Image]]:
+ """Get images from the current problem."""
+ return self._current_images
+
+ def has_multimodal_content(self) -> bool:
+ """Check if the current trajectory contains any multimodal content."""
+ return any(msg.images for msg in self.messages if msg.images)
+
+ def get_multimodal_summary(self) -> Dict[str, Any]:
+ """Get a summary of multimodal content in the trajectory."""
+ total_images = sum(len(msg.images) for msg in self.messages if msg.images)
+ return {
+ "total_images": total_images,
+ "has_multimodal": self.has_multimodal_content(),
+ "current_images": len(self._current_images),
+ "steps_with_images": sum(1 for step in self._trajectory.steps
+ if step.info and "images" in step.info)
+ }
diff --git a/rllm/data/__init__.py b/rllm/data/__init__.py
index 6a8f1be83..415d75090 100644
--- a/rllm/data/__init__.py
+++ b/rllm/data/__init__.py
@@ -1,6 +1,15 @@
from rllm.data.dataset import Dataset, DatasetRegistry
from rllm.data.dataset_types import Dataset as DatasetEnum
from rllm.data.dataset_types import DatasetConfig, Problem, TestDataset, TrainDataset
+from rllm.data.multimodal import (
+ MultimodalMessage,
+ create_multimodal_conversation,
+ extract_multimodal_data,
+ create_text_message,
+ create_image_message,
+ create_video_message,
+ create_multimodal_message,
+)
__all__ = [
"TrainDataset",
@@ -10,4 +19,12 @@
"DatasetRegistry",
"Problem",
"DatasetConfig",
+ # Multimodal support
+ "MultimodalMessage",
+ "create_multimodal_conversation",
+ "extract_multimodal_data",
+ "create_text_message",
+ "create_image_message",
+ "create_video_message",
+ "create_multimodal_message",
]
diff --git a/rllm/data/multimodal.py b/rllm/data/multimodal.py
new file mode 100644
index 000000000..6df738aea
--- /dev/null
+++ b/rllm/data/multimodal.py
@@ -0,0 +1,220 @@
+"""Multimodal data utilities for RLLM."""
+
+import base64
+from dataclasses import dataclass
+from io import BytesIO
+from typing import Any, Dict, List, Optional, Union
+
+from PIL import Image
+
+# Import Verl's vision utilities - these are the core multimodal processing functions
+try:
+ from verl.utils.dataset.vision_utils import process_image, process_video
+ from verl.workers.rollout.schemas import AsyncRolloutRequest, Message
+ VERL_MULTIMODAL_AVAILABLE = True
+except ImportError:
+ VERL_MULTIMODAL_AVAILABLE = False
+ # Fallback stubs for type hints
+ class AsyncRolloutRequest:
+ pass
+ class Message:
+ pass
+
+ def process_image(image): # type: ignore[override]
+ return image
+
+ def process_video(video): # type: ignore[override]
+ return video
+
+
+DATA_URI_PREFIX = "data:image/"
+
+
+def as_pil_image(image: Any) -> Image.Image | None:
+ """Convert supported payloads (PIL image, data URI, byte dict) to a PIL image."""
+
+ if hasattr(image, "mode") and hasattr(image, "size"):
+ return image # Already a PIL image
+
+ if isinstance(image, str) and image.startswith(DATA_URI_PREFIX):
+ try:
+ header, encoded = image.split(",", 1)
+ image_bytes = base64.b64decode(encoded)
+ return Image.open(BytesIO(image_bytes)).convert("RGB")
+ except Exception:
+ return None
+
+ if isinstance(image, dict):
+ if "bytes" in image and image["bytes"] is not None:
+ try:
+ return Image.open(BytesIO(image["bytes"])).convert("RGB")
+ except Exception:
+ return None
+
+ # HuggingFace datasets often store images as dicts with "path" pointing to a
+ # file or a data URI. Handle both cases, falling back to disk loading.
+ data_str: str | None = None
+ if "data" in image and isinstance(image["data"], str):
+ data_str = image["data"]
+ elif "path" in image and isinstance(image["path"], str):
+ data_str = image["path"]
+
+ if data_str:
+ if data_str.startswith(DATA_URI_PREFIX):
+ try:
+ _, encoded = data_str.split(",", 1)
+ image_bytes = base64.b64decode(encoded)
+ return Image.open(BytesIO(image_bytes)).convert("RGB")
+ except Exception:
+ return None
+ else:
+ try:
+ return Image.open(data_str).convert("RGB")
+ except Exception:
+ return None
+
+ return None
+
+
+def ensure_multimodal_available():
+ """Ensure Verl multimodal capabilities are available."""
+ if not VERL_MULTIMODAL_AVAILABLE:
+ raise ImportError(
+ "Verl multimodal support not available. Please ensure Verl is properly installed."
+ )
+
+
+@dataclass
+class MultimodalMessage:
+ """A message that can contain text, images, and videos.
+
+ This is a lightweight wrapper around Verl's Message format for easier use in RLLM.
+ """
+
+ role: str # "user", "assistant", or "system"
+ text: Optional[str] = None
+ images: Optional[List[Union[str, Dict, Image.Image]]] = None
+ videos: Optional[List[Dict]] = None
+
+ def to_verl_message(self) -> Dict[str, Any]:
+ """Convert to Verl's Message format."""
+ ensure_multimodal_available()
+
+ content = []
+
+ # Add text content
+ if self.text:
+ content.append({"type": "text", "text": self.text})
+
+ # Add image content - process through Verl's utilities
+ if self.images:
+ processed_images = []
+ for image in self.images:
+ pil_image = as_pil_image(image)
+ if pil_image is not None:
+ processed_images.append(pil_image.convert("RGB"))
+ else:
+ processed_images.append(process_image(image))
+ content.append({"type": "image", "image": processed_images})
+
+ # Add video content - process through Verl's utilities
+ if self.videos:
+ processed_videos = []
+ for video in self.videos:
+ processed_video = process_video(video)
+ processed_videos.append(processed_video)
+ content.append({"type": "video", "video": processed_videos})
+
+ return {
+ "role": self.role,
+ "content": content if content else self.text or ""
+ }
+
+
+def create_multimodal_conversation(messages: List[MultimodalMessage]) -> List[Dict[str, Any]]:
+ """Create a conversation from MultimodalMessage objects.
+
+ Args:
+ messages: List of MultimodalMessage objects
+
+ Returns:
+ List of messages in Verl format, ready for training
+ """
+ return [msg.to_verl_message() for msg in messages]
+
+
+def extract_multimodal_data(messages: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
+ """Extract multimodal data from messages for Verl processing.
+
+ This function extracts all images and videos from a conversation,
+ which is the format expected by Verl's AsyncRolloutRequest.
+
+ Args:
+ messages: List of messages in Verl format
+
+ Returns:
+ Dictionary with 'image' and 'video' keys containing processed media
+ """
+ ensure_multimodal_available()
+
+ all_images = []
+ all_videos = []
+
+ for message in messages:
+ content = message.get("content", [])
+ if isinstance(content, str):
+ continue
+
+ for item in content:
+ if item.get("type") == "image":
+ images = item.get("image", [])
+ if isinstance(images, list):
+ all_images.extend(images)
+ else:
+ all_images.append(images)
+ elif item.get("type") == "video":
+ videos = item.get("video", [])
+ if isinstance(videos, list):
+ all_videos.extend(videos)
+ else:
+ all_videos.append(videos)
+
+ return {
+ "image": all_images,
+ "video": all_videos
+ }
+
+
+# Convenience functions for common use cases
+
+def create_text_message(role: str, text: str) -> MultimodalMessage:
+ """Create a text-only message."""
+ return MultimodalMessage(role=role, text=text)
+
+
+def create_image_message(
+ role: str,
+ text: str,
+ images: List[Union[str, Dict, Image.Image]]
+) -> MultimodalMessage:
+ """Create a message with text and images."""
+ return MultimodalMessage(role=role, text=text, images=images)
+
+
+def create_video_message(
+ role: str,
+ text: str,
+ videos: List[Dict]
+) -> MultimodalMessage:
+ """Create a message with text and videos."""
+ return MultimodalMessage(role=role, text=text, videos=videos)
+
+
+def create_multimodal_message(
+ role: str,
+ text: str,
+ images: Optional[List[Union[str, Dict, Image.Image]]] = None,
+ videos: Optional[List[Dict]] = None
+) -> MultimodalMessage:
+ """Create a message with text, images, and/or videos."""
+ return MultimodalMessage(role=role, text=text, images=images, videos=videos)
diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py
index c4982a6de..5a26b433f 100644
--- a/rllm/engine/agent_execution_engine.py
+++ b/rllm/engine/agent_execution_engine.py
@@ -40,6 +40,7 @@ def __init__(
max_response_length=8192,
max_prompt_length=1024,
config=None,
+ processor=None,
agent_class=None,
env_class=None,
agent_args=None,
@@ -60,6 +61,7 @@ def __init__(
self.config = config
self.tokenizer = tokenizer
self.engine_name = engine_name
+ self.processor = processor
self.n_parallel_agents = n_parallel_agents
self.overlong_filter = overlong_filter
@@ -111,13 +113,22 @@ def __init__(
config=self.config,
rollout_manager=rollout_engine,
tokenizer=self.tokenizer,
+ processor=self.processor,
disable_thinking=self.config.rllm.disable_thinking,
)
# Create a thread pool executor for environment interactions (i.e. step, reset, close)
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
- async def get_model_response(self, prompt, application_id, **kwargs) -> str:
+ async def get_model_response(
+ self,
+ messages,
+ application_id,
+ *,
+ multimodal_messages=None,
+ multimodal_data=None,
+ **kwargs,
+ ) -> str:
"""
Compute model response asynchronously based on the engine type.
@@ -140,12 +151,20 @@ async def get_model_response(self, prompt, application_id, **kwargs) -> str:
sampling_params.update(kwargs)
if self.engine_name == "openai":
- output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, **sampling_params)
+ request_messages = messages
+ output = await self.rollout_engine.get_model_response(request_messages, application_id=application_id, **sampling_params)
return output.text
elif self.engine_name == "verl":
+ request_messages = multimodal_messages or messages
meta_data = sampling_params.pop("meta_info", {})
validate = meta_data.get("validate", False)
- output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, validate=validate, **sampling_params)
+ output = await self.rollout_engine.get_model_response(
+ request_messages,
+ application_id=application_id,
+ validate=validate,
+ multimodal_data=multimodal_data,
+ **sampling_params,
+ )
return output.text
else:
raise NotImplementedError(f"Engine type '{self.engine_name}' not supported")
@@ -195,11 +214,17 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
# Reset agent
agent.reset()
# Update agent internal state from environment.
+ # Extract multimodal data from observation and pass to agent
+ agent_kwargs = {}
+ if isinstance(observation, dict) and "images" in observation:
+ agent_kwargs["images"] = observation["images"]
+
agent.update_from_env(
observation=observation, # Raw observation from environment
reward=0.0,
done=False,
info=info,
+ **agent_kwargs,
)
messages = agent.chat_completions
prompt_tokens, _ = convert_messages_to_tokens_and_masks(messages, tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=True, contains_generation_msg=True)
@@ -212,6 +237,21 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
for step_idx in range(self.max_steps):
# Get action from agent
prompt_messages = agent.chat_completions.copy()
+
+ multimodal_messages = None
+ multimodal_data = None
+ if hasattr(agent, "get_multimodal_data"):
+ try:
+ multimodal_data = agent.get_multimodal_data()
+ has_multimodal = isinstance(multimodal_data, dict) and (
+ multimodal_data.get("image") or multimodal_data.get("video")
+ )
+ if has_multimodal and hasattr(agent, "get_multimodal_conversation"):
+ multimodal_messages = agent.get_multimodal_conversation()
+ except Exception as exc:
+ colorful_print(f"Warning: failed to collect multimodal data: {exc}", "yellow")
+ multimodal_messages = None
+ multimodal_data = None
# Max remaining tokens left for the response
# For enforced max prompt at each step, no need to deduct here
if not self.enforce_max_prompt_length:
@@ -229,7 +269,13 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
kwargs["max_tokens"] = max_tokens
start_time = time.time()
- response = await self.get_model_response(prompt_messages, application_id, **kwargs)
+ response = await self.get_model_response(
+ prompt_messages,
+ application_id,
+ multimodal_messages=multimodal_messages,
+ multimodal_data=multimodal_data,
+ **kwargs,
+ )
delta_time = time.time() - start_time
llm_time += delta_time
total_time += delta_time
@@ -267,11 +313,17 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
info["cur_tokens"] = response_token_len
# Update agent internal state.
+ # Extract multimodal data from next observation and pass to agent
+ next_agent_kwargs = {}
+ if isinstance(next_observation, dict) and "images" in next_observation:
+ next_agent_kwargs["images"] = next_observation["images"]
+
agent.update_from_env(
observation=next_observation,
reward=reward,
done=done,
info=info,
+ **next_agent_kwargs,
)
cur_step = agent.get_current_state()
@@ -383,6 +435,9 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te
"trajectory_reward": trajectory.reward,
"idx": env.idx,
"chat_completions": agent.chat_completions,
+ "prompt_messages": prompt_messages,
+ "multimodal_messages": multimodal_messages,
+ "multi_modal_data": multimodal_data,
"metrics": {
# Total number of steps taken in the trajectory
"steps": len(trajectory.steps),
diff --git a/rllm/engine/rollout/verl_engine.py b/rllm/engine/rollout/verl_engine.py
index cc1476e66..f7a318eb9 100644
--- a/rllm/engine/rollout/verl_engine.py
+++ b/rllm/engine/rollout/verl_engine.py
@@ -1,4 +1,8 @@
import uuid
+from datetime import datetime
+from pathlib import Path
+
+import numpy as np
from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine
from rllm.parser import ChatTemplateParser, ToolParser
@@ -6,11 +10,12 @@
class VerlEngine(RolloutEngine):
- def __init__(self, config, rollout_manager, tokenizer, **kwargs):
+ def __init__(self, config, rollout_manager, tokenizer, processor=None, **kwargs):
self.config = config
self.rollout_manager = rollout_manager
self.server_manager = AsyncLLMServerManager(config, rollout_manager.async_llm_servers)
self.tokenizer = tokenizer
+ self.processor = processor # Store processor for multimodal processing
self.chat_parser = ChatTemplateParser.get_parser(self.tokenizer, disable_thinking=kwargs.get("disable_thinking", False))
try:
@@ -21,10 +26,13 @@ def __init__(self, config, rollout_manager, tokenizer, **kwargs):
self.validate = False
- async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput:
+ async def get_model_response(self, messages: list[dict], multimodal_messages: list[dict] | None = None, **kwargs) -> ModelOutput:
application_id = kwargs.pop("application_id", str(uuid.uuid4()))
validate = self.validate or kwargs.pop("validate", False)
+ # Extract multimodal data if present
+ multimodal_data = kwargs.pop("multimodal_data", None)
+
if validate:
sampling_params = dict(
temperature=0.0 if self.config.actor_rollout_ref.rollout.val_kwargs.do_sample is False else self.config.actor_rollout_ref.rollout.val_kwargs.temperature,
@@ -41,10 +49,164 @@ async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutpu
max_tokens = sampling_params.pop("max_tokens", self.config.data.max_response_length)
- prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True)
- prompt_ids = self.tokenizer.encode(prompt)
+ base_messages = multimodal_messages if multimodal_messages is not None else messages
+
+ def ensure_text_messages(messages_list: list[dict]) -> list[dict]:
+ text_only = []
+ for msg in messages_list:
+ new_msg = dict(msg)
+ content = new_msg.get("content")
+ if isinstance(content, list):
+ text_parts = []
+ for item in content:
+ if isinstance(item, dict) and item.get("type") == "text":
+ text_parts.append(item.get("text", ""))
+ new_msg["content"] = " ".join(part for part in text_parts if part).strip()
+ text_only.append(new_msg)
+ return text_only
+
+ model_inputs = None
+ images = None
+ videos = None
+
+ if self.processor is not None:
+ raw_prompt = self.processor.apply_chat_template(
+ base_messages,
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+
+ images = multimodal_data.get("image") if isinstance(multimodal_data, dict) else None
+ if isinstance(images, (list, tuple)) and len(images) == 0:
+ images = None
+ videos = multimodal_data.get("video") if isinstance(multimodal_data, dict) else None
+ if isinstance(videos, (list, tuple)) and len(videos) == 0:
+ videos = None
+
+ model_inputs = self.processor(
+ text=[raw_prompt],
+ images=images,
+ videos=videos,
+ return_tensors="pt",
+ )
+
+ prompt_ids = model_inputs["input_ids"][0].tolist()
+ else:
+ text_messages = ensure_text_messages(base_messages)
+ raw_prompt = self.chat_parser.parse(text_messages, add_generation_prompt=True, is_first_msg=True)
+ prompt_ids = self.tokenizer.encode(raw_prompt)
- response_ids = await self.server_manager.generate(request_id=application_id, prompt_ids=prompt_ids, sampling_params=sampling_params)
+ try:
+ prompt_tokens = self.tokenizer.convert_ids_to_tokens(prompt_ids)
+ except Exception: # noqa: BLE001
+ prompt_tokens = None
+
+ # Debug logging disabled; uncomment for detailed prompt inspection.
+ # print("\n=== FINAL PROMPT TO MODEL ===")
+ # print(raw_prompt)
+ # print(f"Prompt ids ({len(prompt_ids)}): {prompt_ids}")
+ # if prompt_tokens is not None:
+ # print(f"Prompt tokens: {prompt_tokens}")
+ # if model_inputs is not None:
+ # pixel_values = model_inputs.get("pixel_values")
+ # if pixel_values is not None:
+ # try:
+ # print(f"pixel_values shape: {tuple(pixel_values.shape)}")
+ # except Exception: # noqa: BLE001
+ # print(f"pixel_values type: {type(pixel_values)}")
+ # image_grid = model_inputs.get("image_grid_thw")
+ # if image_grid is not None:
+ # try:
+ # grid_repr = image_grid.tolist() if hasattr(image_grid, "tolist") else image_grid
+ # except Exception: # noqa: BLE001
+ # grid_repr = image_grid
+ # print(f"image_grid_thw: {grid_repr}")
+ # if images:
+ # payload = images if isinstance(images, (list, tuple)) else [images]
+ # print(f"image payload count: {len(payload)}")
+ # for idx, image in enumerate(payload):
+ # size = getattr(image, "size", None)
+ # print(f"image[{idx}] type={type(image)} size={size}")
+ # if videos:
+ # payload_v = videos if isinstance(videos, (list, tuple)) else [videos]
+ # print(f"video payload count: {len(payload_v)}")
+ # print("=== END MODEL PROMPT ===\n")
+
+ if multimodal_data and (multimodal_data.get("image") or multimodal_data.get("video")):
+ debug_logged = getattr(self, "_geo3k_multimodal_debug", 0)
+ if debug_logged < 10:
+ debug_path = Path("outputs/debug_geo3k_multimodal.log")
+ debug_path.parent.mkdir(parents=True, exist_ok=True)
+ with debug_path.open("a", encoding="utf-8") as fp:
+ fp.write("\n" + "=" * 80 + "\n")
+ fp.write(f"{datetime.now().isoformat()} | multimodal request\n")
+ fp.write(f"Messages: {base_messages}\n")
+ fp.write(f"Has images: {len(multimodal_data.get('image', []))}\n")
+ if self.processor is None:
+ fp.write("WARNING: No multimodal processor available; falling back to text-only tokenization.\n")
+ else:
+ fp.write("Multimodal processor detected; logging tokenizer inputs.\n")
+ fp.write(f"Raw prompt string: {raw_prompt}\n")
+ fp.write(f"Prompt id count: {len(prompt_ids)}\n")
+ fp.write(f"Prompt ids: {prompt_ids}\n")
+ try:
+ token_strings = self.tokenizer.convert_ids_to_tokens(prompt_ids)
+ fp.write(f"Prompt tokens: {token_strings}\n")
+ except Exception as exc: # noqa: BLE001
+ fp.write(f"Failed to convert prompt ids to tokens: {exc}\n")
+
+ model_input_keys = sorted(list(model_inputs.keys())) if "model_inputs" in locals() else []
+ fp.write(f"Model input keys: {model_input_keys}\n")
+
+ if images:
+ payload = images if isinstance(images, (list, tuple)) else [images]
+ fp.write(f"Image payload types: {[type(img).__name__ for img in payload]}\n")
+ for idx, image in enumerate(payload):
+ size = getattr(image, "size", None)
+ fp.write(f"Image[{idx}] info: type={type(image)}, size={size}\n")
+ self._geo3k_multimodal_debug = debug_logged + 1
+
+ # Use Verl's generate_sequences for multimodal data
+ from verl import DataProto
+
+ # Create batch with multimodal data in Verl format
+ non_tensor_batch = {
+ "raw_prompt_ids": np.array([prompt_ids], dtype=object),
+ "raw_prompt": np.array([np.array(base_messages, dtype=object)], dtype=object),
+ "raw_prompt_text": np.array([raw_prompt], dtype=object),
+ "multi_modal_data": np.array([multimodal_data], dtype=object),
+ }
+
+ batch = DataProto.from_dict(non_tensors=non_tensor_batch)
+
+ # Update sampling params in rollout manager config temporarily
+ original_config = {}
+ for key, value in sampling_params.items():
+ if hasattr(self.rollout_manager.config.actor_rollout_ref.rollout, key):
+ original_config[key] = getattr(self.rollout_manager.config.actor_rollout_ref.rollout, key)
+ setattr(self.rollout_manager.config.actor_rollout_ref.rollout, key, value)
+
+ try:
+ # Use rollout_manager's generate_sequences for multimodal
+ output_batch = self.rollout_manager.generate_sequences(batch)
+
+ # Extract response_ids from output
+ if hasattr(output_batch, 'batch') and 'responses' in output_batch.batch:
+ response_ids = output_batch.batch['responses'][0].tolist()
+ else:
+ raise RuntimeError("Failed to get response from multimodal generation")
+
+ finally:
+ # Restore original config
+ for key, value in original_config.items():
+ setattr(self.rollout_manager.config.actor_rollout_ref.rollout, key, value)
+ else:
+ # Original text-only path
+ response_ids = await self.server_manager.generate(
+ request_id=application_id,
+ prompt_ids=prompt_ids,
+ sampling_params=sampling_params,
+ )
# verl sets max_tokens as max_model_len - len(prompt_ids), where max_model_len is config.data.max_prompt_length + config.data.max_response_length
# so we truncate the response to max_tokens if it exceeds max_tokens
diff --git a/rllm/environments/base/multi_turn_env.py b/rllm/environments/base/multi_turn_env.py
index 5f4aa9aa1..408072d7b 100644
--- a/rllm/environments/base/multi_turn_env.py
+++ b/rllm/environments/base/multi_turn_env.py
@@ -61,7 +61,9 @@ def step(self, action):
# Check if we've reached the maximum number of turns
if self.current_turn >= self.max_turns:
self.done = True
- return {}, reward, self.done, self.task
+ # For multimodal tasks, preserve task info even when done
+ final_obs = self.task if isinstance(self.task, dict) else {}
+ return final_obs, reward, self.done, self.task
return next_obs, reward, self.done, self.task
diff --git a/rllm/environments/base/single_turn_env.py b/rllm/environments/base/single_turn_env.py
index 49e5549a4..a8a16b786 100644
--- a/rllm/environments/base/single_turn_env.py
+++ b/rllm/environments/base/single_turn_env.py
@@ -37,7 +37,17 @@ def get_reward_and_next_obs(self, task: dict, action: Any) -> tuple[float, dict]
"""
reward_output = self.reward_fn(task_info=task, action=action)
- return reward_output.reward, {}
+ metadata = getattr(reward_output, "metadata", {}) or {}
+ data_source = None
+ if isinstance(task, dict):
+ data_source = task.get("data_source")
+ if metadata and (data_source == "hiyouga/geometry3k" or metadata.get("data_source") == "hiyouga/geometry3k"):
+ print(f"DEBUG: GEO3K reward metadata -> {metadata}")
+
+ # For single-turn environments, preserve the original task for next observation
+ # This ensures multimodal data (like images) is maintained
+ next_obs = task if isinstance(task, dict) else {}
+ return reward_output.reward, next_obs
@staticmethod
def from_dict(env_args: dict) -> "SingleTurnEnvironment":
diff --git a/rllm/rewards/math_reward.py b/rllm/rewards/math_reward.py
index 4a4c17bbb..a7a462381 100644
--- a/rllm/rewards/math_reward.py
+++ b/rllm/rewards/math_reward.py
@@ -1,8 +1,9 @@
-"""
-This module contains the RewardMathFn class, which evaluates mathematical answers
-and assigns rewards based on their correctness. It utilizes a language model to
-validate answers when necessary.
-"""
+"""Reward helpers for math-style tasks, including GEO3K multimodal problems."""
+
+from rllm.agents.agent import Action
+from pathlib import Path
+from datetime import datetime
+import re
from rllm.globals import THOUGHT_DELIMITER_END
from rllm.rewards.math_utils.utils import extract_answer, grade_answer_mathd, grade_answer_sympy
@@ -41,11 +42,94 @@ def __call__(self, task_info: dict, action: str) -> RewardOutput:
# problem = task_info.get("problem", "")
model_response = action
- # Handle None or empty response
- if model_response is None or model_response == "":
+ # Extract raw text when an Action wrapper is provided
+ if isinstance(model_response, Action):
+ model_response = model_response.action
+
+ # Normalize response into a string for downstream processing
+ if model_response is None:
+ print("DEBUG: Empty or None response")
+ return RewardOutput(reward=self.config.format_error_reward, is_correct=False)
+
+ if not isinstance(model_response, str):
+ model_response = str(model_response)
+
+ model_response = model_response.strip()
+
+ # Handle empty response after stripping whitespace
+ if model_response == "":
print("DEBUG: Empty or None response")
return RewardOutput(reward=self.config.format_error_reward, is_correct=False)
+ data_source = task_info.get("data_source")
+
+ if data_source == "hiyouga/geometry3k":
+ model_response = self._normalize_geo3k_response(model_response)
+ geo3k_reward = None
+ try:
+ from verl.utils.reward_score import geo3k as geo3k_reward # type: ignore
+ except ImportError:
+ try:
+ from verl.verl.utils.reward_score import geo3k as geo3k_reward # type: ignore
+ except ImportError:
+ module_path = Path(__file__).resolve().parents[2] / "verl" / "verl" / "utils" / "reward_score" / "geo3k.py"
+ if module_path.exists():
+ import importlib.util
+
+ spec = importlib.util.spec_from_file_location("rllm_vendor_verl_geo3k", module_path)
+ if spec and spec.loader:
+ geo3k_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(geo3k_module)
+ geo3k_reward = geo3k_module # type: ignore
+ if geo3k_reward is None:
+ print("DEBUG: Failed to import geo3k reward module: module not found")
+ return RewardOutput(reward=self.config.unk_error_reward, is_correct=False)
+
+ ground_truths = task_info.get("ground_truth")
+ if ground_truths is None:
+ return RewardOutput(reward=self.config.unk_error_reward, is_correct=False)
+
+ if isinstance(ground_truths, (list, tuple, set)):
+ gt_list = [str(gt) for gt in ground_truths]
+ else:
+ gt_list = [str(ground_truths)]
+
+ format_score = geo3k_reward.format_reward(model_response)
+ accuracy_scores = [geo3k_reward.acc_reward(model_response, gt) for gt in gt_list]
+ score_candidates = [geo3k_reward.compute_score(model_response, gt) for gt in gt_list]
+
+ reward = float(max(score_candidates)) if score_candidates else self.config.unk_error_reward
+ max_accuracy = max(accuracy_scores, default=0.0)
+
+ if format_score < 1.0:
+ print("DEBUG: GEO3K format violation detected; response missing required structure")
+ if max_accuracy == 0.0:
+ print("DEBUG: GEO3K accuracy check failed for all ground truths")
+
+ if format_score < 1.0:
+ debug_limit = 50
+ counter = getattr(self, "_geo3k_debug_logged", 0)
+ if counter < debug_limit:
+ debug_path = Path("outputs/debug_geo3k_responses.log")
+ debug_path.parent.mkdir(parents=True, exist_ok=True)
+ with debug_path.open("a", encoding="utf-8") as fp:
+ fp.write("\n" + "=" * 80 + "\n")
+ fp.write(f"{datetime.now().isoformat()} | format={format_score:.3f} | max_acc={max_accuracy:.3f}\n")
+ fp.write("Model response:\n")
+ fp.write(model_response + "\n")
+ fp.write("Ground truths:\n")
+ for gt in gt_list:
+ fp.write(str(gt) + "\n")
+ self._geo3k_debug_logged = counter + 1
+
+ metadata = {
+ "data_source": data_source,
+ "geo3k_accuracy_scores": accuracy_scores,
+ "geo3k_format_reward": format_score,
+ }
+
+ return RewardOutput(reward=reward, metadata=metadata, is_correct=bool(max_accuracy))
+
# Extract solution.
if THOUGHT_DELIMITER_END in model_response:
model_solution = model_response.split(THOUGHT_DELIMITER_END)[1]
@@ -93,6 +177,33 @@ def __call__(self, task_info: dict, action: str) -> RewardOutput:
return RewardOutput(reward=self.config.incorrect_reward, is_correct=False)
+ @staticmethod
+ def _normalize_geo3k_response(model_response: str) -> str:
+ response = model_response
+
+ final_answer_pattern = re.compile(r"Final answer\s*:\s*\\boxed\{.*?\}", re.DOTALL)
+ match = final_answer_pattern.search(response)
+ if not match:
+ return response
+
+ final_start = match.start()
+ prefix = response[:final_start]
+ suffix = response[final_start:]
+
+ has_think = "" in prefix
+ has_think_end = "" in prefix
+
+ if has_think and not has_think_end:
+ prefix = prefix + "\n"
+ elif not has_think:
+ cleaned_prefix = prefix.strip()
+ if cleaned_prefix:
+ prefix = f"{cleaned_prefix}\n"
+ else:
+ prefix = "\n"
+
+ return prefix + suffix
+
def rllm_reward_fn_math(data_source: str, llm_solution: str, ground_truth: str | list[str], extra_info=None, **kwargs):
"""Evaluates mathematical solutions against ground truth answers.
diff --git a/rllm/trainer/config/geo3k_multimodal_trainer.yaml b/rllm/trainer/config/geo3k_multimodal_trainer.yaml
new file mode 100644
index 000000000..699a3e704
--- /dev/null
+++ b/rllm/trainer/config/geo3k_multimodal_trainer.yaml
@@ -0,0 +1,288 @@
+hydra:
+ searchpath:
+ - pkg://verl.trainer.config
+
+defaults:
+ - ppo_trainer
+ - _self_
+
+# Enable multimodal processing
+multimodal:
+ enable: true
+
+# Model configuration for multimodal training
+actor_rollout_ref:
+ model:
+ path: "Qwen/Qwen2.5-VL-7B-Instruct"
+ trust_remote_code: true
+ use_shm: false
+ hybrid_engine: true
+ actor:
+ strategy: fsdp
+ ppo_mini_batch_size: 64
+ ppo_micro_batch_size: null
+ ppo_micro_batch_size_per_gpu: 2
+ use_dynamic_bsz: false
+ ppo_max_token_len_per_gpu: 16384
+ clip_ratio: 0.2
+ clip_ratio_low: 0.2
+ clip_ratio_high: 0.2
+ policy_loss:
+ loss_mode: vanilla
+ clip_cov_ratio: 0.0002
+ clip_cov_lb: 1.0
+ clip_cov_ub: 5.0
+ kl_cov_ratio: 0.0002
+ ppo_kl_coef: 0.1
+ clip_ratio_c: 3.0
+ loss_agg_mode: token-mean
+ entropy_coeff: 0.01
+ use_kl_loss: false
+ use_torch_compile: true
+ kl_loss_coef: 0.001
+ kl_loss_type: low_var_kl
+ ppo_epochs: 1
+ shuffle: false
+ checkpoint:
+ save_contents:
+ - model
+ - optimizer
+ - extra
+ load_contents: ${.save_contents}
+ optim:
+ lr: 5.0e-07
+ lr_warmup_steps_ratio: 0.1
+ total_training_steps: -1
+ weight_decay: 0.01
+ lr_warmup_steps: -1
+ min_lr_ratio: 0.0
+ num_cycles: 0.5
+ warmup_style: constant
+ grad_clip: 1.0
+ ulysses_sequence_parallel_size: 1
+ entropy_from_logits_with_chunking: false
+ entropy_checkpointing: false
+ fsdp_config:
+ wrap_policy:
+ min_num_params: 0
+ param_offload: false
+ optimizer_offload: false
+ offload_policy: false
+ reshard_after_forward: true
+ fsdp_size: -1
+ forward_prefetch: false
+ ref:
+ strategy: ${actor_rollout_ref.actor.strategy}
+ use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}
+ log_prob_micro_batch_size: null
+ log_prob_micro_batch_size_per_gpu: null
+ log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
+ log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
+ fsdp_config:
+ param_offload: false
+ reshard_after_forward: true
+ forward_prefetch: false
+ wrap_policy:
+ min_num_params: 0
+ ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}
+ entropy_from_logits_with_chunking: false
+ entropy_checkpointing: false
+ rollout:
+ name: sglang
+ mode: async
+ temperature: 1.0
+ top_k: -1
+ top_p: 1
+ prompt_length: ${oc.select:data.max_prompt_length,2048}
+ response_length: ${oc.select:data.max_response_length,2048}
+ dtype: bfloat16
+ gpu_memory_utilization: 0.85
+ ignore_eos: false
+ enforce_eager: false
+ free_cache_engine: true
+ tensor_model_parallel_size: 1
+ max_num_batched_tokens: 8192
+ max_model_len: null
+ max_num_seqs: 256
+ log_prob_micro_batch_size: null
+ log_prob_micro_batch_size_per_gpu: null
+ log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
+ log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
+ disable_log_stats: true
+ do_sample: true
+ 'n': 1
+ multi_stage_wake_up: false
+ multi_turn:
+ enable: true
+ max_assistant_turns: 5
+ agent:
+ num_workers: 0
+ val_kwargs:
+ do_sample: true
+ engine_kwargs:
+ sglang:
+ swap_space: null
+ disable_mm_preprocessor_cache: false
+
+# Critic configuration
+critic:
+ strategy: fsdp
+ use_torch_compile: true
+ critic_loss_coef: 1.0
+ value_head_mode: separate
+ use_kl_loss: false
+ kl_loss_coef: 0.001
+ kl_loss_type: low_var_kl
+ ppo_epochs: 1
+ ppo_micro_batch_size: null
+ ppo_micro_batch_size_per_gpu: 2
+ use_dynamic_bsz: false
+ ppo_max_token_len_per_gpu: 16384
+ shuffle: false
+ checkpoint:
+ save_contents:
+ - model
+ - optimizer
+ - extra
+ load_contents: ${.save_contents}
+ optim:
+ lr: 5.0e-07
+ lr_warmup_steps_ratio: 0.1
+ total_training_steps: -1
+ weight_decay: 0.01
+ lr_warmup_steps: -1
+ min_lr_ratio: 0.0
+ num_cycles: 0.5
+ warmup_style: constant
+ grad_clip: 1.0
+ ulysses_sequence_parallel_size: 1
+ entropy_from_logits_with_chunking: false
+ entropy_checkpointing: false
+ fsdp_config:
+ wrap_policy:
+ min_num_params: 0
+ param_offload: false
+ optimizer_offload: false
+ offload_policy: false
+ reshard_after_forward: true
+ fsdp_size: -1
+ forward_prefetch: false
+
+# Data configuration for multimodal training
+data:
+ max_prompt_length: 2048
+ max_response_length: 2048
+ train_batch_size: 64
+ micro_batch_size_per_gpu: 1
+ val_batch_size: 64
+ val_micro_batch_size_per_gpu: 1
+ train_files: null
+ val_files: null
+ dataset: null
+ gen_batch_size: ${mul:${data.train_batch_size},${rllm.rejection_sample.multiplier}}
+ return_raw_chat: true
+ return_multi_modal_inputs: true
+ trust_remote_code: true
+ pad_to_length: null
+ max_epochs: 10
+ shuffle_before_pack: true
+ pack_sequences: false
+ tokenizer_chat_template: null
+ response_only_loss_mask: true
+ normalize_reward: false
+ filter_length: -1
+ max_token_len: null
+ balance_sequence_length: false
+ balance_sequence_length_mode: packed_token
+ balance_sequence_length_sequence_weight: 0.7
+ balance_sequence_length_balance_weight: 0.3
+ balance_sequence_length_max_ratio: 10.0
+ parquet_data_file_fraction: 1.0
+
+# RLLM-specific configuration
+rllm:
+ agent:
+ name: geo3k_agent
+ max_steps: 10
+ trajectory_timeout: null
+ overlong_filter: false
+ agent_args:
+ accumulate_thinking: true
+ include_images_in_completion: true
+ engine_args: {}
+ env:
+ name: math
+ env_args: {}
+ workflow:
+ use_workflow: true
+ name: multi_turn_workflow
+ workflow_args:
+ max_prompt_length: ${data.max_prompt_length}
+ max_response_length: ${data.max_response_length}
+ timeout: 1e6
+ gamma: 1.0
+ reward_bonus_coeff: 0.0
+ accumulate_response_length: true
+ n_parallel_tasks: 64
+ retry_limit: 3
+ disable_thinking: false
+ mask_truncated_samples: false
+ stepwise_advantage:
+ enable: false
+ mode: broadcast
+ normalize_by_steps: false
+ compact_filtering:
+ enable: false
+ compact_ratio: 0.8
+ compact_mode: packed_token
+ rejection_sample:
+ enable: false
+ multiplier: 4
+ reward_threshold: 0.0
+ keep_best_n: 1
+
+# Algorithm configuration
+algorithm:
+ kl_penalty: 0.05
+ kl_target: 6.0
+ gamma: 1.0
+ lam: 0.95
+ cliprange: 0.2
+ cliprange_value: 0.2
+ vf_coef: 0.5
+ entropy_bonus: 0.01
+ use_kl_in_reward: false
+ adv_estimator: gae
+ normalization:
+ advantage: true
+ value_target: false
+
+# Reward model configuration
+reward_model:
+ enable: false
+ path: null
+ trust_remote_code: true
+
+# Trainer configuration
+trainer:
+ n_gpus_per_node: 4
+ nnodes: 1
+ save_freq: 100
+ eval_freq: 50
+ logging_freq: 10
+ save_dir: ./checkpoints/geo3k_multimodal
+ project_name: geo3k_multimodal_training
+ experiment_name: ${oc.env:USER,unknown}_geo3k_${now:%Y%m%d_%H%M%S}
+ tags: ["multimodal", "geo3k", "qwen2vl"]
+ disable_fast_tokenizer: false
+ profile_steps: null
+ controller_nsight_options:
+ trace: ["cuda", "osrt", "nvtx"]
+ output: "./nsight_traces"
+ force_overwrite: true
+ use_legacy_worker_impl: auto
+
+# Ray configuration
+ray_init:
+ num_cpus: 16
+ timeline_json_file: null
diff --git a/rllm/trainer/env_agent_mappings.py b/rllm/trainer/env_agent_mappings.py
index 4807464ee..4a1575591 100644
--- a/rllm/trainer/env_agent_mappings.py
+++ b/rllm/trainer/env_agent_mappings.py
@@ -24,6 +24,7 @@ def safe_import(module_path, class_name):
"tool_agent": safe_import("rllm.agents.tool_agent", "ToolAgent"),
"sweagent": safe_import("rllm.agents.swe_agent", "SWEAgent"),
"math_agent": safe_import("rllm.agents.math_agent", "MathAgent"),
+ "geo3k_agent": safe_import("rllm.agents.geo3k_agent", "Geo3kAgent"),
"code_agent": safe_import("rllm.agents.code_agent", "CompetitionCodingAgent"),
}
diff --git a/rllm/trainer/verl/agent_ppo_trainer.py b/rllm/trainer/verl/agent_ppo_trainer.py
index 734fef051..e0294ef1a 100644
--- a/rllm/trainer/verl/agent_ppo_trainer.py
+++ b/rllm/trainer/verl/agent_ppo_trainer.py
@@ -14,6 +14,7 @@
from omegaconf import OmegaConf
from rllm.engine.agent_execution_engine import AsyncAgentExecutionEngine
+from rllm.parser import ChatTemplateParser
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor
from verl.trainer.ppo.ray_trainer import (
@@ -45,12 +46,109 @@ def __init__(
agent_class=None,
env_args=None,
agent_args=None,
+ processor=None,
):
- super().__init__(config=config, tokenizer=tokenizer, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn)
+ super().__init__(config=config, tokenizer=tokenizer, processor=processor, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn)
self.env_class = env_class
self.agent_class = agent_class
self.env_args = env_args or {}
self.agent_args = agent_args or {}
+ self.processor = processor
+ disable_thinking = False
+ if hasattr(self.config, "rllm") and hasattr(self.config.rllm, "disable_thinking"):
+ disable_thinking = bool(self.config.rllm.disable_thinking)
+ self.chat_parser = ChatTemplateParser.get_parser(self.tokenizer, disable_thinking=disable_thinking)
+ self._mm_debug_prints = 0
+
+ def _dump_multimodal_debug(self, stage: str, data_proto: DataProto, max_prints: int = 6) -> None:
+ """Best-effort logging of multimodal payloads for debugging."""
+ if self._mm_debug_prints >= max_prints:
+ return
+
+ self._mm_debug_prints += 1
+
+ try:
+ batch_size = len(data_proto)
+ except Exception: # noqa: BLE001
+ batch_size = "unknown"
+
+ print("\n[MM DEBUG] stage=", stage, " batch_size=", batch_size)
+ try:
+ print(" non_tensor_keys=", list(data_proto.non_tensor_batch.keys()))
+ except Exception as exc: # noqa: BLE001
+ print(f" failed to list non_tensor_keys: {exc}")
+
+ # Inspect first element when available
+ try:
+ if len(data_proto) == 0:
+ return
+ except Exception: # noqa: BLE001
+ return
+
+ try:
+ sample = data_proto[0]
+ except Exception as exc: # noqa: BLE001
+ print(f" failed to fetch sample: {exc}")
+ return
+
+ sample_non_tensor = getattr(sample, "non_tensor_batch", {}) or {}
+ sample_batch = getattr(sample, "batch", None)
+
+ mm_data = sample_non_tensor.get("multi_modal_data")
+ mm_inputs = sample_non_tensor.get("multi_modal_inputs")
+ raw_prompt_text = sample_non_tensor.get("raw_prompt_text")
+
+ if raw_prompt_text:
+ print(" raw_prompt_text_head=", raw_prompt_text[:160])
+
+ if isinstance(mm_data, dict):
+ print(" multi_modal_data keys=", list(mm_data.keys()))
+ for key, value in mm_data.items():
+ if isinstance(value, list):
+ print(f" {key}: list(len={len(value)})")
+ else:
+ print(f" {key}: type={type(value)}")
+ else:
+ print(" multi_modal_data type=", type(mm_data))
+
+ if isinstance(mm_inputs, dict) and mm_inputs:
+ print(" multi_modal_inputs summary:")
+ for key, value in mm_inputs.items():
+ try:
+ if torch.is_tensor(value):
+ print(f" {key}: tensor shape={tuple(value.shape)} dtype={value.dtype}")
+ flat = value.flatten()
+ preview = flat[: min(10, flat.shape[0])].tolist()
+ print(f" sample[{key}]: {preview}")
+ elif isinstance(value, list) and value and torch.is_tensor(value[0]):
+ shapes = [tuple(v.shape) for v in value]
+ print(f" {key}: list[tensor] shapes={shapes}")
+ flat = value[0].flatten()
+ preview = flat[: min(10, flat.shape[0])].tolist()
+ print(f" sample[{key}[0]]: {preview}")
+ else:
+ print(f" {key}: type={type(value)}")
+ except Exception as exc: # noqa: BLE001
+ print(f" {key}: failed to inspect ({exc})")
+ else:
+ print(" multi_modal_inputs type=", type(mm_inputs))
+
+ if sample_batch is not None:
+ try:
+ prompts_tensor = data_proto.batch["prompts"][0]
+ prompt_ids = prompts_tensor.tolist() if hasattr(prompts_tensor, "tolist") else []
+ decoded_prompt = self.tokenizer.decode([pid for pid in prompt_ids if pid != self.tokenizer.pad_token_id])
+ tokens_head = prompt_ids[:64]
+ try:
+ token_strs = self.tokenizer.convert_ids_to_tokens(tokens_head)
+ except Exception: # noqa: BLE001
+ token_strs = ""
+ print(" prompt_token_ids_head=", tokens_head)
+ print(" prompt_token_strs_head=", token_strs)
+ print(" decoded_prompt_head=", decoded_prompt[:160])
+ except Exception as exc: # noqa: BLE001
+ print(f" failed to decode prompt tokens: {exc}")
+
assert self.config.actor_rollout_ref.hybrid_engine, "Only hybrid engine is supported"
assert self.config.actor_rollout_ref.rollout.mode == "async", "Only async rollout mode is supported"
@@ -68,6 +166,7 @@ def init_workers(self):
config=self.config,
engine_name="verl",
tokenizer=self.tokenizer,
+ processor=self.processor,
model_path=self.config.actor_rollout_ref.model.path,
max_steps=self.config.rllm.agent.max_steps,
max_response_length=self.config.data.max_response_length,
@@ -162,7 +261,17 @@ def fit_agent(self):
metrics = {}
timing_raw = {}
- batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"])
+ non_tensor_keys_to_pop = [
+ "raw_prompt",
+ "raw_prompt_text",
+ "raw_prompt_ids",
+ "multi_modal_data",
+ "multi_modal_inputs",
+ ]
+ batch.pop(
+ batch_keys=["input_ids", "attention_mask", "position_ids"],
+ non_tensor_batch_keys=[key for key in non_tensor_keys_to_pop if key in batch.non_tensor_batch],
+ )
with marked_timer("step", timing_raw):
self.init_envs_and_agents(batch)
@@ -175,10 +284,12 @@ def fit_agent(self):
final_gen_batch_output.meta_info.pop("repeat_counts", None) # no longer needed after this
# batch needs to be padded to divisor of world size, we will pad with everything masked out
batch = batch.union(final_gen_batch_output)
+ self._dump_multimodal_debug("train_union_steps", batch)
batch = self._pad_dataproto_to_world_size(batch=batch)
else:
final_gen_batch_output, generate_metrics = self.generate_agent_trajectory(timing_raw=timing_raw, meta_info=batch.meta_info)
batch = batch.union(final_gen_batch_output)
+ self._dump_multimodal_debug("train_union_traj", batch)
metrics.update(generate_metrics)
# compute values
@@ -425,7 +536,17 @@ def _validate_agent(self):
test_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object)
n_val_samples = self.config.actor_rollout_ref.rollout.val_kwargs.n
test_batch = test_batch.repeat(repeat_times=n_val_samples, interleave=True)
- test_batch.pop(["input_ids", "attention_mask", "position_ids"]) # these are not needed for environment based interaction
+ non_tensor_keys_to_pop = [
+ "raw_prompt",
+ "raw_prompt_text",
+ "raw_prompt_ids",
+ "multi_modal_data",
+ "multi_modal_inputs",
+ ]
+ test_batch.pop(
+ ["input_ids", "attention_mask", "position_ids"],
+ non_tensor_batch_keys=[key for key in non_tensor_keys_to_pop if key in test_batch.non_tensor_batch],
+ ) # remove tensors and multimodal caches before agent rollout
test_batch.meta_info = {
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
@@ -445,6 +566,7 @@ def _validate_agent(self):
test_output_gen_batch, _ = self.generate_agent_trajectory(meta_info=test_batch.meta_info)
test_batch = test_batch.union(test_output_gen_batch)
+ self._dump_multimodal_debug("validate_union", test_batch)
reward_tensor = test_batch.batch["token_level_scores"]
@@ -567,9 +689,24 @@ def _transform_agent_trajectories(self, trajectories: list[dict]):
all_masks_list = []
traj_scores = []
chat_completions = []
+ raw_prompt_texts = []
+ raw_prompt_ids_list = []
+ multi_modal_data_list = []
+ multi_modal_inputs_list = []
traj_metrics = []
metrics = {}
+ def _move_to_cpu(obj):
+ if torch.is_tensor(obj):
+ return obj.detach().cpu()
+ if isinstance(obj, dict):
+ return {k: _move_to_cpu(v) for k, v in obj.items()}
+ if isinstance(obj, list):
+ return [_move_to_cpu(v) for v in obj]
+ if isinstance(obj, tuple):
+ return tuple(_move_to_cpu(v) for v in obj)
+ return obj
+
for traj in trajectories:
prompt_tokens = traj["prompt_tokens"]
response_tokens = traj["response_tokens"]
@@ -580,6 +717,69 @@ def _transform_agent_trajectories(self, trajectories: list[dict]):
all_masks_list.append(traj["response_masks"])
traj_scores.append(traj["trajectory_reward"])
chat_completions.append(traj["chat_completions"])
+ prompt_messages = traj.get("prompt_messages") or traj.get("chat_completions") or []
+ multimodal_messages = traj.get("multimodal_messages") or []
+
+ multi_modal_data = traj.get("multi_modal_data")
+ if not isinstance(multi_modal_data, dict):
+ multi_modal_data = {}
+ else:
+ multi_modal_data = dict(multi_modal_data)
+
+ raw_prompt_text = None
+ raw_prompt_ids = None
+ multi_modal_inputs = {}
+
+ if self.processor is not None and multimodal_messages:
+ try:
+ raw_prompt_text = self.processor.apply_chat_template(
+ multimodal_messages,
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+
+ images = multi_modal_data.get("image") if multi_modal_data else None
+ if isinstance(images, list) and len(images) == 0:
+ images = None
+ videos = multi_modal_data.get("video") if multi_modal_data else None
+ if isinstance(videos, list) and len(videos) == 0:
+ videos = None
+
+ model_inputs = self.processor(
+ text=[raw_prompt_text],
+ images=images,
+ videos=videos,
+ return_tensors="pt",
+ )
+ raw_prompt_ids = model_inputs["input_ids"][0].tolist()
+ multi_modal_inputs = {
+ key: _move_to_cpu(value)
+ for key, value in model_inputs.items()
+ if key not in {"input_ids", "attention_mask"}
+ }
+ except Exception as exc: # noqa: BLE001
+ print(f"Warning: failed to build multimodal inputs: {exc}")
+ raw_prompt_text = None
+ raw_prompt_ids = None
+ multi_modal_inputs = {}
+
+ if raw_prompt_text is None and prompt_messages:
+ try:
+ raw_prompt_text = self.chat_parser.parse(
+ prompt_messages,
+ add_generation_prompt=True,
+ is_first_msg=True,
+ )
+ except Exception: # noqa: BLE001
+ raw_prompt_text = None
+
+ if raw_prompt_ids is None:
+ raw_prompt_ids = prompt_tokens.tolist()
+
+ raw_prompt_texts.append(raw_prompt_text)
+ raw_prompt_ids_list.append(raw_prompt_ids)
+ multi_modal_data_list.append(_move_to_cpu(multi_modal_data))
+ multi_modal_inputs_list.append(_move_to_cpu(multi_modal_inputs))
traj_metrics.append(traj["metrics"])
# Flatten traj_metrics into a dict of lists
@@ -655,9 +855,20 @@ def _transform_agent_trajectories(self, trajectories: list[dict]):
"response_mask": traj_mask,
}
- self.visualize_trajectory(DataProto.from_dict(tensors=tensor_batch))
+ non_tensor_batch = {
+ "raw_prompt": np.array(raw_prompt_texts, dtype=object),
+ "raw_prompt_text": np.array(raw_prompt_texts, dtype=object),
+ "raw_prompt_ids": np.array(raw_prompt_ids_list, dtype=object),
+ "multi_modal_data": np.array(multi_modal_data_list, dtype=object),
+ "multi_modal_inputs": np.array(multi_modal_inputs_list, dtype=object),
+ }
- return DataProto.from_dict(tensors=tensor_batch), metrics
+ traj_dataproto = DataProto.from_dict(tensors=tensor_batch, non_tensors=non_tensor_batch)
+
+ self.visualize_trajectory(traj_dataproto)
+ self._dump_multimodal_debug("transform_trajectory", traj_dataproto)
+
+ return traj_dataproto, metrics
def visualize_trajectory(self, tensor_batch, sample_idx=0, max_samples=1, mask_key="response_mask"):
"""
@@ -881,6 +1092,7 @@ def _transform_agent_steps(self, steps: list[dict], uids: np.ndarray):
sample_indices = np.random.choice(last_step_indices, size=min(2, len(last_step_indices)), replace=False)
for idx in sample_indices:
self.visualize_trajectory(result, sample_idx=idx, max_samples=1)
+ self._dump_multimodal_debug("transform_steps", result)
return result
def _stepwise_advantage_broadcast(self, last_step_batch, other_step_batch):
@@ -956,3 +1168,12 @@ def _pad_dataproto_to_world_size(self, batch):
batch.non_tensor_batch["is_pad_step"][idx] = True
return batch
+
+ def shutdown(self):
+ """Gracefully release resources used by the async agent execution engine."""
+ engine = getattr(self, "agent_execution_engine", None)
+ if engine is not None and hasattr(engine, "shutdown"):
+ try:
+ engine.shutdown()
+ except Exception as exc: # noqa: BLE001
+ print(f"Warning: failed to shutdown agent execution engine: {exc}")
diff --git a/rllm/trainer/verl/train_agent_ppo.py b/rllm/trainer/verl/train_agent_ppo.py
index f05966fc2..b5668242d 100644
--- a/rllm/trainer/verl/train_agent_ppo.py
+++ b/rllm/trainer/verl/train_agent_ppo.py
@@ -89,11 +89,20 @@ def run(self, config, workflow_class=None, workflow_args=None, agent_class=None,
# Instantiate the tokenizer and processor.
from verl.utils import hf_tokenizer
+ from verl.utils.tokenizer import hf_processor
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
- # Used for multimodal LLM, could be None
- # processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
+
+ # Enable processor for multimodal models
+ processor = None
+ if config.get("multimodal", {}).get("enable", False):
+ try:
+ processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
+ print(f"Multimodal processor loaded: {type(processor)}")
+ except Exception as e:
+ print(f"Failed to load multimodal processor: {e}")
+ print("Falling back to text-only training")
# Define worker classes based on the actor strategy.
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
@@ -192,6 +201,7 @@ def run(self, config, workflow_class=None, workflow_args=None, agent_class=None,
trainer = AgentPPOTrainer(
config=config,
tokenizer=tokenizer,
+ processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,