Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 89 additions & 41 deletions rllm/agents/swe_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,69 @@
import logging
import re

try:
from r2egym.agenthub.action import Action as SWEAction
except ImportError:
SWEAction = None
from r2egym.agenthub.action import Action as SWEAction

from rllm.agents.agent import Action, BaseAgent, Step, Trajectory
from rllm.agents.system_prompts import SWE_SYSTEM_PROMPT, SWE_SYSTEM_PROMPT_FN_CALL, SWE_USER_PROMPT, SWE_USER_PROMPT_FN_CALL, SWEAGENT_SYSTEM_PROMPT, SWEAGENT_USER_PROMPT
from rllm.engine.rollout.rollout_engine import ModelOutput
from rllm.parser.chat_template_parser import ChatTemplateParser

TOKEN_WARNING_THRESHOLD = 28000


def parse_oai_response(response):
thought = response.choices[0].message.content
if not thought:
thought = ""
try:
function_name = response.choices[0].message.tool_calls[0].function.name
parameters = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
action = SWEAction(function_name, parameters)
except Exception:
action = SWEAction(function_name="", parameters={})
return thought, action
# Mapping of scaffold types to their tool schema definitions
# These are imported directly from R2E-Gym


def get_tools_for_scaffold(scaffold: str = "sweagile"):
"""
Get the OpenAI function calling tools schema for a given scaffold.

Args:
scaffold: The scaffold type ("r2egym", "sweagent", or "sweagile")

Returns:
List of tool schemas in OpenAI function calling format
"""
from r2egym.agenthub.tools import (
execute_bash_tool,
file_editor,
finish_tool,
r2egym_bash_execute_tool,
search_tool,
str_replace_editor_tool,
submit_tool,
)

if scaffold == "r2egym":
return [
file_editor,
search_tool,
r2egym_bash_execute_tool,
finish_tool,
]
elif scaffold == "sweagent":
return [
str_replace_editor_tool,
execute_bash_tool,
submit_tool,
]
raise ValueError(f"Invalid scaffold: {scaffold}")


def parse_oai_response(response: ModelOutput) -> tuple[str, SWEAction]:
if isinstance(response, ModelOutput):
content = response.content
if len(response.tool_calls) == 0:
logger.warning(f"No tool calls found in the ModelOutput. Last 500 chars of the response: ...{response.text[-500:]} Returning empty action.")
return content, SWEAction(function_name="", parameters={})
if not isinstance(response.tool_calls[0].arguments, dict):
logger.warning(f"Arguments is not a dict, got {type(response.tool_calls[0].arguments)}: {response.tool_calls[0].arguments}")
response.tool_calls[0].arguments = {}
action = SWEAction(function_name=response.tool_calls[0].name, parameters=response.tool_calls[0].arguments)
return content, action
else:
raise ValueError(f"Invalid response type: {type(response)}. Expected ChatCompletion or ModelOutput object.")


def parse_xml_response(response_text: str) -> tuple[str, SWEAction]:
Expand Down Expand Up @@ -59,18 +100,22 @@ def parse_xml_response(response_text: str) -> tuple[str, SWEAction]:


class SWEAgent(BaseAgent):
def __init__(self, use_fn_calling: bool = False, format_model_response: bool = False, scaffold: str = "r2egym"):
self.use_fn_calling = use_fn_calling
self.format_model_response = format_model_response
def __init__(self, use_fn_calling: bool = True, scaffold: str = "r2egym", chat_template_parser: ChatTemplateParser = None, **kwargs):
self.use_tool_calling = use_fn_calling
self.scaffold = scaffold
assert scaffold in ["r2egym", "sweagent"], f"Invalid scaffold: {scaffold}, must be one of ['r2egym', 'sweagent']"
self.system_prompt = SWE_SYSTEM_PROMPT_FN_CALL if use_fn_calling else SWE_SYSTEM_PROMPT
self.system_prompt = SWE_SYSTEM_PROMPT_FN_CALL if self.use_tool_calling else SWE_SYSTEM_PROMPT
if scaffold == "sweagent":
self.system_prompt = SWEAGENT_SYSTEM_PROMPT
self.user_prompt_template = SWE_USER_PROMPT_FN_CALL if use_fn_calling else SWE_USER_PROMPT
self.user_prompt_template = SWE_USER_PROMPT_FN_CALL if self.use_tool_calling else SWE_USER_PROMPT
if scaffold == "sweagent":
self.user_prompt_template = SWEAGENT_USER_PROMPT

self.chat_template_parser = chat_template_parser
if self.use_tool_calling:
tools_schema = json.dumps(get_tools_for_scaffold(scaffold))
self.tools_prompt = self.chat_template_parser.tool_parser.get_tool_prompt(tools_schema)

self._trajectory = Trajectory()
self.reset()

Expand All @@ -88,7 +133,7 @@ def process_model_response(self, response: str) -> tuple[str, str]:
- The action string in XML format
- The processed response (may be reformatted if self.format_model_response is True)
"""
if self.use_fn_calling:
if self.use_tool_calling:
thought, action = parse_oai_response(response)
else:
thought, action = parse_xml_response(response)
Expand Down Expand Up @@ -130,38 +175,42 @@ def update_from_env(self, observation, reward, done, info):
self.messages.append({"role": "user", "content": observation})
self.cur_step = Step(observation=observation)

def update_from_model(self, response: str, **kwargs):
def update_from_model(self, model_output: ModelOutput, **kwargs) -> Action:
"""
Updates the agent's internal state after an environment step.

This function is called during environment interaction to incorporate the latest action's
outcome into the agent's learning process.

Args:
response (str): The response from the model.

