diff --git a/rllm/agents/swe_agent.py b/rllm/agents/swe_agent.py index 05271d5f..feab672d 100644 --- a/rllm/agents/swe_agent.py +++ b/rllm/agents/swe_agent.py @@ -2,28 +2,69 @@ import logging import re -try: - from r2egym.agenthub.action import Action as SWEAction -except ImportError: - SWEAction = None +from r2egym.agenthub.action import Action as SWEAction from rllm.agents.agent import Action, BaseAgent, Step, Trajectory from rllm.agents.system_prompts import SWE_SYSTEM_PROMPT, SWE_SYSTEM_PROMPT_FN_CALL, SWE_USER_PROMPT, SWE_USER_PROMPT_FN_CALL, SWEAGENT_SYSTEM_PROMPT, SWEAGENT_USER_PROMPT +from rllm.engine.rollout.rollout_engine import ModelOutput +from rllm.parser.chat_template_parser import ChatTemplateParser TOKEN_WARNING_THRESHOLD = 28000 -def parse_oai_response(response): - thought = response.choices[0].message.content - if not thought: - thought = "" - try: - function_name = response.choices[0].message.tool_calls[0].function.name - parameters = json.loads(response.choices[0].message.tool_calls[0].function.arguments) - action = SWEAction(function_name, parameters) - except Exception: - action = SWEAction(function_name="", parameters={}) - return thought, action +# Mapping of scaffold types to their tool schema definitions +# These are imported directly from R2E-Gym + + +def get_tools_for_scaffold(scaffold: str = "sweagile"): + """ + Get the OpenAI function calling tools schema for a given scaffold. + + Args: + scaffold: The scaffold type ("r2egym", "sweagent", or "sweagile") + + Returns: + List of tool schemas in OpenAI function calling format + """ + from r2egym.agenthub.tools import ( + execute_bash_tool, + file_editor, + finish_tool, + r2egym_bash_execute_tool, + search_tool, + str_replace_editor_tool, + submit_tool, + ) + + if scaffold == "r2egym": + return [ + file_editor, + search_tool, + r2egym_bash_execute_tool, + finish_tool, + ] + elif scaffold == "sweagent": + return [ + str_replace_editor_tool, + execute_bash_tool, + submit_tool, + ] + raise ValueError(f"Invalid scaffold: {scaffold}") + + +def parse_oai_response(response: ModelOutput) -> tuple[str, SWEAction]: + if isinstance(response, ModelOutput): + content = response.content + if len(response.tool_calls) == 0: + logger.warning(f"No tool calls found in the ModelOutput. Last 500 chars of the response: ...{response.text[-500:]} Returning empty action.") + return content, SWEAction(function_name="", parameters={}) + if not isinstance(response.tool_calls[0].arguments, dict): + logger.warning(f"Arguments is not a dict, got {type(response.tool_calls[0].arguments)}: {response.tool_calls[0].arguments}") + response.tool_calls[0].arguments = {} + action = SWEAction(function_name=response.tool_calls[0].name, parameters=response.tool_calls[0].arguments) + return content, action + else: + raise ValueError(f"Invalid response type: {type(response)}. Expected ChatCompletion or ModelOutput object.") def parse_xml_response(response_text: str) -> tuple[str, SWEAction]: @@ -59,18 +100,22 @@ def parse_xml_response(response_text: str) -> tuple[str, SWEAction]: class SWEAgent(BaseAgent): - def __init__(self, use_fn_calling: bool = False, format_model_response: bool = False, scaffold: str = "r2egym"): - self.use_fn_calling = use_fn_calling - self.format_model_response = format_model_response + def __init__(self, use_fn_calling: bool = True, scaffold: str = "r2egym", chat_template_parser: ChatTemplateParser = None, **kwargs): + self.use_tool_calling = use_fn_calling self.scaffold = scaffold assert scaffold in ["r2egym", "sweagent"], f"Invalid scaffold: {scaffold}, must be one of ['r2egym', 'sweagent']" - self.system_prompt = SWE_SYSTEM_PROMPT_FN_CALL if use_fn_calling else SWE_SYSTEM_PROMPT + self.system_prompt = SWE_SYSTEM_PROMPT_FN_CALL if self.use_tool_calling else SWE_SYSTEM_PROMPT if scaffold == "sweagent": self.system_prompt = SWEAGENT_SYSTEM_PROMPT - self.user_prompt_template = SWE_USER_PROMPT_FN_CALL if use_fn_calling else SWE_USER_PROMPT + self.user_prompt_template = SWE_USER_PROMPT_FN_CALL if self.use_tool_calling else SWE_USER_PROMPT if scaffold == "sweagent": self.user_prompt_template = SWEAGENT_USER_PROMPT + self.chat_template_parser = chat_template_parser + if self.use_tool_calling: + tools_schema = json.dumps(get_tools_for_scaffold(scaffold)) + self.tools_prompt = self.chat_template_parser.tool_parser.get_tool_prompt(tools_schema) + self._trajectory = Trajectory() self.reset() @@ -88,7 +133,7 @@ def process_model_response(self, response: str) -> tuple[str, str]: - The action string in XML format - The processed response (may be reformatted if self.format_model_response is True) """ - if self.use_fn_calling: + if self.use_tool_calling: thought, action = parse_oai_response(response) else: thought, action = parse_xml_response(response) @@ -130,7 +175,7 @@ def update_from_env(self, observation, reward, done, info): self.messages.append({"role": "user", "content": observation}) self.cur_step = Step(observation=observation) - def update_from_model(self, response: str, **kwargs): + def update_from_model(self, model_output: ModelOutput, **kwargs) -> Action: """ Updates the agent's internal state after an environment step. @@ -138,30 +183,34 @@ def update_from_model(self, response: str, **kwargs): outcome into the agent's learning process. Args: - response (str): The response from the model. - + model_output ModelOutput: The response from the model. Returns: - None + Action: The action to take. """ + response = model_output.text self._trajectory.steps.append(self.cur_step) - if self.use_fn_calling: - thought, action = parse_oai_response(response) + + if self.use_tool_calling: + content, action = parse_oai_response(model_output) else: - thought, action = parse_xml_response(response) - action_str = action.to_xml_string() + content, action = parse_xml_response(response) + if len(model_output.tool_calls) > 0: + action_str = self.chat_template_parser.tool_parser.tool_call_to_str(model_output.tool_calls[0]) + else: + action_str = "" + logger.debug(f"update_from_model: action_str: {action_str}") assert self._trajectory.steps, "Trajectory should not be empty when update_from_model is called." # Update Trajectory cur_step = self._trajectory.steps[-1] - cur_step.thought = thought - cur_step.action = action_str + cur_step.reasoning = model_output.reasoning + cur_step.content = model_output.content + cur_step.text = model_output.text + cur_step.action = action cur_step.model_response = response # Update Chat Completions - if self.format_model_response: - self.messages.append({"role": "assistant", "content": f"{thought}\n\n{action_str}"}) - else: - self.messages.append({"role": "assistant", "content": response}) + self.messages.append({"role": "assistant", "content": response}) self.step += 1 return Action(action=cur_step.action) @@ -171,12 +220,11 @@ def get_current_state(self) -> Step: def reset(self): self._trajectory = Trajectory() - self.messages = [ - { - "role": "system", - "content": self.system_prompt, - } - ] + if self.use_tool_calling: + prompt = self.system_prompt + self.tools_prompt + else: + prompt = self.system_prompt + self.messages = [{"role": "system", "content": prompt}] self.step = 0 @property diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index fc16b2f9..694eedec 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -12,6 +12,7 @@ convert_messages_to_tokens_and_masks, get_recent_assistant_user_messages, ) +from rllm.engine.rollout.rollout_engine import ModelOutput from rllm.environments.base.base_env import BaseEnv from rllm.environments.env_utils import ( compute_mc_return, @@ -74,6 +75,7 @@ def __init__( self.agent_class = agent_class self.agent_args = agent_args + self.agent_args["chat_template_parser"] = self.chat_parser self.env_class = env_class self.env_args = env_args @@ -117,7 +119,7 @@ def __init__( disable_thinking=self.disable_thinking, ) - async def get_model_response(self, prompt, application_id, **kwargs) -> str: + async def get_model_response(self, prompt, application_id, **kwargs) -> ModelOutput: """ Compute model response asynchronously based on the engine type. @@ -135,20 +137,18 @@ async def get_model_response(self, prompt, application_id, **kwargs) -> str: Raises: NotImplementedError: If the engine type is not supported """ - sampling_params = self.sampling_params.copy() sampling_params.update(kwargs) - if self.engine_name == "openai": output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, enforce_max_prompt_length=False, **sampling_params) - return output.text + return output elif self.engine_name == "verl": meta_data = sampling_params.pop("meta_info", {}) validate = meta_data.get("validate", False) output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, validate=validate, enforce_max_prompt_length=False, **sampling_params) - return output.text - else: - raise NotImplementedError(f"Engine type '{self.engine_name}' not supported") + return output + + raise NotImplementedError(f"Engine type '{self.engine_name}' not supported") def update_envs_and_agents(self, envs, agents): """ @@ -227,20 +227,30 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te kwargs["max_tokens"] = max_tokens + # Add tools for rollout_engine if agent provides them (for use_tool_calling) + # May be duplicated if already added tool to sys prompt in agent.reset() + if hasattr(agent, "get_tools_for_rollout_engine") and callable(agent.get_tools_for_rollout_engine) and hasattr(agent, "use_tool_calling") and agent.use_tool_calling: + tools = agent.get_tools_for_rollout_engine() + if tools: + kwargs["tools"] = tools + start_time = time.time() - response = await self.get_model_response(prompt_messages, application_id, **kwargs) + # Return ModelOutput instead of response text, for agent.update_from_model + # response text -> ModelOutput in rollout_engines + # So no need to get ModelOutput.text and re-parse it in agent.update_from_model + model_output = await self.get_model_response(prompt_messages, application_id, **kwargs) delta_time = time.time() - start_time llm_time += delta_time total_time += delta_time # Update steps prompt_response_pair = { "prompt": self.chat_parser.parse(prompt_messages, add_generation_prompt=True, is_first_msg=True), - "response": response, + "response": model_output.text, } episode_steps.append(prompt_response_pair) # Update agent with model response - action: Action = agent.update_from_model(response) + action: Action = agent.update_from_model(model_output) action = action.action # Take step in environment using the executor diff --git a/rllm/trainer/verl/agent_ppo_trainer.py b/rllm/trainer/verl/agent_ppo_trainer.py index 9cf908b5..789b3336 100644 --- a/rllm/trainer/verl/agent_ppo_trainer.py +++ b/rllm/trainer/verl/agent_ppo_trainer.py @@ -14,6 +14,7 @@ from omegaconf import OmegaConf from rllm.engine.agent_execution_engine import AsyncAgentExecutionEngine +from rllm.parser.chat_template_parser import ChatTemplateParser from verl import DataProto from verl.protocol import pad_dataproto_to_divisor from verl.trainer.ppo.core_algos import agg_loss @@ -91,6 +92,7 @@ def init_envs_and_agents(self, batch): env_args = batch.non_tensor_batch["extra_info"].tolist() full_agent_args = dict(self.config.rllm.agent.get("agent_args", {})) | self.agent_args + full_agent_args["chat_template_parser"] = ChatTemplateParser.get_parser(self.tokenizer, self.config.rllm.disable_thinking) base_env_args = dict(self.config.rllm.env.get("env_args", {})) | self.env_args def _create_env(i):