From 8351f66d0214c19f21e77546f3f35eb4bd805141 Mon Sep 17 00:00:00 2001 From: LianShuQuan <1286152658@qq.com> Date: Sun, 26 Oct 2025 22:36:10 +0800 Subject: [PATCH 1/7] fix controlling the n_parallel_agents and the concurrent env operations --- rllm/engine/agent_execution_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index fc16b2f90..0da37c0c4 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -117,6 +117,7 @@ def __init__( disable_thinking=self.disable_thinking, ) + async def get_model_response(self, prompt, application_id, **kwargs) -> str: """ Compute model response asynchronously based on the engine type. From 080027bfd7bce3f3d494d9e356b0068cecbe37c7 Mon Sep 17 00:00:00 2001 From: LianShuQuan <1286152658@qq.com> Date: Mon, 27 Oct 2025 12:51:18 +0800 Subject: [PATCH 2/7] applied pre-commit, fixed unused-import --- rllm/engine/agent_execution_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index 0da37c0c4..fc16b2f90 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -117,7 +117,6 @@ def __init__( disable_thinking=self.disable_thinking, ) - async def get_model_response(self, prompt, application_id, **kwargs) -> str: """ Compute model response asynchronously based on the engine type. From 7d50d009fd0edd7e5984c24bb395a150e6712fe5 Mon Sep 17 00:00:00 2001 From: LianShuQuan <1286152658@qq.com> Date: Tue, 4 Nov 2025 16:51:16 +0800 Subject: [PATCH 3/7] tool call for SWEAgent --- rllm/agents/swe_agent.py | 131 ++++++++++++++++++-------- rllm/engine/agent_execution_engine.py | 7 ++ 2 files changed, 97 insertions(+), 41 deletions(-) diff --git a/rllm/agents/swe_agent.py b/rllm/agents/swe_agent.py index 05271d5fe..40ec89eb1 100644 --- a/rllm/agents/swe_agent.py +++ b/rllm/agents/swe_agent.py @@ -2,28 +2,69 @@ import logging import re -try: - from r2egym.agenthub.action import Action as SWEAction -except ImportError: - SWEAction = None +from r2egym.agenthub.action import Action as SWEAction from rllm.agents.agent import Action, BaseAgent, Step, Trajectory from rllm.agents.system_prompts import SWE_SYSTEM_PROMPT, SWE_SYSTEM_PROMPT_FN_CALL, SWE_USER_PROMPT, SWE_USER_PROMPT_FN_CALL, SWEAGENT_SYSTEM_PROMPT, SWEAGENT_USER_PROMPT +from rllm.engine.rollout.rollout_engine import ModelOutput +from rllm.parser.chat_template_parser import ChatTemplateParser TOKEN_WARNING_THRESHOLD = 28000 -def parse_oai_response(response): - thought = response.choices[0].message.content - if not thought: - thought = "" - try: - function_name = response.choices[0].message.tool_calls[0].function.name - parameters = json.loads(response.choices[0].message.tool_calls[0].function.arguments) - action = SWEAction(function_name, parameters) - except Exception: - action = SWEAction(function_name="", parameters={}) - return thought, action +# Mapping of scaffold types to their tool schema definitions +# These are imported directly from R2E-Gym + + +def get_tools_for_scaffold(scaffold: str = "sweagile"): + """ + Get the OpenAI function calling tools schema for a given scaffold. + + Args: + scaffold: The scaffold type ("r2egym", "sweagent", or "sweagile") + + Returns: + List of tool schemas in OpenAI function calling format + """ + from r2egym.agenthub.tools import ( + execute_bash_tool, + file_editor, + finish_tool, + r2egym_bash_execute_tool, + search_tool, + str_replace_editor_tool, + submit_tool, + ) + + if scaffold == "r2egym": + return [ + file_editor, + search_tool, + r2egym_bash_execute_tool, + finish_tool, + ] + elif scaffold == "sweagent": + return [ + str_replace_editor_tool, + execute_bash_tool, + submit_tool, + ] + raise ValueError(f"Invalid scaffold: {scaffold}") + + +def parse_oai_response(response: ModelOutput) -> tuple[str, SWEAction]: + if isinstance(response, ModelOutput): + content = response.content + if len(response.tool_calls) == 0: + logger.warning(f"No tool calls found in the ModelOutput. Last 500 chars of the response: ...{response.text[-500:]} Returning empty action.") + return content, SWEAction(function_name="", parameters={}) + if not isinstance(response.tool_calls[0].arguments, dict): + logger.warning(f"Arguments is not a dict, got {type(response.tool_calls[0].arguments)}: {response.tool_calls[0].arguments}") + response.tool_calls[0].arguments = {} + action = SWEAction(function_name=response.tool_calls[0].name, parameters=response.tool_calls[0].arguments) + return content, action + else: + raise ValueError(f"Invalid response type: {type(response)}. Expected ChatCompletion or ModelOutput object.") def parse_xml_response(response_text: str) -> tuple[str, SWEAction]: @@ -59,18 +100,23 @@ def parse_xml_response(response_text: str) -> tuple[str, SWEAction]: class SWEAgent(BaseAgent): - def __init__(self, use_fn_calling: bool = False, format_model_response: bool = False, scaffold: str = "r2egym"): - self.use_fn_calling = use_fn_calling - self.format_model_response = format_model_response + def __init__(self, use_tool_calling: bool = True, scaffold: str = "r2egym", chat_template_parser: ChatTemplateParser = None, accumulate_reasoning: bool = False, **kwargs): + self.use_tool_calling = use_tool_calling self.scaffold = scaffold + self.accumulate_reasoning = accumulate_reasoning assert scaffold in ["r2egym", "sweagent"], f"Invalid scaffold: {scaffold}, must be one of ['r2egym', 'sweagent']" - self.system_prompt = SWE_SYSTEM_PROMPT_FN_CALL if use_fn_calling else SWE_SYSTEM_PROMPT + self.system_prompt = SWE_SYSTEM_PROMPT_FN_CALL if use_tool_calling else SWE_SYSTEM_PROMPT if scaffold == "sweagent": self.system_prompt = SWEAGENT_SYSTEM_PROMPT - self.user_prompt_template = SWE_USER_PROMPT_FN_CALL if use_fn_calling else SWE_USER_PROMPT + self.user_prompt_template = SWE_USER_PROMPT_FN_CALL if use_tool_calling else SWE_USER_PROMPT if scaffold == "sweagent": self.user_prompt_template = SWEAGENT_USER_PROMPT + self.chat_template_parser = chat_template_parser + if self.use_tool_calling: + tools_schema = json.dumps(get_tools_for_scaffold(scaffold)) + self.tools_prompt = self.chat_template_parser.tool_parser.get_tool_prompt(tools_schema) + self._trajectory = Trajectory() self.reset() @@ -88,7 +134,7 @@ def process_model_response(self, response: str) -> tuple[str, str]: - The action string in XML format - The processed response (may be reformatted if self.format_model_response is True) """ - if self.use_fn_calling: + if self.use_tool_calling: thought, action = parse_oai_response(response) else: thought, action = parse_xml_response(response) @@ -130,7 +176,7 @@ def update_from_env(self, observation, reward, done, info): self.messages.append({"role": "user", "content": observation}) self.cur_step = Step(observation=observation) - def update_from_model(self, response: str, **kwargs): + def update_from_model(self, model_output: ModelOutput, **kwargs) -> Action: """ Updates the agent's internal state after an environment step. @@ -138,30 +184,34 @@ def update_from_model(self, response: str, **kwargs): outcome into the agent's learning process. Args: - response (str): The response from the model. - + model_output ModelOutput: The response from the model. Returns: - None + Action: The action to take. """ + response = model_output.text self._trajectory.steps.append(self.cur_step) - if self.use_fn_calling: - thought, action = parse_oai_response(response) + + if self.use_tool_calling: + content, action = parse_oai_response(model_output) else: - thought, action = parse_xml_response(response) - action_str = action.to_xml_string() + content, action = parse_xml_response(response) + if len(model_output.tool_calls) > 0: + action_str = self.chat_template_parser.tool_parser.tool_call_to_str(model_output.tool_calls[0]) + else: + action_str = "" + logger.debug(f"update_from_model: action_str: {action_str}") assert self._trajectory.steps, "Trajectory should not be empty when update_from_model is called." # Update Trajectory cur_step = self._trajectory.steps[-1] - cur_step.thought = thought - cur_step.action = action_str + cur_step.reasoning = model_output.reasoning + cur_step.content = model_output.content + cur_step.text = model_output.text + cur_step.action = action cur_step.model_response = response # Update Chat Completions - if self.format_model_response: - self.messages.append({"role": "assistant", "content": f"{thought}\n\n{action_str}"}) - else: - self.messages.append({"role": "assistant", "content": response}) + self.messages.append({"role": "assistant", "content": response}) self.step += 1 return Action(action=cur_step.action) @@ -171,12 +221,11 @@ def get_current_state(self) -> Step: def reset(self): self._trajectory = Trajectory() - self.messages = [ - { - "role": "system", - "content": self.system_prompt, - } - ] + if self.use_tool_calling: + prompt = self.system_prompt + self.tools_prompt + else: + prompt = self.system_prompt + self.messages = [{"role": "system", "content": prompt}] self.step = 0 @property diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index fc16b2f90..a203d3090 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -227,6 +227,13 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te kwargs["max_tokens"] = max_tokens + # Add tools for rollout_engine if agent provides them (for use_tool_calling) + # May be duplicated if already added tool to sys prompt in agent.reset() + if hasattr(agent, "get_tools_for_rollout_engine") and callable(agent.get_tools_for_rollout_engine) and hasattr(agent, "use_tool_calling") and agent.use_tool_calling: + tools = agent.get_tools_for_rollout_engine() + if tools: + kwargs["tools"] = tools + start_time = time.time() response = await self.get_model_response(prompt_messages, application_id, **kwargs) delta_time = time.time() - start_time From 278ad67031e32ac5244c2e4291c7b603a84a40fd Mon Sep 17 00:00:00 2001 From: LianShuQuan <1286152658@qq.com> Date: Tue, 4 Nov 2025 17:17:57 +0800 Subject: [PATCH 4/7] get_model_response() should return ModelOutput --- rllm/engine/agent_execution_engine.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index a203d3090..d904f6714 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -12,6 +12,7 @@ convert_messages_to_tokens_and_masks, get_recent_assistant_user_messages, ) +from rllm.engine.rollout.rollout_engine import ModelOutput from rllm.environments.base.base_env import BaseEnv from rllm.environments.env_utils import ( compute_mc_return, @@ -117,7 +118,7 @@ def __init__( disable_thinking=self.disable_thinking, ) - async def get_model_response(self, prompt, application_id, **kwargs) -> str: + async def get_model_response(self, prompt, application_id, **kwargs) -> ModelOutput: """ Compute model response asynchronously based on the engine type. @@ -135,20 +136,18 @@ async def get_model_response(self, prompt, application_id, **kwargs) -> str: Raises: NotImplementedError: If the engine type is not supported """ - sampling_params = self.sampling_params.copy() sampling_params.update(kwargs) - if self.engine_name == "openai": output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, enforce_max_prompt_length=False, **sampling_params) - return output.text + return output elif self.engine_name == "verl": meta_data = sampling_params.pop("meta_info", {}) validate = meta_data.get("validate", False) output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, validate=validate, enforce_max_prompt_length=False, **sampling_params) - return output.text - else: - raise NotImplementedError(f"Engine type '{self.engine_name}' not supported") + return output + + raise NotImplementedError(f"Engine type '{self.engine_name}' not supported") def update_envs_and_agents(self, envs, agents): """ @@ -235,19 +234,22 @@ async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Te kwargs["tools"] = tools start_time = time.time() - response = await self.get_model_response(prompt_messages, application_id, **kwargs) + # Return ModelOutput instead of response text, for agent.update_from_model + # response text -> ModelOutput in rollout_engines + # So no need to get ModelOutput.text and re-parse it in agent.update_from_model + model_output = await self.get_model_response(prompt_messages, application_id, **kwargs) delta_time = time.time() - start_time llm_time += delta_time total_time += delta_time # Update steps prompt_response_pair = { "prompt": self.chat_parser.parse(prompt_messages, add_generation_prompt=True, is_first_msg=True), - "response": response, + "response": model_output.text, } episode_steps.append(prompt_response_pair) # Update agent with model response - action: Action = agent.update_from_model(response) + action: Action = agent.update_from_model(model_output) action = action.action # Take step in environment using the executor From 3bad0644a0f4b91df157cf5eca3d0824d8c62f5c Mon Sep 17 00:00:00 2001 From: LianShuQuan <1286152658@qq.com> Date: Tue, 4 Nov 2025 17:35:21 +0800 Subject: [PATCH 5/7] 1 --- rllm/agents/swe_agent.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rllm/agents/swe_agent.py b/rllm/agents/swe_agent.py index 40ec89eb1..6bbab8ad2 100644 --- a/rllm/agents/swe_agent.py +++ b/rllm/agents/swe_agent.py @@ -100,10 +100,9 @@ def parse_xml_response(response_text: str) -> tuple[str, SWEAction]: class SWEAgent(BaseAgent): - def __init__(self, use_tool_calling: bool = True, scaffold: str = "r2egym", chat_template_parser: ChatTemplateParser = None, accumulate_reasoning: bool = False, **kwargs): + def __init__(self, use_tool_calling: bool = True, scaffold: str = "r2egym", chat_template_parser: ChatTemplateParser = None, **kwargs): self.use_tool_calling = use_tool_calling self.scaffold = scaffold - self.accumulate_reasoning = accumulate_reasoning assert scaffold in ["r2egym", "sweagent"], f"Invalid scaffold: {scaffold}, must be one of ['r2egym', 'sweagent']" self.system_prompt = SWE_SYSTEM_PROMPT_FN_CALL if use_tool_calling else SWE_SYSTEM_PROMPT if scaffold == "sweagent": From 45942e5b65c18cbfa19fcc9793363781182319cd Mon Sep 17 00:00:00 2001 From: LianShuQuan <1286152658@qq.com> Date: Tue, 4 Nov 2025 17:41:53 +0800 Subject: [PATCH 6/7] use_tool_calling maybe more precise. Still but compatible with use_fn_calling outside SWEAgent --- rllm/agents/swe_agent.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rllm/agents/swe_agent.py b/rllm/agents/swe_agent.py index 6bbab8ad2..feab672d6 100644 --- a/rllm/agents/swe_agent.py +++ b/rllm/agents/swe_agent.py @@ -100,14 +100,14 @@ def parse_xml_response(response_text: str) -> tuple[str, SWEAction]: class SWEAgent(BaseAgent): - def __init__(self, use_tool_calling: bool = True, scaffold: str = "r2egym", chat_template_parser: ChatTemplateParser = None, **kwargs): - self.use_tool_calling = use_tool_calling + def __init__(self, use_fn_calling: bool = True, scaffold: str = "r2egym", chat_template_parser: ChatTemplateParser = None, **kwargs): + self.use_tool_calling = use_fn_calling self.scaffold = scaffold assert scaffold in ["r2egym", "sweagent"], f"Invalid scaffold: {scaffold}, must be one of ['r2egym', 'sweagent']" - self.system_prompt = SWE_SYSTEM_PROMPT_FN_CALL if use_tool_calling else SWE_SYSTEM_PROMPT + self.system_prompt = SWE_SYSTEM_PROMPT_FN_CALL if self.use_tool_calling else SWE_SYSTEM_PROMPT if scaffold == "sweagent": self.system_prompt = SWEAGENT_SYSTEM_PROMPT - self.user_prompt_template = SWE_USER_PROMPT_FN_CALL if use_tool_calling else SWE_USER_PROMPT + self.user_prompt_template = SWE_USER_PROMPT_FN_CALL if self.use_tool_calling else SWE_USER_PROMPT if scaffold == "sweagent": self.user_prompt_template = SWEAGENT_USER_PROMPT From 1a0b9825d96d4abe9c8bd3e2bb6b81575732645d Mon Sep 17 00:00:00 2001 From: LianShuQuan <1286152658@qq.com> Date: Tue, 4 Nov 2025 17:51:41 +0800 Subject: [PATCH 7/7] add chat_template_parser for agent --- rllm/engine/agent_execution_engine.py | 1 + rllm/trainer/verl/agent_ppo_trainer.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/rllm/engine/agent_execution_engine.py b/rllm/engine/agent_execution_engine.py index d904f6714..694eedeca 100644 --- a/rllm/engine/agent_execution_engine.py +++ b/rllm/engine/agent_execution_engine.py @@ -75,6 +75,7 @@ def __init__( self.agent_class = agent_class self.agent_args = agent_args + self.agent_args["chat_template_parser"] = self.chat_parser self.env_class = env_class self.env_args = env_args diff --git a/rllm/trainer/verl/agent_ppo_trainer.py b/rllm/trainer/verl/agent_ppo_trainer.py index 9cf908b57..789b3336f 100644 --- a/rllm/trainer/verl/agent_ppo_trainer.py +++ b/rllm/trainer/verl/agent_ppo_trainer.py @@ -14,6 +14,7 @@ from omegaconf import OmegaConf from rllm.engine.agent_execution_engine import AsyncAgentExecutionEngine +from rllm.parser.chat_template_parser import ChatTemplateParser from verl import DataProto from verl.protocol import pad_dataproto_to_divisor from verl.trainer.ppo.core_algos import agg_loss @@ -91,6 +92,7 @@ def init_envs_and_agents(self, batch): env_args = batch.non_tensor_batch["extra_info"].tolist() full_agent_args = dict(self.config.rllm.agent.get("agent_args", {})) | self.agent_args + full_agent_args["chat_template_parser"] = ChatTemplateParser.get_parser(self.tokenizer, self.config.rllm.disable_thinking) base_env_args = dict(self.config.rllm.env.get("env_args", {})) | self.env_args def _create_env(i):