model_output ModelOutput: The response from the model.
Returns:
None
Action: The action to take.
"""
response = model_output.text
self._trajectory.steps.append(self.cur_step)
if self.use_fn_calling:
thought, action = parse_oai_response(response)

if self.use_tool_calling:
content, action = parse_oai_response(model_output)
else:
thought, action = parse_xml_response(response)
action_str = action.to_xml_string()
content, action = parse_xml_response(response)
if len(model_output.tool_calls) > 0:
action_str = self.chat_template_parser.tool_parser.tool_call_to_str(model_output.tool_calls[0])
else:
action_str = ""
logger.debug(f"update_from_model: action_str: {action_str}")
assert self._trajectory.steps, "Trajectory should not be empty when update_from_model is called."

# Update Trajectory
cur_step = self._trajectory.steps[-1]
cur_step.thought = thought
cur_step.action = action_str
cur_step.reasoning = model_output.reasoning
cur_step.content = model_output.content
cur_step.text = model_output.text
cur_step.action = action
cur_step.model_response = response

# Update Chat Completions
if self.format_model_response:
self.messages.append({"role": "assistant", "content": f"{thought}\n\n{action_str}"})
else:
self.messages.append({"role": "assistant", "content": response})
self.messages.append({"role": "assistant", "content": response})
self.step += 1
return Action(action=cur_step.action)

Expand All @@ -171,12 +220,11 @@ def get_current_state(self) -> Step:

def reset(self):
self._trajectory = Trajectory()
self.messages = [
{
"role": "system",
"content": self.system_prompt,
}
]
if self.use_tool_calling:
prompt = self.system_prompt + self.tools_prompt
else:
prompt = self.system_prompt
self.messages = [{"role": "system", "content": prompt}]
self.step = 0

@property
Expand Down
30 changes: 20 additions & 10 deletions rllm/engine/agent_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
convert_messages_to_tokens_and_masks,
get_recent_assistant_user_messages,
)
from rllm.engine.rollout.rollout_engine import ModelOutput
from rllm.environments.base.base_env import BaseEnv
from rllm.environments.env_utils import (
compute_mc_return,
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(

self.agent_class = agent_class
self.agent_args = agent_args
self.agent_args["chat_template_parser"] = self.chat_parser
self.env_class = env_class
self.env_args = env_args

Expand Down Expand Up @@ -117,7 +119,7 @@ def __init__(
disable_thinking=self.disable_thinking,
)

async def get_model_response(self, prompt, application_id, **kwargs) -> str:
async def get_model_response(self, prompt, application_id, **kwargs) -> ModelOutput:
"""
Compute model response asynchronously based on the engine type.

Expand All @@ -135,20 +137,18 @@ async def get_model_response(self, prompt, application_id, **kwargs) -> str:
Raises:
NotImplementedError: If the engine type is not supported
"""

sampling_params = self.sampling_params.copy()
sampling_params.update(kwargs)

if self.engine_name == "openai":
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, enforce_max_prompt_length=False, **sampling_params)
return output.text
return output
elif self.engine_name == "verl":
meta_data = sampling_params.pop("meta_info", {})
validate = meta_data.get("validate", False)
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, validate=validate, enforce_max_prompt_length=False, **sampling_params)
return output.text
else:
raise NotImplementedError(f"Engine type '{self.engine_name}' not supported")
return output

raise NotImplementedError(f"Engine type '{self.engine_name}' not supported")

def update_envs_and_agents(self, envs, agents):
"""
Expand Down Expand Up @@ -227,20 +227,30 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te

kwargs["max_tokens"] = max_tokens

# Add tools for rollout_engine if agent provides them (for use_tool_calling)
# May be duplicated if already added tool to sys prompt in agent.reset()
if hasattr(agent, "get_tools_for_rollout_engine") and callable(agent.get_tools_for_rollout_engine) and hasattr(agent, "use_tool_calling") and agent.use_tool_calling:
tools = agent.get_tools_for_rollout_engine()
if tools:
kwargs["tools"] = tools

start_time = time.time()
response = await self.get_model_response(prompt_messages, application_id, **kwargs)
# Return ModelOutput instead of response text, for agent.update_from_model
# response text -> ModelOutput in rollout_engines
# So no need to get ModelOutput.text and re-parse it in agent.update_from_model
model_output = await self.get_model_response(prompt_messages, application_id, **kwargs)
delta_time = time.time() - start_time
llm_time += delta_time
total_time += delta_time
# Update steps
prompt_response_pair = {
"prompt": self.chat_parser.parse(prompt_messages, add_generation_prompt=True, is_first_msg=True),
"response": response,
"response": model_output.text,
}
episode_steps.append(prompt_response_pair)

# Update agent with model response
action: Action = agent.update_from_model(response)
action: Action = agent.update_from_model(model_output)
action = action.action

# Take step in environment using the executor
Expand Down
2 changes: 2 additions & 0 deletions rllm/trainer/verl/agent_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from omegaconf import OmegaConf

from rllm.engine.agent_execution_engine import AsyncAgentExecutionEngine
from rllm.parser.chat_template_parser import ChatTemplateParser
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor
from verl.trainer.ppo.core_algos import agg_loss
Expand Down Expand Up @@ -91,6 +92,7 @@ def init_envs_and_agents(self, batch):
env_args = batch.non_tensor_batch["extra_info"].tolist()

full_agent_args = dict(self.config.rllm.agent.get("agent_args", {})) | self.agent_args
full_agent_args["chat_template_parser"] = ChatTemplateParser.get_parser(self.tokenizer, self.config.rllm.disable_thinking)
base_env_args = dict(self.config.rllm.env.get("env_args", {})) | self.env_args

def _create_env(i):
Expand Down