Skip to content

Commit add2579

Browse files
Conversation: make state.events.append a default callback; remove manual appends in CodeActAgent (#57)
Co-authored-by: openhands <[email protected]>
1 parent 3b8b501 commit add2579

File tree

3 files changed

+68
-12
lines changed

3 files changed

+68
-12
lines changed

openhands/core/agent/codeact_agent/codeact_agent.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ def init_state(
5252
if len(messages) == 0:
5353
# Prepare system message
5454
event = SystemPromptEvent(source="agent", system_prompt=self.system_message, tools=[t.to_openai_tool() for t in self.tools.values()])
55-
# TODO: maybe we should combine this into on_event?
56-
state.events.append(event)
5755
on_event(event)
5856

5957
def step(
@@ -99,15 +97,13 @@ def step(
9997
if action_event is None:
10098
continue
10199
action_events.append(action_event)
102-
state.events.append(action_event)
103100

104101
for action_event in action_events:
105102
self._execute_action_events(state, action_event, on_event=on_event)
106103
else:
107104
logger.info("LLM produced a message response - awaits user input")
108105
state.agent_finished = True
109106
msg_event = MessageEvent(source="agent", llm_message=message)
110-
state.events.append(msg_event)
111107
on_event(msg_event)
112108

113109
def _get_action_events(
@@ -131,7 +127,6 @@ def _get_action_events(
131127
err = f"Tool '{tool_name}' not found. Available: {list(self.tools.keys())}"
132128
logger.error(err)
133129
event = AgentErrorEvent(error=err)
134-
state.events.append(event)
135130
on_event(event)
136131
state.agent_finished = True
137132
return
@@ -142,7 +137,6 @@ def _get_action_events(
142137
except (json.JSONDecodeError, ValidationError) as e:
143138
err = f"Error validating args {tool_call.function.arguments} for tool '{tool.name}': {e}"
144139
event = AgentErrorEvent(error=err)
145-
state.events.append(event)
146140
on_event(event)
147141
return
148142

@@ -172,5 +166,4 @@ def _execute_action_events(self, state: ConversationState, action_event: ActionE
172166
# Set conversation state
173167
if tool.name == FinishTool.name:
174168
state.agent_finished = True
175-
state.events.append(obs_event)
176169
return obs_event

openhands/core/conversation/conversation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,19 @@ class Conversation:
3030
def __init__(self, agent: "AgentBase", callbacks: list[ConversationCallbackType] | None = None, max_iteration_per_run: int = 500, env_context: EnvContext | None = None):
3131
"""Initialize the conversation."""
3232
self._visualizer = ConversationVisualizer()
33-
# Compose multiple callbacks if a list is provided
34-
self._on_event = compose_callbacks([self._visualizer.on_event] + (callbacks if callbacks else []))
35-
self.max_iteration_per_run = max_iteration_per_run
36-
3733
self.agent = agent
3834
self.state = ConversationState()
35+
36+
# Default callback: persist every event to state
37+
def _append_event(e):
38+
self.state.events.append(e)
39+
40+
# Compose callbacks; default appender runs last to keep agent-emitted event order (on_event then persist)
41+
composed_list = [self._visualizer.on_event] + (callbacks if callbacks else []) + [_append_event]
42+
self._on_event = compose_callbacks(composed_list)
43+
44+
self.max_iteration_per_run = max_iteration_per_run
45+
3946
with self.state:
4047
self.agent.init_state(self.state, on_event=self._on_event)
4148

@@ -65,7 +72,6 @@ def send_message(self, message: Message) -> None:
6572
pass
6673

6774
user_msg_event = MessageEvent(source="user", llm_message=message, activated_microagents=activated_microagents)
68-
self.state.events.append(user_msg_event)
6975
self._on_event(user_msg_event)
7076

7177
def run(self) -> None:
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
from typing import List
3+
from unittest.mock import MagicMock
4+
5+
from openhands.core.agent.base import AgentBase
6+
from openhands.core.conversation import Conversation
7+
from openhands.core.conversation.state import ConversationState
8+
from openhands.core.conversation.types import ConversationCallbackType
9+
from openhands.core.event.llm_convertible import MessageEvent, SystemPromptEvent
10+
from openhands.core.llm import Message, TextContent
11+
12+
13+
class DummyAgent(AgentBase):
14+
def __init__(self):
15+
super().__init__(llm=MagicMock(name="LLM"), tools=[])
16+
self.prompt_manager = MagicMock()
17+
18+
def init_state(self, state: ConversationState, on_event: ConversationCallbackType) -> None:
19+
event = SystemPromptEvent(source="agent", system_prompt=TextContent(text="dummy"), tools=[])
20+
on_event(event)
21+
22+
def step(self, state: ConversationState, on_event: ConversationCallbackType) -> None:
23+
on_event(MessageEvent(source="agent", llm_message=Message(role="assistant", content=[TextContent(text="ok")])) )
24+
25+
26+
def test_default_callback_appends_on_init():
27+
agent = DummyAgent()
28+
events_seen: List[str] = []
29+
30+
convo = Conversation(agent=agent, callbacks=[lambda e: events_seen.append(e.id)])
31+
32+
assert len(convo.state.events) == 1
33+
assert isinstance(convo.state.events[0], SystemPromptEvent)
34+
assert convo.state.events[0].id in events_seen
35+
36+
37+
def test_send_message_appends_once():
38+
agent = DummyAgent()
39+
seen_ids: List[str] = []
40+
41+
def user_cb(event):
42+
seen_ids.append(event.id)
43+
44+
convo = Conversation(agent=agent, callbacks=[user_cb])
45+
46+
convo.send_message(Message(role="user", content=[TextContent(text="hi")]))
47+
48+
# Now we should have two events: initial system prompt and the user message
49+
assert len(convo.state.events) == 2
50+
assert isinstance(convo.state.events[-1], MessageEvent)
51+
52+
# Ensure the user message event is appended exactly once in state
53+
last_id = convo.state.events[-1].id
54+
assert sum(1 for e in convo.state.events if e.id == last_id) == 1
55+
56+
# Ensure callback saw both events
57+
assert set(seen_ids) == {e.id for e in convo.state.events}

0 commit comments

Comments
 (0)