diff --git a/examples/geo3k/prepare_geo3k_data.py b/examples/geo3k/prepare_geo3k_data.py new file mode 100755 index 000000000..09f536581 --- /dev/null +++ b/examples/geo3k/prepare_geo3k_data.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +"""Prepare GEO3K multimodal geometry dataset for RLLM training.""" + +import base64 +from io import BytesIO +from typing import Iterable, List + +from datasets import load_dataset +from PIL import Image + +from rllm.data.dataset import DatasetRegistry + +DATA_URI_PREFIX = "data:image/png;base64," + + +def _serialize_images(images: Iterable[Image.Image]) -> List[str]: + """Serialize a list of PIL images into base64 data URIs for Parquet storage.""" + + serialized: list[str] = [] + for image in images or []: + if not isinstance(image, Image.Image): + continue + buffer = BytesIO() + image.convert("RGB").save(buffer, format="PNG") + encoded = base64.b64encode(buffer.getvalue()).decode("utf-8") + serialized.append(f"{DATA_URI_PREFIX}{encoded}") + return serialized + + +def prepare_geo3k_data(): + """ + Prepare GEO3K dataset following RLLM conventions. + + Returns: + Tuple of (train_dataset, test_dataset) registered with DatasetRegistry + """ + print("šŸ“„ Loading GEO3K dataset from HuggingFace...") + data_source = "hiyouga/geometry3k" + dataset = load_dataset(data_source) + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + print(f"āœ… Dataset loaded:") + print(f" - Train samples: {len(train_dataset)}") + print(f" - Test samples: {len(test_dataset)}") + + # Instruction template based on Verl's GEO3K processing + instruction_following = ( + "You must strictly follow this answer template: " + "(1) write all reasoning as an internal monologue inside ...; " + "(2) after , output exactly one line in the form `Final answer: \\boxed{value}` with the numeric solution." + ) + + formatting_example = ( + "Reason step by step here....\n" + "Final answer: \\boxed{42}" + ) + + def preprocess_fn(example, idx): + """ + Preprocess function to convert GEO3K data to RLLM format. + + This follows the pattern from verl/examples/data_preprocess/geo3k.py + but adapts it for RLLM's simpler format requirements. + """ + problem = example["problem"] + answer = example["answer"] + images = _serialize_images(example.get("images", [])) + + # Create the full prompt with instruction following + prompt = problem + "\n\n" + instruction_following + "\nExample format:\n" + formatting_example + + # Return RLLM-compatible format + return { + "question": prompt, # RLLM expects 'question' field + "ground_truth": answer, # RLLM expects 'ground_truth' field + "data_source": "hiyouga/geometry3k", # Data source identifier matching verl's reward function + "images": images, # Serialized data URIs; reconstructed during training/inference + "extra_info": { + "original_problem": problem, + "answer": answer, + "has_images": len(images) > 0, + "num_images": len(images), + "formatting_example": formatting_example, + } + } + + print("šŸ”„ Preprocessing datasets...") + + # Apply preprocessing to both splits + train_dataset = train_dataset.map(preprocess_fn, with_indices=True) + test_dataset = test_dataset.map(preprocess_fn, with_indices=True) + + # Register datasets with RLLM DatasetRegistry + print("šŸ“‹ Registering datasets with RLLM...") + + train_dataset = DatasetRegistry.register_dataset("geo3k_train", train_dataset, "train") + test_dataset = DatasetRegistry.register_dataset("geo3k_test", test_dataset, "test") + + print("āœ… Datasets registered:") + print(" - geo3k_train (training data)") + print(" - geo3k_test (evaluation data)") + + return train_dataset, test_dataset + + +if __name__ == "__main__": + train_dataset, test_dataset = prepare_geo3k_data() + print(f"\nšŸ“Š Dataset Statistics:") + print(f" - Training samples: {len(train_dataset)}") + print(f" - Test samples: {len(test_dataset)}") + + # Show a sample to verify format + sample = train_dataset[0] + print(f"\nšŸ“‹ Sample data format:") + print(f" - Question length: {len(sample['question'])} chars") + print(f" - Has images: {len(sample['images']) > 0}") + print(f" - Number of images: {len(sample['images'])}") + print(f" - Ground truth: {sample['ground_truth'][:100]}...") + + print(f"\nšŸŽ‰ GEO3K dataset preparation complete!") + print(f"\nNext steps:") + print(f"1. Configure multimodal model in training config") + print(f"2. Run training: python train_geo3k_agent.py") + print(f"3. Test inference: python run_geo3k_agent.py") diff --git a/examples/geo3k/run_geo3k_agent.py b/examples/geo3k/run_geo3k_agent.py new file mode 100755 index 000000000..ae1f27b66 --- /dev/null +++ b/examples/geo3k/run_geo3k_agent.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +""" +Inference script for GEO3K multimodal geometry agent. + +This script tests the trained multimodal agent on GEO3K geometry problems +and evaluates its performance on problems with images. +""" + +import asyncio +import os + +from transformers import AutoTokenizer + +from rllm.agents import Geo3kAgent +from rllm.data.dataset import DatasetRegistry +from rllm.engine.agent_execution_engine import AgentExecutionEngine +from rllm.environments.base.single_turn_env import SingleTurnEnvironment +from rllm.rewards.reward_fn import math_reward_fn +from rllm.utils import compute_pass_at_k + + +def main(): + """ + Main inference function for testing GEO3K multimodal agent. + """ + print("šŸš€ GEO3K Multimodal Agent Inference") + print("=" * 40) + + # Set environment variables + os.environ["TOKENIZERS_PARALLELISM"] = "true" + + # Configuration + n_parallel_agents = 16 # Adjust based on your GPU memory + model_name = "Qwen/Qwen2-VL-2B-Instruct" # Default multimodal model + + print(f"šŸ”§ Configuration:") + print(f" - Model: {model_name}") + print(f" - Parallel agents: {n_parallel_agents}") + + # Initialize tokenizer + print("šŸ“ Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + # Agent configuration for multimodal geometry reasoning + agent_args = { + "accumulate_thinking": True, + "include_images_in_completion": True, # Include image info in completions for debugging + } + + # Environment configuration + env_args = { + "reward_fn": math_reward_fn, + } + + # Sampling parameters for inference + sampling_params = { + "temperature": 0.7, # Slightly higher for reasoning diversity + "top_p": 0.95, + "max_tokens": 2048, + "model": model_name + } + + print("šŸ”§ Initializing agent execution engine...") + + # Initialize execution engine + # Note: You can switch between "openai" and "verl" engine + # For multimodal models, "verl" engine with SGLang backend is recommended + engine = AgentExecutionEngine( + agent_class=Geo3kAgent, + agent_args=agent_args, + env_class=SingleTurnEnvironment, + env_args=env_args, + engine_name="openai", # Can be "openai" or "verl" + rollout_engine_args={ + "base_url": "http://localhost:30000/v1", # SGLang server URL + "api_key": "None" + }, + tokenizer=tokenizer, + sampling_params=sampling_params, + max_response_length=2048, + max_prompt_length=2048, + n_parallel_agents=n_parallel_agents, + ) + + # Load test dataset + print("šŸ“Š Loading test dataset...") + try: + test_dataset = DatasetRegistry.load_dataset("geo3k_test", "test") + print(f"āœ… Test dataset loaded: {len(test_dataset)} samples") + except Exception as e: + print(f"āŒ Dataset not found: {e}") + print("šŸ”„ Preparing dataset...") + from prepare_geo3k_data import prepare_geo3k_data + + _, test_dataset = prepare_geo3k_data() + + test_data = test_dataset.get_data() + + # Take a smaller subset for quick testing + test_samples = 50 # Adjust as needed + subset_size = min(test_samples, len(test_data)) + test_subset = test_data[:subset_size] + print(f"šŸŽÆ Testing on {subset_size} samples") + + # Check multimodal content statistics + multimodal_count = sum(1 for sample in test_subset if sample.get("images")) + print(f"šŸ“ø Multimodal samples: {multimodal_count}/{subset_size}") + + # Repeat samples for pass@k evaluation + n_repeats = 4 # Number of attempts per problem + tasks = [sample for sample in test_subset for _ in range(n_repeats)] + print(f"šŸ”„ Total tasks (with repeats): {len(tasks)}") + + # Show sample problem + if not test_subset: + print("āš ļø No samples found in test dataset. Exiting.") + return + + sample = test_subset[0] + print(f"\nšŸ“‹ Sample Problem:") + question_preview = sample.get("question", "")[:200] + ground_truth_preview = str(sample.get("ground_truth", ""))[:100] + has_images = bool(sample.get("images")) + print(f" Question: {question_preview}...") + print(f" Ground Truth: {ground_truth_preview}...") + print(f" Has Images: {has_images}") + + # Run inference + print("\nšŸš€ Starting inference...") + print("ā³ This may take a while depending on the number of samples and model speed...") + + try: + results = asyncio.run(engine.execute_tasks(tasks)) + print(f"\nāœ… Inference completed!") + print(f"šŸ“Š Results: {len(results)} task completions") + + # Compute and display pass@k metrics + print("\nšŸ“ˆ Computing Pass@K metrics...") + pass_at_k_results = compute_pass_at_k(results) + + # Display results + print(f"\nšŸŽÆ Performance Summary:") + for k, score in pass_at_k_results.items(): + print(f" - Pass@{k}: {score:.3f}") + + # Show some example results + print(f"\nšŸ“ Example Results:") + for i, result in enumerate(results[:3]): # Show first 3 results + reward = result.get('reward', 0) + success = "āœ…" if reward > 0.5 else "āŒ" + print(f" {success} Task {i+1}: Reward = {reward:.3f}") + + except Exception as e: + print(f"āŒ Inference failed: {e}") + print("šŸ’” Make sure your model server is running:") + print(" SGLang: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-2B-Instruct") + print(" vLLM: python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2-VL-2B-Instruct") + raise + + print(f"\nšŸŽ‰ GEO3K inference completed!") + print(f"šŸ’” For better performance, consider:") + print(f" - Using a larger multimodal model (e.g., Qwen2-VL-7B)") + print(f" - Fine-tuning on GEO3K training data") + print(f" - Adjusting sampling parameters") + + +if __name__ == "__main__": + main() diff --git a/examples/geo3k/train_geo3k_agent.py b/examples/geo3k/train_geo3k_agent.py new file mode 100644 index 000000000..54f9103d2 --- /dev/null +++ b/examples/geo3k/train_geo3k_agent.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +""" +Training script for GEO3K multimodal geometry agent using RLLM framework. + +This script follows RLLM patterns and leverages Verl's multimodal capabilities +for training on geometry problems with images. +""" + +import hydra +from omegaconf import OmegaConf, open_dict + +from rllm.agents import Geo3kAgent +from rllm.data.dataset import DatasetRegistry +from rllm.environments.base.single_turn_env import SingleTurnEnvironment +from rllm.rewards.reward_fn import math_reward_fn +from rllm.trainer.agent_trainer import AgentTrainer + + +@hydra.main(config_path="pkg://rllm.trainer.config", config_name="geo3k_multimodal_trainer", version_base=None) +def main(config): + """ + Main training function for GEO3K multimodal agent. + + This follows the same pattern as other RLLM examples while enabling + multimodal training through Verl's capabilities. + """ + # Load datasets (must be prepared first using prepare_geo3k_data.py) + train_dataset = DatasetRegistry.load_dataset("geo3k_train", "train") + test_dataset = DatasetRegistry.load_dataset("geo3k_test", "test") + + # Ensure configuration matches Verl GEO3K example defaults + if hasattr(config, "actor_rollout_ref") and hasattr(config.actor_rollout_ref, "model"): + with open_dict(config.actor_rollout_ref.model): + config.actor_rollout_ref.model.path = "Qwen/Qwen2.5-VL-7B-Instruct" + + if not hasattr(config, "multimodal") or config.multimodal is None: + with open_dict(config): + config.multimodal = OmegaConf.create({"enable": True, "image_key": "images"}) + else: + with open_dict(config.multimodal): + config.multimodal.enable = True + config.multimodal.image_key = "images" + + rejection_multiplier = 1 + if hasattr(config, "rllm") and hasattr(config.rllm, "rejection_sample") and config.rllm.rejection_sample is not None: + rejection_multiplier = getattr(config.rllm.rejection_sample, "multiplier", 1) or 1 + + if hasattr(config, "data"): + with open_dict(config.data): + config.data.image_key = "images" + config.data.return_multi_modal_inputs = True + config.data.train_batch_size = 32 + config.data.val_batch_size = 32 + config.data.gen_batch_size = 32 * max(1, rejection_multiplier) + + # Ensure PPO mini-batch sizes do not exceed data batch sizes + train_batch_size = getattr(config.data, "train_batch_size", None) + val_batch_size = getattr(config.data, "val_batch_size", None) + + if hasattr(config, "actor_rollout_ref") and hasattr(config.actor_rollout_ref, "actor"): + with open_dict(config.actor_rollout_ref.actor): + mini_batch = getattr(config.actor_rollout_ref.actor, "ppo_mini_batch_size", None) + if train_batch_size is not None and (mini_batch is None or mini_batch > train_batch_size): + config.actor_rollout_ref.actor.ppo_mini_batch_size = train_batch_size + + if hasattr(config, "critic"): + with open_dict(config.critic): + mini_batch = getattr(config.critic, "ppo_mini_batch_size", None) + if val_batch_size is not None and (mini_batch is None or mini_batch > val_batch_size): + config.critic.ppo_mini_batch_size = val_batch_size + + # Disable workflow mode for single-turn geo3k training + if hasattr(config, "rllm") and hasattr(config.rllm, "workflow"): + with open_dict(config.rllm.workflow): + config.rllm.workflow.use_workflow = False + + # Reduce rollout parallelism to avoid GPU OOM + if hasattr(config, "actor_rollout_ref") and hasattr(config.actor_rollout_ref, "rollout"): + with open_dict(config.actor_rollout_ref.rollout): + config.actor_rollout_ref.rollout.n = 1 + config.actor_rollout_ref.rollout.gpu_memory_utilization = min(0.6, config.actor_rollout_ref.rollout.gpu_memory_utilization) + config.actor_rollout_ref.rollout.max_num_batched_tokens = 4096 + config.actor_rollout_ref.rollout.max_num_seqs = 128 + + if hasattr(config.actor_rollout_ref.rollout, "agent"): + with open_dict(config.actor_rollout_ref.rollout.agent): + current_workers = getattr(config.actor_rollout_ref.rollout.agent, "num_workers", 0) or 0 + config.actor_rollout_ref.rollout.agent.num_workers = max(1, current_workers) + + # Agent configuration - minimal args following RLLM patterns + agent_args = { + "accumulate_thinking": True, + "include_images_in_completion": True, # Enable image info in completions for multimodal training + } + + # Environment configuration - simple single-turn math environment + env_args = { + "reward_fn": math_reward_fn, + } + + # Initialize trainer following RLLM patterns + trainer = AgentTrainer( + agent_class=Geo3kAgent, + env_class=SingleTurnEnvironment, + agent_args=agent_args, + env_args=env_args, + config=config, + train_dataset=train_dataset, + val_dataset=test_dataset, + ) + + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/examples/geo3k/train_geo3k_agent.sh b/examples/geo3k/train_geo3k_agent.sh new file mode 100755 index 000000000..3b5e2bb0d --- /dev/null +++ b/examples/geo3k/train_geo3k_agent.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# GEO3K Multimodal Agent Training Script +# This script follows RLLM patterns while enabling multimodal training through Verl + +set -x + +# Environment setup for multimodal training +export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False" +export TOKENIZERS_PARALLELISM=false + +# Use SGLang for multimodal support (following Verl's GEO3K example) +# Uncomment these for vLLM if preferred +# export VLLM_ATTENTION_BACKEND=FLASH_ATTN +# export VLLM_USE_V1=1 +# export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 +# export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000 + +# Find the directory where rllm package is located +RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))") + +# Train GEO3K agent with multimodal support +# Configuration based on Verl's geo3k_multiturn examples and RLLM patterns +python3 -m examples.geo3k.train_geo3k_agent \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=64 \ + data.val_batch_size=64 \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.return_raw_chat=True \ + data.return_multi_modal_inputs=True \ + data.trust_remote_code=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.hybrid_engine=True \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.loss_agg_mode=token-mean \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_dynamic_bsz=False \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=16384 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.clip_ratio=0.2 \ + actor_rollout_ref.actor.entropy_coeff=0.01 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.dtype=bfloat16 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=1.0 \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.temperature=0.7 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + critic.ppo_micro_batch_size_per_gpu=2 \ + rllm.agent.name=geo3k_agent \ + rllm.agent.max_steps=1 \ + rllm.env.name=math \ + rllm.mask_truncated_samples=False \ + rllm.stepwise_advantage.enable=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='rllm-geo3k-multimodal' \ + trainer.experiment_name='geo3k-multimodal-qwen2vl-2b' \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=100 \ + trainer.test_freq=50 \ + trainer.total_epochs=10 \ + trainer.default_hdfs_dir=null \ No newline at end of file diff --git a/rllm/agents/__init__.py b/rllm/agents/__init__.py index 43924560d..d4e9f007e 100644 --- a/rllm/agents/__init__.py +++ b/rllm/agents/__init__.py @@ -15,6 +15,7 @@ def safe_import(module_path, class_name): # Define all agent imports AGENT_IMPORTS = [ + ("rllm.agents.geo3k_agent", "Geo3kAgent"), ("rllm.agents.miniwob_agent", "MiniWobAgent"), ("rllm.agents.frozenlake_agent", "FrozenLakeAgent"), ("rllm.agents.swe_agent", "SWEAgent"), diff --git a/rllm/agents/geo3k_agent.py b/rllm/agents/geo3k_agent.py new file mode 100644 index 000000000..628ae4a6c --- /dev/null +++ b/rllm/agents/geo3k_agent.py @@ -0,0 +1,245 @@ +"""Geometry3k Agent with multimodal support. + +This agent is designed to solve geometry problems that include both text and images, +leveraging RLLM's multimodal capabilities built on top of Verl. +""" + +from typing import Any, List, Union, Dict +from PIL import Image + +from rllm.agents.agent import Action, BaseAgent, Step, Trajectory +from rllm.data.multimodal import MultimodalMessage, as_pil_image, create_multimodal_conversation + + +class Geo3kAgent(BaseAgent): + """ + A geometry agent that solves mathematical problems with multimodal inputs (text + images). + + This agent extends the standard MathAgent pattern to handle images that are common + in geometry problems, while maintaining full compatibility with RLLM's training infrastructure. + """ + + def __init__(self, accumulate_thinking=True, include_images_in_completion=False): + """ + Initialize the Geo3kAgent. + + Args: + accumulate_thinking: Whether to accumulate thinking across turns + include_images_in_completion: Whether to include image info in chat_completions + """ + self._trajectory = Trajectory() + self.messages = [] # Store MultimodalMessage objects + self.accumulate_thinking = accumulate_thinking + self.include_images_in_completion = include_images_in_completion + self._current_images = [] # Store images for current problem + + def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs): + """Process environment feedback and update internal state with multimodal support.""" + + # If observation is None, this is a reward update for the existing step + if observation is None: + if self.trajectory.steps: + cur_step = self.get_current_state() + cur_step.reward = reward + cur_step.done = done + cur_step.info = info + return + + # Extract multimodal data from observation + images = [] + if isinstance(observation, dict): + formatted_observation = observation.get("question", str(observation)) + raw_images = observation.get("images", []) + print(f"DEBUG: Observation dict has images: {len(raw_images) if raw_images else 0}") + elif isinstance(observation, str): + formatted_observation = observation + # Check if kwargs contains images + raw_images = kwargs.get("images", []) + print(f"DEBUG: String observation, kwargs has images: {len(raw_images) if raw_images else 0}") + else: + raise ValueError(f"Invalid observation type: {type(observation)}") + + print(f"DEBUG: Raw images count: {len(raw_images) if raw_images else 0}") + if raw_images: + print(f"DEBUG: First raw image type: {type(raw_images[0])}") + if isinstance(raw_images[0], dict): + print(f"DEBUG: First raw image dict keys: {raw_images[0].keys()}") + + decoded_images: list[Image.Image] = [] + # Handle numpy arrays properly + images_to_process = [] + if raw_images is not None: + if hasattr(raw_images, '__len__') and len(raw_images) > 0: + images_to_process = raw_images + + for i, image in enumerate(images_to_process): + pil_image = as_pil_image(image) + print(f"DEBUG: Image {i} decoded successfully: {pil_image is not None}") + if pil_image is None and isinstance(image, str): + try: + pil_image = Image.open(image).convert("RGB") + print(f"DEBUG: Image {i} decoded from file path") + except Exception as e: + print(f"DEBUG: Image {i} failed file decode: {e}") + pil_image = None + if pil_image is not None: + decoded_images.append(pil_image) + print(f"DEBUG: Image {i} added to decoded list, size: {pil_image.size}") + + images = decoded_images + self._current_images = images + print(f"DEBUG: Final decoded images count: {len(images)}") + + # Create multimodal message + multimodal_message = MultimodalMessage( + role="user", + text=formatted_observation, + images=images if images else None + ) + self.messages.append(multimodal_message) + + # Create new step + new_step = Step(observation=formatted_observation) + # Store multimodal info in step's info dict for compatibility + if images: + new_step.info = {"images": images} + + self._trajectory.steps.append(new_step) + + def update_from_model(self, response: str, **kwargs) -> Action: + """ + Updates the agent's internal state based on the model's response. + """ + + # Create assistant message (models typically don't generate images in geo3k) + assistant_message = MultimodalMessage( + role="assistant", + text=response + ) + self.messages.append(assistant_message) + + # Update the latest step + cur_step = self.get_current_state() + cur_step.chat_completions = self.chat_completions + cur_step.model_response = response + + # Parse thinking and action (same as MathAgent) + if response.count("") == 1: + thought, sep, action = response.partition("") + thought = thought + sep + action = Action(action.strip()) + else: + thought = None + action = Action(response.strip()) + + cur_step.thought = thought + cur_step.action = action + + return action + + def reset(self) -> None: + """Reset agent state for new episode (wipes trajectory and messages).""" + self._trajectory = Trajectory() + self.messages = [] + self._current_images = [] + + @property + def chat_completions(self) -> List[Dict[str, str]]: + """Return conversation history for model interaction. + + For multimodal agents, this returns the Verl-compatible format. + The actual multimodal data is provided via get_multimodal_data(). + """ + # Convert multimodal messages to Verl format + conversation = self.get_multimodal_conversation() + + # For backward compatibility, also provide text-only format + completions = [] + for msg in conversation: + completion = {"role": msg["role"]} + + if isinstance(msg.get("content"), list): + # Extract text parts from multimodal content + text_parts = [] + for item in msg["content"]: + if item.get("type") == "text": + text_parts.append(item["text"]) + elif item.get("type") == "image" and self.include_images_in_completion: + text_parts.append(f"[Images: {len(item.get('image', []))} geometry diagrams]") + completion["content"] = " ".join(text_parts) if text_parts else "" + else: + completion["content"] = msg.get("content", "") + + completions.append(completion) + + # Apply thinking accumulation logic (same as MathAgent) + if not self.accumulate_thinking: + for msg in completions[:-1]: + if msg["role"] == "assistant": + _, sep, after = msg["content"].partition("") + if sep: + msg["content"] = after + + return completions + + @property + def trajectory(self) -> Trajectory: + """Return complete interaction trajectory.""" + return self._trajectory + + def get_current_state(self) -> Step: + """Returns the current step/state of the agent.""" + assert self._trajectory.steps, "Trajectory should not be empty when get_current_state is called." + return self._trajectory.steps[-1] + + # Multimodal-specific methods + + def get_multimodal_conversation(self) -> List[Dict[str, Any]]: + """Get the conversation in Verl's multimodal format.""" + return create_multimodal_conversation(self.messages) + + def get_multimodal_data(self) -> Dict[str, List[Any]]: + """Extract all multimodal data for Verl processing.""" + # Collect all images from current trajectory + all_images = [] + all_videos = [] + + # Add images from current observation if any + if self._current_images: + all_images.extend(self._current_images) + + # Add images from all messages in conversation + for msg in self.messages: + if msg.images: + all_images.extend(msg.images) + if msg.videos: + all_videos.extend(msg.videos) + + print(f"DEBUG: get_multimodal_data - total images: {len(all_images)}, total videos: {len(all_videos)}") + if all_images: + print(f"DEBUG: get_multimodal_data - first image type: {type(all_images[0])}") + + # Return in Verl's expected format + return { + "image": all_images, + "video": all_videos + } + + def get_current_images(self) -> List[Union[str, Dict, Image.Image]]: + """Get images from the current problem.""" + return self._current_images + + def has_multimodal_content(self) -> bool: + """Check if the current trajectory contains any multimodal content.""" + return any(msg.images for msg in self.messages if msg.images) + + def get_multimodal_summary(self) -> Dict[str, Any]: + """Get a summary of multimodal content in the trajectory.""" + total_images = sum(len(msg.images) for msg in self.messages if msg.images) + return { + "total_images": total_images, + "has_multimodal": self.has_multimodal_content(), + "current_images": len(self._current_images), + "steps_with_images": sum(1 for step in self._trajectory.steps + if step.info and "images" in step.info) + } diff --git a/rllm/data/__init__.py b/rllm/data/__init__.py index 6a8f1be83..415d75090 100644 --- a/rllm/data/__init__.py +++ b/rllm/data/__init__.py @@ -1,6 +1,15 @@ from rllm.data.dataset import Dataset, DatasetRegistry from rllm.data.dataset_types import Dataset as DatasetEnum from rllm.data.dataset_types import DatasetConfig, Problem, TestDataset, TrainDataset +from rllm.data.multimodal import ( + MultimodalMessage, + create_multimodal_conversation, + extract_multimodal_data, + create_text_message, + create_image_message, + create_video_message, + create_multimodal_message, +) __all__ = [ "TrainDataset", @@ -10,4 +19,12 @@ "DatasetRegistry", "Problem", "DatasetConfig", + # Multimodal support + "MultimodalMessage", + "create_multimodal_conversation", + "extract_multimodal_data", + "create_text_message", + "create_image_message", + "create_video_message", + "create_multimodal_message", ] diff --git a/rllm/data/multimodal.py b/rllm/data/multimodal.py new file mode 100644 index 000000000..6df738aea --- /dev/null +++ b/rllm/data/multimodal.py @@ -0,0 +1,220 @@ +"""Multimodal data utilities for RLLM.""" + +import base64 +from dataclasses import dataclass +from io import BytesIO +from typing import Any, Dict, List, Optional, Union + +from PIL import Image + +# Import Verl's vision utilities - these are the core multimodal processing functions +try: + from verl.utils.dataset.vision_utils import process_image, process_video + from verl.workers.rollout.schemas import AsyncRolloutRequest, Message + VERL_MULTIMODAL_AVAILABLE = True +except ImportError: + VERL_MULTIMODAL_AVAILABLE = False + # Fallback stubs for type hints + class AsyncRolloutRequest: + pass + class Message: + pass + + def process_image(image): # type: ignore[override] + return image + + def process_video(video): # type: ignore[override] + return video + + +DATA_URI_PREFIX = "data:image/" + + +def as_pil_image(image: Any) -> Image.Image | None: + """Convert supported payloads (PIL image, data URI, byte dict) to a PIL image.""" + + if hasattr(image, "mode") and hasattr(image, "size"): + return image # Already a PIL image + + if isinstance(image, str) and image.startswith(DATA_URI_PREFIX): + try: + header, encoded = image.split(",", 1) + image_bytes = base64.b64decode(encoded) + return Image.open(BytesIO(image_bytes)).convert("RGB") + except Exception: + return None + + if isinstance(image, dict): + if "bytes" in image and image["bytes"] is not None: + try: + return Image.open(BytesIO(image["bytes"])).convert("RGB") + except Exception: + return None + + # HuggingFace datasets often store images as dicts with "path" pointing to a + # file or a data URI. Handle both cases, falling back to disk loading. + data_str: str | None = None + if "data" in image and isinstance(image["data"], str): + data_str = image["data"] + elif "path" in image and isinstance(image["path"], str): + data_str = image["path"] + + if data_str: + if data_str.startswith(DATA_URI_PREFIX): + try: + _, encoded = data_str.split(",", 1) + image_bytes = base64.b64decode(encoded) + return Image.open(BytesIO(image_bytes)).convert("RGB") + except Exception: + return None + else: + try: + return Image.open(data_str).convert("RGB") + except Exception: + return None + + return None + + +def ensure_multimodal_available(): + """Ensure Verl multimodal capabilities are available.""" + if not VERL_MULTIMODAL_AVAILABLE: + raise ImportError( + "Verl multimodal support not available. Please ensure Verl is properly installed." + ) + + +@dataclass +class MultimodalMessage: + """A message that can contain text, images, and videos. + + This is a lightweight wrapper around Verl's Message format for easier use in RLLM. + """ + + role: str # "user", "assistant", or "system" + text: Optional[str] = None + images: Optional[List[Union[str, Dict, Image.Image]]] = None + videos: Optional[List[Dict]] = None + + def to_verl_message(self) -> Dict[str, Any]: + """Convert to Verl's Message format.""" + ensure_multimodal_available() + + content = [] + + # Add text content + if self.text: + content.append({"type": "text", "text": self.text}) + + # Add image content - process through Verl's utilities + if self.images: + processed_images = [] + for image in self.images: + pil_image = as_pil_image(image) + if pil_image is not None: + processed_images.append(pil_image.convert("RGB")) + else: + processed_images.append(process_image(image)) + content.append({"type": "image", "image": processed_images}) + + # Add video content - process through Verl's utilities + if self.videos: + processed_videos = [] + for video in self.videos: + processed_video = process_video(video) + processed_videos.append(processed_video) + content.append({"type": "video", "video": processed_videos}) + + return { + "role": self.role, + "content": content if content else self.text or "" + } + + +def create_multimodal_conversation(messages: List[MultimodalMessage]) -> List[Dict[str, Any]]: + """Create a conversation from MultimodalMessage objects. + + Args: + messages: List of MultimodalMessage objects + + Returns: + List of messages in Verl format, ready for training + """ + return [msg.to_verl_message() for msg in messages] + + +def extract_multimodal_data(messages: List[Dict[str, Any]]) -> Dict[str, List[Any]]: + """Extract multimodal data from messages for Verl processing. + + This function extracts all images and videos from a conversation, + which is the format expected by Verl's AsyncRolloutRequest. + + Args: + messages: List of messages in Verl format + + Returns: + Dictionary with 'image' and 'video' keys containing processed media + """ + ensure_multimodal_available() + + all_images = [] + all_videos = [] + + for message in messages: + content = message.get("content", []) + if isinstance(content, str): + continue + + for item in content: + if item.get("type") == "image": + images = item.get("image", []) + if isinstance(images, list): + all_images.extend(images) + else: + all_images.append(images) + elif item.get("type") == "video": + videos = item.get("video", []) + if isinstance(videos, list): + all_videos.extend(videos) + else: + all_videos.append(videos) + + return { + "image": all_images, + "video": all_videos + } + + +# Convenience functions for common use cases + +def create_text_message(role: str, text: str) -> MultimodalMessage: + """Create a text-only message.""" + return MultimodalMessage(role=role, text=text) + + +def create_image_message( + role: str, + text: str, + images: List[Union[str, Dict, Image.Image]] +) -> MultimodalMessage: + """Create a message with text and images.""" + return MultimodalMessage(role=role, text=text, images=images) + + +def create_video_message( + role: str, + text: str, + videos: List[Dict] +) -> MultimodalMessage: + """Create a message with text and videos.""" + return MultimodalMessage(role=role, text=text, videos=videos) + + +def create_multimodal_message( + role: str, + text: str, + images: Optional[List[Union[str, Dict, Image.Image]]] = None, + videos: Optional[List[Dict]] = None +) -> MultimodalMessage: + """Create a message with text, images, and/or videos.""" + return MultimodalMessage(role=role, text=text, images=images, videos=videos) diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index c4982a6de..5a26b433f 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -40,6 +40,7 @@ def __init__( max_response_length=8192, max_prompt_length=1024, config=None, + processor=None, agent_class=None, env_class=None, agent_args=None, @@ -60,6 +61,7 @@ def __init__( self.config = config self.tokenizer = tokenizer self.engine_name = engine_name + self.processor = processor self.n_parallel_agents = n_parallel_agents self.overlong_filter = overlong_filter @@ -111,13 +113,22 @@ def __init__( config=self.config, rollout_manager=rollout_engine, tokenizer=self.tokenizer, + processor=self.processor, disable_thinking=self.config.rllm.disable_thinking, ) # Create a thread pool executor for environment interactions (i.e. step, reset, close) self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) - async def get_model_response(self, prompt, application_id, **kwargs) -> str: + async def get_model_response( + self, + messages, + application_id, + *, + multimodal_messages=None, + multimodal_data=None, + **kwargs, + ) -> str: """ Compute model response asynchronously based on the engine type. @@ -140,12 +151,20 @@ async def get_model_response(self, prompt, application_id, **kwargs) -> str: sampling_params.update(kwargs) if self.engine_name == "openai": - output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, **sampling_params) + request_messages = messages + output = await self.rollout_engine.get_model_response(request_messages, application_id=application_id, **sampling_params) return output.text elif self.engine_name == "verl": + request_messages = multimodal_messages or messages meta_data = sampling_params.pop("meta_info", {}) validate = meta_data.get("validate", False) - output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, validate=validate, **sampling_params) + output = await self.rollout_engine.get_model_response( + request_messages, + application_id=application_id, + validate=validate, + multimodal_data=multimodal_data, + **sampling_params, + ) return output.text else: raise NotImplementedError(f"Engine type '{self.engine_name}' not supported") @@ -195,11 +214,17 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te # Reset agent agent.reset() # Update agent internal state from environment. + # Extract multimodal data from observation and pass to agent + agent_kwargs = {} + if isinstance(observation, dict) and "images" in observation: + agent_kwargs["images"] = observation["images"] + agent.update_from_env( observation=observation, # Raw observation from environment reward=0.0, done=False, info=info, + **agent_kwargs, ) messages = agent.chat_completions prompt_tokens, _ = convert_messages_to_tokens_and_masks(messages, tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=True, contains_generation_msg=True) @@ -212,6 +237,21 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te for step_idx in range(self.max_steps): # Get action from agent prompt_messages = agent.chat_completions.copy() + + multimodal_messages = None + multimodal_data = None + if hasattr(agent, "get_multimodal_data"): + try: + multimodal_data = agent.get_multimodal_data() + has_multimodal = isinstance(multimodal_data, dict) and ( + multimodal_data.get("image") or multimodal_data.get("video") + ) + if has_multimodal and hasattr(agent, "get_multimodal_conversation"): + multimodal_messages = agent.get_multimodal_conversation() + except Exception as exc: + colorful_print(f"Warning: failed to collect multimodal data: {exc}", "yellow") + multimodal_messages = None + multimodal_data = None # Max remaining tokens left for the response # For enforced max prompt at each step, no need to deduct here if not self.enforce_max_prompt_length: @@ -229,7 +269,13 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te kwargs["max_tokens"] = max_tokens start_time = time.time() - response = await self.get_model_response(prompt_messages, application_id, **kwargs) + response = await self.get_model_response( + prompt_messages, + application_id, + multimodal_messages=multimodal_messages, + multimodal_data=multimodal_data, + **kwargs, + ) delta_time = time.time() - start_time llm_time += delta_time total_time += delta_time @@ -267,11 +313,17 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te info["cur_tokens"] = response_token_len # Update agent internal state. + # Extract multimodal data from next observation and pass to agent + next_agent_kwargs = {} + if isinstance(next_observation, dict) and "images" in next_observation: + next_agent_kwargs["images"] = next_observation["images"] + agent.update_from_env( observation=next_observation, reward=reward, done=done, info=info, + **next_agent_kwargs, ) cur_step = agent.get_current_state() @@ -383,6 +435,9 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te "trajectory_reward": trajectory.reward, "idx": env.idx, "chat_completions": agent.chat_completions, + "prompt_messages": prompt_messages, + "multimodal_messages": multimodal_messages, + "multi_modal_data": multimodal_data, "metrics": { # Total number of steps taken in the trajectory "steps": len(trajectory.steps), diff --git a/rllm/engine/rollout/verl_engine.py b/rllm/engine/rollout/verl_engine.py index cc1476e66..f7a318eb9 100644 --- a/rllm/engine/rollout/verl_engine.py +++ b/rllm/engine/rollout/verl_engine.py @@ -1,4 +1,8 @@ import uuid +from datetime import datetime +from pathlib import Path + +import numpy as np from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine from rllm.parser import ChatTemplateParser, ToolParser @@ -6,11 +10,12 @@ class VerlEngine(RolloutEngine): - def __init__(self, config, rollout_manager, tokenizer, **kwargs): + def __init__(self, config, rollout_manager, tokenizer, processor=None, **kwargs): self.config = config self.rollout_manager = rollout_manager self.server_manager = AsyncLLMServerManager(config, rollout_manager.async_llm_servers) self.tokenizer = tokenizer + self.processor = processor # Store processor for multimodal processing self.chat_parser = ChatTemplateParser.get_parser(self.tokenizer, disable_thinking=kwargs.get("disable_thinking", False)) try: @@ -21,10 +26,13 @@ def __init__(self, config, rollout_manager, tokenizer, **kwargs): self.validate = False - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + async def get_model_response(self, messages: list[dict], multimodal_messages: list[dict] | None = None, **kwargs) -> ModelOutput: application_id = kwargs.pop("application_id", str(uuid.uuid4())) validate = self.validate or kwargs.pop("validate", False) + # Extract multimodal data if present + multimodal_data = kwargs.pop("multimodal_data", None) + if validate: sampling_params = dict( temperature=0.0 if self.config.actor_rollout_ref.rollout.val_kwargs.do_sample is False else self.config.actor_rollout_ref.rollout.val_kwargs.temperature, @@ -41,10 +49,164 @@ async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutpu max_tokens = sampling_params.pop("max_tokens", self.config.data.max_response_length) - prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True) - prompt_ids = self.tokenizer.encode(prompt) + base_messages = multimodal_messages if multimodal_messages is not None else messages + + def ensure_text_messages(messages_list: list[dict]) -> list[dict]: + text_only = [] + for msg in messages_list: + new_msg = dict(msg) + content = new_msg.get("content") + if isinstance(content, list): + text_parts = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + text_parts.append(item.get("text", "")) + new_msg["content"] = " ".join(part for part in text_parts if part).strip() + text_only.append(new_msg) + return text_only + + model_inputs = None + images = None + videos = None + + if self.processor is not None: + raw_prompt = self.processor.apply_chat_template( + base_messages, + add_generation_prompt=True, + tokenize=False, + ) + + images = multimodal_data.get("image") if isinstance(multimodal_data, dict) else None + if isinstance(images, (list, tuple)) and len(images) == 0: + images = None + videos = multimodal_data.get("video") if isinstance(multimodal_data, dict) else None + if isinstance(videos, (list, tuple)) and len(videos) == 0: + videos = None + + model_inputs = self.processor( + text=[raw_prompt], + images=images, + videos=videos, + return_tensors="pt", + ) + + prompt_ids = model_inputs["input_ids"][0].tolist() + else: + text_messages = ensure_text_messages(base_messages) + raw_prompt = self.chat_parser.parse(text_messages, add_generation_prompt=True, is_first_msg=True) + prompt_ids = self.tokenizer.encode(raw_prompt) - response_ids = await self.server_manager.generate(request_id=application_id, prompt_ids=prompt_ids, sampling_params=sampling_params) + try: + prompt_tokens = self.tokenizer.convert_ids_to_tokens(prompt_ids) + except Exception: # noqa: BLE001 + prompt_tokens = None + + # Debug logging disabled; uncomment for detailed prompt inspection. + # print("\n=== FINAL PROMPT TO MODEL ===") + # print(raw_prompt) + # print(f"Prompt ids ({len(prompt_ids)}): {prompt_ids}") + # if prompt_tokens is not None: + # print(f"Prompt tokens: {prompt_tokens}") + # if model_inputs is not None: + # pixel_values = model_inputs.get("pixel_values") + # if pixel_values is not None: + # try: + # print(f"pixel_values shape: {tuple(pixel_values.shape)}") + # except Exception: # noqa: BLE001 + # print(f"pixel_values type: {type(pixel_values)}") + # image_grid = model_inputs.get("image_grid_thw") + # if image_grid is not None: + # try: + # grid_repr = image_grid.tolist() if hasattr(image_grid, "tolist") else image_grid + # except Exception: # noqa: BLE001 + # grid_repr = image_grid + # print(f"image_grid_thw: {grid_repr}") + # if images: + # payload = images if isinstance(images, (list, tuple)) else [images] + # print(f"image payload count: {len(payload)}") + # for idx, image in enumerate(payload): + # size = getattr(image, "size", None) + # print(f"image[{idx}] type={type(image)} size={size}") + # if videos: + # payload_v = videos if isinstance(videos, (list, tuple)) else [videos] + # print(f"video payload count: {len(payload_v)}") + # print("=== END MODEL PROMPT ===\n") + + if multimodal_data and (multimodal_data.get("image") or multimodal_data.get("video")): + debug_logged = getattr(self, "_geo3k_multimodal_debug", 0) + if debug_logged < 10: + debug_path = Path("outputs/debug_geo3k_multimodal.log") + debug_path.parent.mkdir(parents=True, exist_ok=True) + with debug_path.open("a", encoding="utf-8") as fp: + fp.write("\n" + "=" * 80 + "\n") + fp.write(f"{datetime.now().isoformat()} | multimodal request\n") + fp.write(f"Messages: {base_messages}\n") + fp.write(f"Has images: {len(multimodal_data.get('image', []))}\n") + if self.processor is None: + fp.write("WARNING: No multimodal processor available; falling back to text-only tokenization.\n") + else: + fp.write("Multimodal processor detected; logging tokenizer inputs.\n") + fp.write(f"Raw prompt string: {raw_prompt}\n") + fp.write(f"Prompt id count: {len(prompt_ids)}\n") + fp.write(f"Prompt ids: {prompt_ids}\n") + try: + token_strings = self.tokenizer.convert_ids_to_tokens(prompt_ids) + fp.write(f"Prompt tokens: {token_strings}\n") + except Exception as exc: # noqa: BLE001 + fp.write(f"Failed to convert prompt ids to tokens: {exc}\n") + + model_input_keys = sorted(list(model_inputs.keys())) if "model_inputs" in locals() else [] + fp.write(f"Model input keys: {model_input_keys}\n") + + if images: + payload = images if isinstance(images, (list, tuple)) else [images] + fp.write(f"Image payload types: {[type(img).__name__ for img in payload]}\n") + for idx, image in enumerate(payload): + size = getattr(image, "size", None) + fp.write(f"Image[{idx}] info: type={type(image)}, size={size}\n") + self._geo3k_multimodal_debug = debug_logged + 1 + + # Use Verl's generate_sequences for multimodal data + from verl import DataProto + + # Create batch with multimodal data in Verl format + non_tensor_batch = { + "raw_prompt_ids": np.array([prompt_ids], dtype=object), + "raw_prompt": np.array([np.array(base_messages, dtype=object)], dtype=object), + "raw_prompt_text": np.array([raw_prompt], dtype=object), + "multi_modal_data": np.array([multimodal_data], dtype=object), + } + + batch = DataProto.from_dict(non_tensors=non_tensor_batch) + + # Update sampling params in rollout manager config temporarily + original_config = {} + for key, value in sampling_params.items(): + if hasattr(self.rollout_manager.config.actor_rollout_ref.rollout, key): + original_config[key] = getattr(self.rollout_manager.config.actor_rollout_ref.rollout, key) + setattr(self.rollout_manager.config.actor_rollout_ref.rollout, key, value) + + try: + # Use rollout_manager's generate_sequences for multimodal + output_batch = self.rollout_manager.generate_sequences(batch) + + # Extract response_ids from output + if hasattr(output_batch, 'batch') and 'responses' in output_batch.batch: + response_ids = output_batch.batch['responses'][0].tolist() + else: + raise RuntimeError("Failed to get response from multimodal generation") + + finally: + # Restore original config + for key, value in original_config.items(): + setattr(self.rollout_manager.config.actor_rollout_ref.rollout, key, value) + else: + # Original text-only path + response_ids = await self.server_manager.generate( + request_id=application_id, + prompt_ids=prompt_ids, + sampling_params=sampling_params, + ) # verl sets max_tokens as max_model_len - len(prompt_ids), where max_model_len is config.data.max_prompt_length + config.data.max_response_length # so we truncate the response to max_tokens if it exceeds max_tokens diff --git a/rllm/environments/base/multi_turn_env.py b/rllm/environments/base/multi_turn_env.py index 5f4aa9aa1..408072d7b 100644 --- a/rllm/environments/base/multi_turn_env.py +++ b/rllm/environments/base/multi_turn_env.py @@ -61,7 +61,9 @@ def step(self, action): # Check if we've reached the maximum number of turns if self.current_turn >= self.max_turns: self.done = True - return {}, reward, self.done, self.task + # For multimodal tasks, preserve task info even when done + final_obs = self.task if isinstance(self.task, dict) else {} + return final_obs, reward, self.done, self.task return next_obs, reward, self.done, self.task diff --git a/rllm/environments/base/single_turn_env.py b/rllm/environments/base/single_turn_env.py index 49e5549a4..a8a16b786 100644 --- a/rllm/environments/base/single_turn_env.py +++ b/rllm/environments/base/single_turn_env.py @@ -37,7 +37,17 @@ def get_reward_and_next_obs(self, task: dict, action: Any) -> tuple[float, dict] """ reward_output = self.reward_fn(task_info=task, action=action) - return reward_output.reward, {} + metadata = getattr(reward_output, "metadata", {}) or {} + data_source = None + if isinstance(task, dict): + data_source = task.get("data_source") + if metadata and (data_source == "hiyouga/geometry3k" or metadata.get("data_source") == "hiyouga/geometry3k"): + print(f"DEBUG: GEO3K reward metadata -> {metadata}") + + # For single-turn environments, preserve the original task for next observation + # This ensures multimodal data (like images) is maintained + next_obs = task if isinstance(task, dict) else {} + return reward_output.reward, next_obs @staticmethod def from_dict(env_args: dict) -> "SingleTurnEnvironment": diff --git a/rllm/rewards/math_reward.py b/rllm/rewards/math_reward.py index 4a4c17bbb..a7a462381 100644 --- a/rllm/rewards/math_reward.py +++ b/rllm/rewards/math_reward.py @@ -1,8 +1,9 @@ -""" -This module contains the RewardMathFn class, which evaluates mathematical answers -and assigns rewards based on their correctness. It utilizes a language model to -validate answers when necessary. -""" +"""Reward helpers for math-style tasks, including GEO3K multimodal problems.""" + +from rllm.agents.agent import Action +from pathlib import Path +from datetime import datetime +import re from rllm.globals import THOUGHT_DELIMITER_END from rllm.rewards.math_utils.utils import extract_answer, grade_answer_mathd, grade_answer_sympy @@ -41,11 +42,94 @@ def __call__(self, task_info: dict, action: str) -> RewardOutput: # problem = task_info.get("problem", "") model_response = action - # Handle None or empty response - if model_response is None or model_response == "": + # Extract raw text when an Action wrapper is provided + if isinstance(model_response, Action): + model_response = model_response.action + + # Normalize response into a string for downstream processing + if model_response is None: + print("DEBUG: Empty or None response") + return RewardOutput(reward=self.config.format_error_reward, is_correct=False) + + if not isinstance(model_response, str): + model_response = str(model_response) + + model_response = model_response.strip() + + # Handle empty response after stripping whitespace + if model_response == "": print("DEBUG: Empty or None response") return RewardOutput(reward=self.config.format_error_reward, is_correct=False) + data_source = task_info.get("data_source") + + if data_source == "hiyouga/geometry3k": + model_response = self._normalize_geo3k_response(model_response) + geo3k_reward = None + try: + from verl.utils.reward_score import geo3k as geo3k_reward # type: ignore + except ImportError: + try: + from verl.verl.utils.reward_score import geo3k as geo3k_reward # type: ignore + except ImportError: + module_path = Path(__file__).resolve().parents[2] / "verl" / "verl" / "utils" / "reward_score" / "geo3k.py" + if module_path.exists(): + import importlib.util + + spec = importlib.util.spec_from_file_location("rllm_vendor_verl_geo3k", module_path) + if spec and spec.loader: + geo3k_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(geo3k_module) + geo3k_reward = geo3k_module # type: ignore + if geo3k_reward is None: + print("DEBUG: Failed to import geo3k reward module: module not found") + return RewardOutput(reward=self.config.unk_error_reward, is_correct=False) + + ground_truths = task_info.get("ground_truth") + if ground_truths is None: + return RewardOutput(reward=self.config.unk_error_reward, is_correct=False) + + if isinstance(ground_truths, (list, tuple, set)): + gt_list = [str(gt) for gt in ground_truths] + else: + gt_list = [str(ground_truths)] + + format_score = geo3k_reward.format_reward(model_response) + accuracy_scores = [geo3k_reward.acc_reward(model_response, gt) for gt in gt_list] + score_candidates = [geo3k_reward.compute_score(model_response, gt) for gt in gt_list] + + reward = float(max(score_candidates)) if score_candidates else self.config.unk_error_reward + max_accuracy = max(accuracy_scores, default=0.0) + + if format_score < 1.0: + print("DEBUG: GEO3K format violation detected; response missing required structure") + if max_accuracy == 0.0: + print("DEBUG: GEO3K accuracy check failed for all ground truths") + + if format_score < 1.0: + debug_limit = 50 + counter = getattr(self, "_geo3k_debug_logged", 0) + if counter < debug_limit: + debug_path = Path("outputs/debug_geo3k_responses.log") + debug_path.parent.mkdir(parents=True, exist_ok=True) + with debug_path.open("a", encoding="utf-8") as fp: + fp.write("\n" + "=" * 80 + "\n") + fp.write(f"{datetime.now().isoformat()} | format={format_score:.3f} | max_acc={max_accuracy:.3f}\n") + fp.write("Model response:\n") + fp.write(model_response + "\n") + fp.write("Ground truths:\n") + for gt in gt_list: + fp.write(str(gt) + "\n") + self._geo3k_debug_logged = counter + 1 + + metadata = { + "data_source": data_source, + "geo3k_accuracy_scores": accuracy_scores, + "geo3k_format_reward": format_score, + } + + return RewardOutput(reward=reward, metadata=metadata, is_correct=bool(max_accuracy)) + # Extract solution. if THOUGHT_DELIMITER_END in model_response: model_solution = model_response.split(THOUGHT_DELIMITER_END)[1] @@ -93,6 +177,33 @@ def __call__(self, task_info: dict, action: str) -> RewardOutput: return RewardOutput(reward=self.config.incorrect_reward, is_correct=False) + @staticmethod + def _normalize_geo3k_response(model_response: str) -> str: + response = model_response + + final_answer_pattern = re.compile(r"Final answer\s*:\s*\\boxed\{.*?\}", re.DOTALL) + match = final_answer_pattern.search(response) + if not match: + return response + + final_start = match.start() + prefix = response[:final_start] + suffix = response[final_start:] + + has_think = "" in prefix + has_think_end = "" in prefix + + if has_think and not has_think_end: + prefix = prefix + "\n" + elif not has_think: + cleaned_prefix = prefix.strip() + if cleaned_prefix: + prefix = f"{cleaned_prefix}\n" + else: + prefix = "\n" + + return prefix + suffix + def rllm_reward_fn_math(data_source: str, llm_solution: str, ground_truth: str | list[str], extra_info=None, **kwargs): """Evaluates mathematical solutions against ground truth answers. diff --git a/rllm/trainer/config/geo3k_multimodal_trainer.yaml b/rllm/trainer/config/geo3k_multimodal_trainer.yaml new file mode 100644 index 000000000..699a3e704 --- /dev/null +++ b/rllm/trainer/config/geo3k_multimodal_trainer.yaml @@ -0,0 +1,288 @@ +hydra: + searchpath: + - pkg://verl.trainer.config + +defaults: + - ppo_trainer + - _self_ + +# Enable multimodal processing +multimodal: + enable: true + +# Model configuration for multimodal training +actor_rollout_ref: + model: + path: "Qwen/Qwen2.5-VL-7B-Instruct" + trust_remote_code: true + use_shm: false + hybrid_engine: true + actor: + strategy: fsdp + ppo_mini_batch_size: 64 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: 2 + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + policy_loss: + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + entropy_coeff: 0.01 + use_kl_loss: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + checkpoint: + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + optim: + lr: 5.0e-07 + lr_warmup_steps_ratio: 0.1 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + min_lr_ratio: 0.0 + num_cycles: 0.5 + warmup_style: constant + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + ref: + strategy: ${actor_rollout_ref.actor.strategy} + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + fsdp_config: + param_offload: false + reshard_after_forward: true + forward_prefetch: false + wrap_policy: + min_num_params: 0 + ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + rollout: + name: sglang + mode: async + temperature: 1.0 + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,2048} + response_length: ${oc.select:data.max_response_length,2048} + dtype: bfloat16 + gpu_memory_utilization: 0.85 + ignore_eos: false + enforce_eager: false + free_cache_engine: true + tensor_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 256 + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + multi_stage_wake_up: false + multi_turn: + enable: true + max_assistant_turns: 5 + agent: + num_workers: 0 + val_kwargs: + do_sample: true + engine_kwargs: + sglang: + swap_space: null + disable_mm_preprocessor_cache: false + +# Critic configuration +critic: + strategy: fsdp + use_torch_compile: true + critic_loss_coef: 1.0 + value_head_mode: separate + use_kl_loss: false + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: 2 + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + shuffle: false + checkpoint: + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + optim: + lr: 5.0e-07 + lr_warmup_steps_ratio: 0.1 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + min_lr_ratio: 0.0 + num_cycles: 0.5 + warmup_style: constant + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + +# Data configuration for multimodal training +data: + max_prompt_length: 2048 + max_response_length: 2048 + train_batch_size: 64 + micro_batch_size_per_gpu: 1 + val_batch_size: 64 + val_micro_batch_size_per_gpu: 1 + train_files: null + val_files: null + dataset: null + gen_batch_size: ${mul:${data.train_batch_size},${rllm.rejection_sample.multiplier}} + return_raw_chat: true + return_multi_modal_inputs: true + trust_remote_code: true + pad_to_length: null + max_epochs: 10 + shuffle_before_pack: true + pack_sequences: false + tokenizer_chat_template: null + response_only_loss_mask: true + normalize_reward: false + filter_length: -1 + max_token_len: null + balance_sequence_length: false + balance_sequence_length_mode: packed_token + balance_sequence_length_sequence_weight: 0.7 + balance_sequence_length_balance_weight: 0.3 + balance_sequence_length_max_ratio: 10.0 + parquet_data_file_fraction: 1.0 + +# RLLM-specific configuration +rllm: + agent: + name: geo3k_agent + max_steps: 10 + trajectory_timeout: null + overlong_filter: false + agent_args: + accumulate_thinking: true + include_images_in_completion: true + engine_args: {} + env: + name: math + env_args: {} + workflow: + use_workflow: true + name: multi_turn_workflow + workflow_args: + max_prompt_length: ${data.max_prompt_length} + max_response_length: ${data.max_response_length} + timeout: 1e6 + gamma: 1.0 + reward_bonus_coeff: 0.0 + accumulate_response_length: true + n_parallel_tasks: 64 + retry_limit: 3 + disable_thinking: false + mask_truncated_samples: false + stepwise_advantage: + enable: false + mode: broadcast + normalize_by_steps: false + compact_filtering: + enable: false + compact_ratio: 0.8 + compact_mode: packed_token + rejection_sample: + enable: false + multiplier: 4 + reward_threshold: 0.0 + keep_best_n: 1 + +# Algorithm configuration +algorithm: + kl_penalty: 0.05 + kl_target: 6.0 + gamma: 1.0 + lam: 0.95 + cliprange: 0.2 + cliprange_value: 0.2 + vf_coef: 0.5 + entropy_bonus: 0.01 + use_kl_in_reward: false + adv_estimator: gae + normalization: + advantage: true + value_target: false + +# Reward model configuration +reward_model: + enable: false + path: null + trust_remote_code: true + +# Trainer configuration +trainer: + n_gpus_per_node: 4 + nnodes: 1 + save_freq: 100 + eval_freq: 50 + logging_freq: 10 + save_dir: ./checkpoints/geo3k_multimodal + project_name: geo3k_multimodal_training + experiment_name: ${oc.env:USER,unknown}_geo3k_${now:%Y%m%d_%H%M%S} + tags: ["multimodal", "geo3k", "qwen2vl"] + disable_fast_tokenizer: false + profile_steps: null + controller_nsight_options: + trace: ["cuda", "osrt", "nvtx"] + output: "./nsight_traces" + force_overwrite: true + use_legacy_worker_impl: auto + +# Ray configuration +ray_init: + num_cpus: 16 + timeline_json_file: null diff --git a/rllm/trainer/env_agent_mappings.py b/rllm/trainer/env_agent_mappings.py index 4807464ee..4a1575591 100644 --- a/rllm/trainer/env_agent_mappings.py +++ b/rllm/trainer/env_agent_mappings.py @@ -24,6 +24,7 @@ def safe_import(module_path, class_name): "tool_agent": safe_import("rllm.agents.tool_agent", "ToolAgent"), "sweagent": safe_import("rllm.agents.swe_agent", "SWEAgent"), "math_agent": safe_import("rllm.agents.math_agent", "MathAgent"), + "geo3k_agent": safe_import("rllm.agents.geo3k_agent", "Geo3kAgent"), "code_agent": safe_import("rllm.agents.code_agent", "CompetitionCodingAgent"), } diff --git a/rllm/trainer/verl/agent_ppo_trainer.py b/rllm/trainer/verl/agent_ppo_trainer.py index 734fef051..e0294ef1a 100644 --- a/rllm/trainer/verl/agent_ppo_trainer.py +++ b/rllm/trainer/verl/agent_ppo_trainer.py @@ -14,6 +14,7 @@ from omegaconf import OmegaConf from rllm.engine.agent_execution_engine import AsyncAgentExecutionEngine +from rllm.parser import ChatTemplateParser from verl import DataProto from verl.protocol import pad_dataproto_to_divisor from verl.trainer.ppo.ray_trainer import ( @@ -45,12 +46,109 @@ def __init__( agent_class=None, env_args=None, agent_args=None, + processor=None, ): - super().__init__(config=config, tokenizer=tokenizer, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn) + super().__init__(config=config, tokenizer=tokenizer, processor=processor, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn) self.env_class = env_class self.agent_class = agent_class self.env_args = env_args or {} self.agent_args = agent_args or {} + self.processor = processor + disable_thinking = False + if hasattr(self.config, "rllm") and hasattr(self.config.rllm, "disable_thinking"): + disable_thinking = bool(self.config.rllm.disable_thinking) + self.chat_parser = ChatTemplateParser.get_parser(self.tokenizer, disable_thinking=disable_thinking) + self._mm_debug_prints = 0 + + def _dump_multimodal_debug(self, stage: str, data_proto: DataProto, max_prints: int = 6) -> None: + """Best-effort logging of multimodal payloads for debugging.""" + if self._mm_debug_prints >= max_prints: + return + + self._mm_debug_prints += 1 + + try: + batch_size = len(data_proto) + except Exception: # noqa: BLE001 + batch_size = "unknown" + + print("\n[MM DEBUG] stage=", stage, " batch_size=", batch_size) + try: + print(" non_tensor_keys=", list(data_proto.non_tensor_batch.keys())) + except Exception as exc: # noqa: BLE001 + print(f" failed to list non_tensor_keys: {exc}") + + # Inspect first element when available + try: + if len(data_proto) == 0: + return + except Exception: # noqa: BLE001 + return + + try: + sample = data_proto[0] + except Exception as exc: # noqa: BLE001 + print(f" failed to fetch sample: {exc}") + return + + sample_non_tensor = getattr(sample, "non_tensor_batch", {}) or {} + sample_batch = getattr(sample, "batch", None) + + mm_data = sample_non_tensor.get("multi_modal_data") + mm_inputs = sample_non_tensor.get("multi_modal_inputs") + raw_prompt_text = sample_non_tensor.get("raw_prompt_text") + + if raw_prompt_text: + print(" raw_prompt_text_head=", raw_prompt_text[:160]) + + if isinstance(mm_data, dict): + print(" multi_modal_data keys=", list(mm_data.keys())) + for key, value in mm_data.items(): + if isinstance(value, list): + print(f" {key}: list(len={len(value)})") + else: + print(f" {key}: type={type(value)}") + else: + print(" multi_modal_data type=", type(mm_data)) + + if isinstance(mm_inputs, dict) and mm_inputs: + print(" multi_modal_inputs summary:") + for key, value in mm_inputs.items(): + try: + if torch.is_tensor(value): + print(f" {key}: tensor shape={tuple(value.shape)} dtype={value.dtype}") + flat = value.flatten() + preview = flat[: min(10, flat.shape[0])].tolist() + print(f" sample[{key}]: {preview}") + elif isinstance(value, list) and value and torch.is_tensor(value[0]): + shapes = [tuple(v.shape) for v in value] + print(f" {key}: list[tensor] shapes={shapes}") + flat = value[0].flatten() + preview = flat[: min(10, flat.shape[0])].tolist() + print(f" sample[{key}[0]]: {preview}") + else: + print(f" {key}: type={type(value)}") + except Exception as exc: # noqa: BLE001 + print(f" {key}: failed to inspect ({exc})") + else: + print(" multi_modal_inputs type=", type(mm_inputs)) + + if sample_batch is not None: + try: + prompts_tensor = data_proto.batch["prompts"][0] + prompt_ids = prompts_tensor.tolist() if hasattr(prompts_tensor, "tolist") else [] + decoded_prompt = self.tokenizer.decode([pid for pid in prompt_ids if pid != self.tokenizer.pad_token_id]) + tokens_head = prompt_ids[:64] + try: + token_strs = self.tokenizer.convert_ids_to_tokens(tokens_head) + except Exception: # noqa: BLE001 + token_strs = "" + print(" prompt_token_ids_head=", tokens_head) + print(" prompt_token_strs_head=", token_strs) + print(" decoded_prompt_head=", decoded_prompt[:160]) + except Exception as exc: # noqa: BLE001 + print(f" failed to decode prompt tokens: {exc}") + assert self.config.actor_rollout_ref.hybrid_engine, "Only hybrid engine is supported" assert self.config.actor_rollout_ref.rollout.mode == "async", "Only async rollout mode is supported" @@ -68,6 +166,7 @@ def init_workers(self): config=self.config, engine_name="verl", tokenizer=self.tokenizer, + processor=self.processor, model_path=self.config.actor_rollout_ref.model.path, max_steps=self.config.rllm.agent.max_steps, max_response_length=self.config.data.max_response_length, @@ -162,7 +261,17 @@ def fit_agent(self): metrics = {} timing_raw = {} - batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) + non_tensor_keys_to_pop = [ + "raw_prompt", + "raw_prompt_text", + "raw_prompt_ids", + "multi_modal_data", + "multi_modal_inputs", + ] + batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=[key for key in non_tensor_keys_to_pop if key in batch.non_tensor_batch], + ) with marked_timer("step", timing_raw): self.init_envs_and_agents(batch) @@ -175,10 +284,12 @@ def fit_agent(self): final_gen_batch_output.meta_info.pop("repeat_counts", None) # no longer needed after this # batch needs to be padded to divisor of world size, we will pad with everything masked out batch = batch.union(final_gen_batch_output) + self._dump_multimodal_debug("train_union_steps", batch) batch = self._pad_dataproto_to_world_size(batch=batch) else: final_gen_batch_output, generate_metrics = self.generate_agent_trajectory(timing_raw=timing_raw, meta_info=batch.meta_info) batch = batch.union(final_gen_batch_output) + self._dump_multimodal_debug("train_union_traj", batch) metrics.update(generate_metrics) # compute values @@ -425,7 +536,17 @@ def _validate_agent(self): test_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object) n_val_samples = self.config.actor_rollout_ref.rollout.val_kwargs.n test_batch = test_batch.repeat(repeat_times=n_val_samples, interleave=True) - test_batch.pop(["input_ids", "attention_mask", "position_ids"]) # these are not needed for environment based interaction + non_tensor_keys_to_pop = [ + "raw_prompt", + "raw_prompt_text", + "raw_prompt_ids", + "multi_modal_data", + "multi_modal_inputs", + ] + test_batch.pop( + ["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=[key for key in non_tensor_keys_to_pop if key in test_batch.non_tensor_batch], + ) # remove tensors and multimodal caches before agent rollout test_batch.meta_info = { "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, @@ -445,6 +566,7 @@ def _validate_agent(self): test_output_gen_batch, _ = self.generate_agent_trajectory(meta_info=test_batch.meta_info) test_batch = test_batch.union(test_output_gen_batch) + self._dump_multimodal_debug("validate_union", test_batch) reward_tensor = test_batch.batch["token_level_scores"] @@ -567,9 +689,24 @@ def _transform_agent_trajectories(self, trajectories: list[dict]): all_masks_list = [] traj_scores = [] chat_completions = [] + raw_prompt_texts = [] + raw_prompt_ids_list = [] + multi_modal_data_list = [] + multi_modal_inputs_list = [] traj_metrics = [] metrics = {} + def _move_to_cpu(obj): + if torch.is_tensor(obj): + return obj.detach().cpu() + if isinstance(obj, dict): + return {k: _move_to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_move_to_cpu(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_move_to_cpu(v) for v in obj) + return obj + for traj in trajectories: prompt_tokens = traj["prompt_tokens"] response_tokens = traj["response_tokens"] @@ -580,6 +717,69 @@ def _transform_agent_trajectories(self, trajectories: list[dict]): all_masks_list.append(traj["response_masks"]) traj_scores.append(traj["trajectory_reward"]) chat_completions.append(traj["chat_completions"]) + prompt_messages = traj.get("prompt_messages") or traj.get("chat_completions") or [] + multimodal_messages = traj.get("multimodal_messages") or [] + + multi_modal_data = traj.get("multi_modal_data") + if not isinstance(multi_modal_data, dict): + multi_modal_data = {} + else: + multi_modal_data = dict(multi_modal_data) + + raw_prompt_text = None + raw_prompt_ids = None + multi_modal_inputs = {} + + if self.processor is not None and multimodal_messages: + try: + raw_prompt_text = self.processor.apply_chat_template( + multimodal_messages, + add_generation_prompt=True, + tokenize=False, + ) + + images = multi_modal_data.get("image") if multi_modal_data else None + if isinstance(images, list) and len(images) == 0: + images = None + videos = multi_modal_data.get("video") if multi_modal_data else None + if isinstance(videos, list) and len(videos) == 0: + videos = None + + model_inputs = self.processor( + text=[raw_prompt_text], + images=images, + videos=videos, + return_tensors="pt", + ) + raw_prompt_ids = model_inputs["input_ids"][0].tolist() + multi_modal_inputs = { + key: _move_to_cpu(value) + for key, value in model_inputs.items() + if key not in {"input_ids", "attention_mask"} + } + except Exception as exc: # noqa: BLE001 + print(f"Warning: failed to build multimodal inputs: {exc}") + raw_prompt_text = None + raw_prompt_ids = None + multi_modal_inputs = {} + + if raw_prompt_text is None and prompt_messages: + try: + raw_prompt_text = self.chat_parser.parse( + prompt_messages, + add_generation_prompt=True, + is_first_msg=True, + ) + except Exception: # noqa: BLE001 + raw_prompt_text = None + + if raw_prompt_ids is None: + raw_prompt_ids = prompt_tokens.tolist() + + raw_prompt_texts.append(raw_prompt_text) + raw_prompt_ids_list.append(raw_prompt_ids) + multi_modal_data_list.append(_move_to_cpu(multi_modal_data)) + multi_modal_inputs_list.append(_move_to_cpu(multi_modal_inputs)) traj_metrics.append(traj["metrics"]) # Flatten traj_metrics into a dict of lists @@ -655,9 +855,20 @@ def _transform_agent_trajectories(self, trajectories: list[dict]): "response_mask": traj_mask, } - self.visualize_trajectory(DataProto.from_dict(tensors=tensor_batch)) + non_tensor_batch = { + "raw_prompt": np.array(raw_prompt_texts, dtype=object), + "raw_prompt_text": np.array(raw_prompt_texts, dtype=object), + "raw_prompt_ids": np.array(raw_prompt_ids_list, dtype=object), + "multi_modal_data": np.array(multi_modal_data_list, dtype=object), + "multi_modal_inputs": np.array(multi_modal_inputs_list, dtype=object), + } - return DataProto.from_dict(tensors=tensor_batch), metrics + traj_dataproto = DataProto.from_dict(tensors=tensor_batch, non_tensors=non_tensor_batch) + + self.visualize_trajectory(traj_dataproto) + self._dump_multimodal_debug("transform_trajectory", traj_dataproto) + + return traj_dataproto, metrics def visualize_trajectory(self, tensor_batch, sample_idx=0, max_samples=1, mask_key="response_mask"): """ @@ -881,6 +1092,7 @@ def _transform_agent_steps(self, steps: list[dict], uids: np.ndarray): sample_indices = np.random.choice(last_step_indices, size=min(2, len(last_step_indices)), replace=False) for idx in sample_indices: self.visualize_trajectory(result, sample_idx=idx, max_samples=1) + self._dump_multimodal_debug("transform_steps", result) return result def _stepwise_advantage_broadcast(self, last_step_batch, other_step_batch): @@ -956,3 +1168,12 @@ def _pad_dataproto_to_world_size(self, batch): batch.non_tensor_batch["is_pad_step"][idx] = True return batch + + def shutdown(self): + """Gracefully release resources used by the async agent execution engine.""" + engine = getattr(self, "agent_execution_engine", None) + if engine is not None and hasattr(engine, "shutdown"): + try: + engine.shutdown() + except Exception as exc: # noqa: BLE001 + print(f"Warning: failed to shutdown agent execution engine: {exc}") diff --git a/rllm/trainer/verl/train_agent_ppo.py b/rllm/trainer/verl/train_agent_ppo.py index f05966fc2..b5668242d 100644 --- a/rllm/trainer/verl/train_agent_ppo.py +++ b/rllm/trainer/verl/train_agent_ppo.py @@ -89,11 +89,20 @@ def run(self, config, workflow_class=None, workflow_args=None, agent_class=None, # Instantiate the tokenizer and processor. from verl.utils import hf_tokenizer + from verl.utils.tokenizer import hf_processor trust_remote_code = config.data.get("trust_remote_code", False) tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - # Used for multimodal LLM, could be None - # processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Enable processor for multimodal models + processor = None + if config.get("multimodal", {}).get("enable", False): + try: + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + print(f"Multimodal processor loaded: {type(processor)}") + except Exception as e: + print(f"Failed to load multimodal processor: {e}") + print("Falling back to text-only training") # Define worker classes based on the actor strategy. if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: @@ -192,6 +201,7 @@ def run(self, config, workflow_class=None, workflow_args=None, agent_class=None, trainer = AgentPPOTrainer( config=config, tokenizer=tokenizer, + processor=processor, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls,