Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions examples/geo3k/prepare_geo3k_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#!/usr/bin/env python3
"""Prepare GEO3K multimodal geometry dataset for RLLM training."""

import base64
from io import BytesIO
from typing import Iterable, List

from datasets import load_dataset
from PIL import Image

from rllm.data.dataset import DatasetRegistry

DATA_URI_PREFIX = "data:image/png;base64,"


def _serialize_images(images: Iterable[Image.Image]) -> List[str]:
"""Serialize a list of PIL images into base64 data URIs for Parquet storage."""

serialized: list[str] = []
for image in images or []:
if not isinstance(image, Image.Image):
continue
buffer = BytesIO()
image.convert("RGB").save(buffer, format="PNG")
encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
serialized.append(f"{DATA_URI_PREFIX}{encoded}")
return serialized


def prepare_geo3k_data():
"""
Prepare GEO3K dataset following RLLM conventions.

Returns:
Tuple of (train_dataset, test_dataset) registered with DatasetRegistry
"""
print("📥 Loading GEO3K dataset from HuggingFace...")
data_source = "hiyouga/geometry3k"
dataset = load_dataset(data_source)

train_dataset = dataset["train"]
test_dataset = dataset["test"]

print(f"✅ Dataset loaded:")
print(f" - Train samples: {len(train_dataset)}")
print(f" - Test samples: {len(test_dataset)}")

# Instruction template based on Verl's GEO3K processing
instruction_following = (
"You must strictly follow this answer template: "
"(1) write all reasoning as an internal monologue inside <think>...</think>; "
"(2) after </think>, output exactly one line in the form `Final answer: \\boxed{value}` with the numeric solution."
)

formatting_example = (
"<think>Reason step by step here....</think>\n"
"Final answer: \\boxed{42}"
)

def preprocess_fn(example, idx):
"""
Preprocess function to convert GEO3K data to RLLM format.

This follows the pattern from verl/examples/data_preprocess/geo3k.py
but adapts it for RLLM's simpler format requirements.
"""
problem = example["problem"]
answer = example["answer"]
images = _serialize_images(example.get("images", []))

# Create the full prompt with instruction following
prompt = problem + "\n\n" + instruction_following + "\nExample format:\n" + formatting_example

# Return RLLM-compatible format
return {
"question": prompt, # RLLM expects 'question' field
"ground_truth": answer, # RLLM expects 'ground_truth' field
"data_source": "hiyouga/geometry3k", # Data source identifier matching verl's reward function
"images": images, # Serialized data URIs; reconstructed during training/inference
"extra_info": {
"original_problem": problem,
"answer": answer,
"has_images": len(images) > 0,
"num_images": len(images),
"formatting_example": formatting_example,
}
}

print("🔄 Preprocessing datasets...")

# Apply preprocessing to both splits
train_dataset = train_dataset.map(preprocess_fn, with_indices=True)
test_dataset = test_dataset.map(preprocess_fn, with_indices=True)

# Register datasets with RLLM DatasetRegistry
print("📋 Registering datasets with RLLM...")

train_dataset = DatasetRegistry.register_dataset("geo3k_train", train_dataset, "train")
test_dataset = DatasetRegistry.register_dataset("geo3k_test", test_dataset, "test")

print("✅ Datasets registered:")
print(" - geo3k_train (training data)")
print(" - geo3k_test (evaluation data)")

return train_dataset, test_dataset


if __name__ == "__main__":
train_dataset, test_dataset = prepare_geo3k_data()
print(f"\n📊 Dataset Statistics:")
print(f" - Training samples: {len(train_dataset)}")
print(f" - Test samples: {len(test_dataset)}")

# Show a sample to verify format
sample = train_dataset[0]
print(f"\n📋 Sample data format:")
print(f" - Question length: {len(sample['question'])} chars")
print(f" - Has images: {len(sample['images']) > 0}")
print(f" - Number of images: {len(sample['images'])}")
print(f" - Ground truth: {sample['ground_truth'][:100]}...")

print(f"\n🎉 GEO3K dataset preparation complete!")
print(f"\nNext steps:")
print(f"1. Configure multimodal model in training config")
print(f"2. Run training: python train_geo3k_agent.py")
print(f"3. Test inference: python run_geo3k_agent.py")
168 changes: 168 additions & 0 deletions examples/geo3k/run_geo3k_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#!/usr/bin/env python3
"""
Inference script for GEO3K multimodal geometry agent.

This script tests the trained multimodal agent on GEO3K geometry problems
and evaluates its performance on problems with images.
"""

import asyncio
import os

from transformers import AutoTokenizer

from rllm.agents import Geo3kAgent
from rllm.data.dataset import DatasetRegistry
from rllm.engine.agent_execution_engine import AgentExecutionEngine
from rllm.environments.base.single_turn_env import SingleTurnEnvironment
from rllm.rewards.reward_fn import math_reward_fn
from rllm.utils import compute_pass_at_k


def main():
"""
Main inference function for testing GEO3K multimodal agent.
"""
print("🚀 GEO3K Multimodal Agent Inference")
print("=" * 40)

# Set environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "true"

# Configuration
n_parallel_agents = 16 # Adjust based on your GPU memory
model_name = "Qwen/Qwen2-VL-2B-Instruct" # Default multimodal model

print(f"🔧 Configuration:")
print(f" - Model: {model_name}")
print(f" - Parallel agents: {n_parallel_agents}")

# Initialize tokenizer
print("📝 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Agent configuration for multimodal geometry reasoning
agent_args = {
"accumulate_thinking": True,
"include_images_in_completion": True, # Include image info in completions for debugging
}

# Environment configuration
env_args = {
"reward_fn": math_reward_fn,
}

# Sampling parameters for inference
sampling_params = {
"temperature": 0.7, # Slightly higher for reasoning diversity
"top_p": 0.95,
"max_tokens": 2048,
"model": model_name
}

print("🔧 Initializing agent execution engine...")

# Initialize execution engine
# Note: You can switch between "openai" and "verl" engine
# For multimodal models, "verl" engine with SGLang backend is recommended
engine = AgentExecutionEngine(
agent_class=Geo3kAgent,
agent_args=agent_args,
env_class=SingleTurnEnvironment,
env_args=env_args,
engine_name="openai", # Can be "openai" or "verl"
rollout_engine_args={
"base_url": "http://localhost:30000/v1", # SGLang server URL
"api_key": "None"
},
tokenizer=tokenizer,
sampling_params=sampling_params,
max_response_length=2048,
max_prompt_length=2048,
n_parallel_agents=n_parallel_agents,
)

# Load test dataset
print("📊 Loading test dataset...")
try:
test_dataset = DatasetRegistry.load_dataset("geo3k_test", "test")
print(f"✅ Test dataset loaded: {len(test_dataset)} samples")
except Exception as e:
print(f"❌ Dataset not found: {e}")
print("🔄 Preparing dataset...")
from prepare_geo3k_data import prepare_geo3k_data

_, test_dataset = prepare_geo3k_data()

test_data = test_dataset.get_data()

# Take a smaller subset for quick testing
test_samples = 50 # Adjust as needed
subset_size = min(test_samples, len(test_data))
test_subset = test_data[:subset_size]
print(f"🎯 Testing on {subset_size} samples")

# Check multimodal content statistics
multimodal_count = sum(1 for sample in test_subset if sample.get("images"))
print(f"📸 Multimodal samples: {multimodal_count}/{subset_size}")

# Repeat samples for pass@k evaluation
n_repeats = 4 # Number of attempts per problem
tasks = [sample for sample in test_subset for _ in range(n_repeats)]
print(f"🔄 Total tasks (with repeats): {len(tasks)}")

# Show sample problem
if not test_subset:
print("⚠️ No samples found in test dataset. Exiting.")
return

sample = test_subset[0]
print(f"\n📋 Sample Problem:")
question_preview = sample.get("question", "")[:200]
ground_truth_preview = str(sample.get("ground_truth", ""))[:100]
has_images = bool(sample.get("images"))
print(f" Question: {question_preview}...")
print(f" Ground Truth: {ground_truth_preview}...")
print(f" Has Images: {has_images}")

# Run inference
print("\n🚀 Starting inference...")
print("⏳ This may take a while depending on the number of samples and model speed...")

try:
results = asyncio.run(engine.execute_tasks(tasks))
print(f"\n✅ Inference completed!")
print(f"📊 Results: {len(results)} task completions")

# Compute and display pass@k metrics
print("\n📈 Computing Pass@K metrics...")
pass_at_k_results = compute_pass_at_k(results)

# Display results
print(f"\n🎯 Performance Summary:")
for k, score in pass_at_k_results.items():
print(f" - Pass@{k}: {score:.3f}")

# Show some example results
print(f"\n📝 Example Results:")
for i, result in enumerate(results[:3]): # Show first 3 results
reward = result.get('reward', 0)
success = "✅" if reward > 0.5 else "❌"
print(f" {success} Task {i+1}: Reward = {reward:.3f}")

except Exception as e:
print(f"❌ Inference failed: {e}")
print("💡 Make sure your model server is running:")
print(" SGLang: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-2B-Instruct")
print(" vLLM: python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2-VL-2B-Instruct")
raise

print(f"\n🎉 GEO3K inference completed!")
print(f"💡 For better performance, consider:")
print(f" - Using a larger multimodal model (e.g., Qwen2-VL-7B)")
print(f" - Fine-tuning on GEO3K training data")
print(f" - Adjusting sampling parameters")


if __name__ == "__main__":
main()
Loading
Loading