From 0a869282b23325bfc904aa10aeb8a0ac91ec830b Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Tue, 26 Aug 2025 09:25:06 -0400 Subject: [PATCH 1/4] feat: Implement typed events internally Step 1/N for implementing typed-events; first just preserve the existing behaviors with no changes to the public api. A follow-up change will update how we invoke callbacks and pass invocation_state around, while this one just adds typed classes for events internally. --- src/strands/agent/agent.py | 3 +- src/strands/event_loop/event_loop.py | 31 ++- src/strands/event_loop/streaming.py | 27 +- src/strands/types/_events.py | 238 ++++++++++++++++++ .../strands/agent/hooks/test_agent_events.py | 159 ++++++++++++ tests/strands/event_loop/test_streaming.py | 9 + 6 files changed, 445 insertions(+), 22 deletions(-) create mode 100644 src/strands/types/_events.py create mode 100644 tests/strands/agent/hooks/test_agent_events.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e2aed9d2b..8233c4bfe 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -50,6 +50,7 @@ from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher +from ..types._events import InitEventLoopEvent from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException @@ -604,7 +605,7 @@ async def _run_loop( self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) try: - yield {"callback": {"init_event_loop": True, **invocation_state}} + yield InitEventLoopEvent(invocation_state) for message in messages: self._append_message(message) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 524ecc3e8..84d3fea78 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -25,6 +25,15 @@ from ..telemetry.metrics import Trace from ..telemetry.tracer import get_tracer from ..tools._validator import validate_and_prepare_tools +from ..types._events import ( + EventLoopStopEvent, + EventLoopThrottleEvent, + ForceStopEvent, + ModelMessageEvent, + StartEvent, + StartEventLoopEvent, + ToolResultMessageEvent, +) from ..types.content import Message from ..types.exceptions import ( ContextWindowOverflowException, @@ -91,8 +100,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes) invocation_state["event_loop_cycle_trace"] = cycle_trace - yield {"callback": {"start": True}} - yield {"callback": {"start_event_loop": True}} + yield StartEvent() + yield StartEventLoopEvent() # Create tracer span for this event loop cycle tracer = get_tracer() @@ -175,7 +184,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> if isinstance(e, ModelThrottledException): if attempt + 1 == MAX_ATTEMPTS: - yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}} + yield ForceStopEvent(reason=e) raise e logger.debug( @@ -189,7 +198,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> time.sleep(current_delay) current_delay = min(current_delay * 2, MAX_DELAY) - yield {"callback": {"event_loop_throttled_delay": current_delay, **invocation_state}} + yield EventLoopThrottleEvent(delay=current_delay, invocation_state=invocation_state) else: raise e @@ -201,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Add the response message to the conversation agent.messages.append(message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) - yield {"callback": {"message": message}} + yield ModelMessageEvent(message=message) # Update metrics agent.event_loop_metrics.update_usage(usage) @@ -264,11 +273,11 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> tracer.end_span_with_error(cycle_span, str(e), e) # Handle any other exceptions - yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}} + yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e - yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: @@ -295,7 +304,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id) cycle_trace.add_child(recursive_trace) - yield {"callback": {"start": True}} + yield StartEvent() events = event_loop_cycle(agent=agent, invocation_state=invocation_state) async for event in events: @@ -339,7 +348,7 @@ async def _handle_tool_execution( validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] if not tool_uses: - yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) return tool_events = agent.tool_executor._execute( @@ -358,7 +367,7 @@ async def _handle_tool_execution( agent.messages.append(tool_result_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) - yield {"callback": {"message": tool_result_message}} + yield ToolResultMessageEvent(message=message) if cycle_span: tracer = get_tracer() @@ -366,7 +375,7 @@ async def _handle_tool_execution( if invocation_state["request_state"].get("stop_event_loop", False): agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) - yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) return events = recurse_event_loop(agent=agent, invocation_state=invocation_state) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 1f8c260a4..d21afe09b 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -5,6 +5,13 @@ from typing import Any, AsyncGenerator, AsyncIterable, Optional from ..models.model import Model +from ..types._events import ( + ModelStopReason, + ModelStreamChunkEvent, + ModelStreamEvent, + ToolUseStreamEvent, + TypedEvent, +) from ..types.content import ContentBlock, Message, Messages from ..types.streaming import ( ContentBlockDeltaEvent, @@ -105,7 +112,7 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: def handle_content_block_delta( event: ContentBlockDeltaEvent, state: dict[str, Any] -) -> tuple[dict[str, Any], dict[str, Any]]: +) -> tuple[dict[str, Any], ModelStreamEvent]: """Handles content block delta updates by appending text, tool input, or reasoning content to the state. Args: @@ -117,18 +124,18 @@ def handle_content_block_delta( """ delta_content = event["delta"] - callback_event = {} + typed_event: ModelStreamEvent = ModelStreamEvent({}) if "toolUse" in delta_content: if "input" not in state["current_tool_use"]: state["current_tool_use"]["input"] = "" state["current_tool_use"]["input"] += delta_content["toolUse"]["input"] - callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]} + typed_event = ToolUseStreamEvent(delta_content, state["current_tool_use"]) elif "text" in delta_content: state["text"] += delta_content["text"] - callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content} + typed_event["callback"] = {"data": delta_content["text"], "delta": delta_content} elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: @@ -136,7 +143,7 @@ def handle_content_block_delta( state["reasoningText"] = "" state["reasoningText"] += delta_content["reasoningContent"]["text"] - callback_event["callback"] = { + typed_event["callback"] = { "reasoningText": delta_content["reasoningContent"]["text"], "delta": delta_content, "reasoning": True, @@ -147,13 +154,13 @@ def handle_content_block_delta( state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - callback_event["callback"] = { + typed_event["callback"] = { "reasoning_signature": delta_content["reasoningContent"]["signature"], "delta": delta_content, "reasoning": True, } - return state, callback_event + return state, typed_event def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: @@ -251,7 +258,7 @@ def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: return usage, metrics -async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[dict[str, Any], None]: +async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[TypedEvent, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: @@ -274,7 +281,7 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d metrics: Metrics = Metrics(latencyMs=0) async for chunk in chunks: - yield {"callback": {"event": chunk}} + yield ModelStreamChunkEvent(chunk=chunk) if "messageStart" in chunk: state["message"] = handle_message_start(chunk["messageStart"], state["message"]) elif "contentBlockStart" in chunk: @@ -291,7 +298,7 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d elif "redactContent" in chunk: handle_redact_content(chunk["redactContent"], state) - yield {"stop": (stop_reason, state["message"], usage, metrics)} + yield ModelStopReason(stop_reason=stop_reason, message=state["message"], usage=usage, metrics=metrics) async def stream_messages( diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py new file mode 100644 index 000000000..1bddc5877 --- /dev/null +++ b/src/strands/types/_events.py @@ -0,0 +1,238 @@ +"""event system for the Strands Agents framework. + +This module defines the event types that are emitted during agent execution, +providing a structured way to observe to different events of the event loop and +agent lifecycle. +""" + +from typing import TYPE_CHECKING, Any + +from ..telemetry import EventLoopMetrics +from .content import Message +from .event_loop import Metrics, StopReason, Usage +from .streaming import ContentBlockDelta, StreamEvent + +if TYPE_CHECKING: + pass + + +class TypedEvent(dict): + """Base class for all typed events in the agent system.""" + + def __init__(self, data: dict[str, Any] | None = None) -> None: + """Initialize the typed event with optional data. + + Args: + data: Optional dictionary of event data to initialize with + """ + super().__init__(data or {}) + + +class InitEventLoopEvent(TypedEvent): + """Event emitted at the very beginning of agent execution. + + This event is fired before any processing begins and provides access to the + initial invocation state. + + Args: + invocation_state: The invocation state passed into the request + """ + + def __init__(self, invocation_state: dict) -> None: + """Initialize the event loop initialization event.""" + super().__init__({"callback": {"init_event_loop": True, **invocation_state}}) + + +class StartEvent(TypedEvent): + """Event emitted at the start of each event loop cycle. + + !!deprecated!! + Use StartEventLoopEvent instead. + + This event events the beginning of a new processing cycle within the agent's + event loop. It's fired before model invocation and tool execution begin. + """ + + def __init__(self) -> None: + """Initialize the event loop start event.""" + super().__init__({"callback": {"start": True}}) + + +class StartEventLoopEvent(TypedEvent): + """Event emitted when the event loop cycle begins processing. + + This event is fired after StartEvent and indicates that the event loop + has begun its core processing logic, including model invocation preparation. + """ + + def __init__(self) -> None: + """Initialize the event loop processing start event.""" + super().__init__({"callback": {"start_event_loop": True}}) + + +class ModelStreamChunkEvent(TypedEvent): + """Event emitted during model response streaming for each raw chunk.""" + + def __init__(self, chunk: StreamEvent) -> None: + """Initialize with streaming delta data from the model. + + Args: + chunk: Incremental streaming data from the model response + """ + super().__init__({"callback": {"event": chunk}}) + + +class ModelStreamEvent(TypedEvent): + """Event emitted during model response streaming. + + This event is fired when the model produces streaming output during response + generation. + """ + + def __init__(self, delta_data: dict[str, Any]) -> None: + """Initialize with streaming delta data from the model. + + Args: + delta_data: Incremental streaming data from the model response + """ + super().__init__(delta_data) + + +class ToolUseStreamEvent(ModelStreamEvent): + """Event emitted during tool use input streaming.""" + + def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) -> None: + """Initialize with delta and current tool use state.""" + super().__init__({"callback": {"delta": delta, "current_tool_use": current_tool_use}}) + + +class TextStreamEvent(ModelStreamEvent): + """Event emitted during text content streaming.""" + + def __init__(self, delta: ContentBlockDelta, text: str) -> None: + """Initialize with delta and text content.""" + super().__init__({"callback": {"data": text, "delta": delta}}) + + +class ReasoningTextStreamEvent(ModelStreamEvent): + """Event emitted during reasoning text streaming.""" + + def __init__(self, delta: ContentBlockDelta, reasoning_text: str | None) -> None: + """Initialize with delta and reasoning text.""" + super().__init__({"callback": {"reasoningText": reasoning_text, "delta": delta, "reasoning": True}}) + + +class ReasoningSignatureStreamEvent(ModelStreamEvent): + """Event emitted during reasoning signature streaming.""" + + def __init__(self, delta: ContentBlockDelta, reasoning_signature: str | None) -> None: + """Initialize with delta and reasoning signature.""" + super().__init__({"callback": {"reasoning_signature": reasoning_signature, "delta": delta, "reasoning": True}}) + + +class ModelStopReason(TypedEvent): + """Event emitted during reasoning signature streaming.""" + + def __init__( + self, + stop_reason: StopReason, + message: Message, + usage: Usage, + metrics: Metrics, + ) -> None: + """Initialize with the final execution results. + + Args: + stop_reason: Why the agent execution stopped + message: Final message from the model + usage: Usage information from the model + metrics: Execution metrics and performance data + """ + super().__init__({"stop": (stop_reason, message, usage, metrics)}) + + +class EventLoopStopEvent(TypedEvent): + """Event emitted when the agent execution completes normally.""" + + def __init__( + self, + stop_reason: StopReason, + message: Message, + metrics: "EventLoopMetrics", + request_state: Any, + ) -> None: + """Initialize with the final execution results. + + Args: + stop_reason: Why the agent execution stopped + message: Final message from the model + metrics: Execution metrics and performance data + request_state: Final state of the agent execution + """ + super().__init__({"stop": (stop_reason, message, metrics, request_state)}) + + +class EventLoopThrottleEvent(TypedEvent): + """Event emitted when the event loop is throttled due to rate limiting.""" + + def __init__(self, delay: int, invocation_state: dict[str, Any]) -> None: + """Initialize with the throttle delay duration. + + Args: + delay: Delay in seconds before the next retry attempt + invocation_state: The invocation state passed into the request + """ + super().__init__({"callback": {"event_loop_throttled_delay": delay, **invocation_state}}) + + +class ModelMessageEvent(TypedEvent): + """Event emitted when the model invocation has completed. + + This event is fired whenever the model generates a response message that + gets added to the conversation history. + """ + + def __init__(self, message: Message) -> None: + """Initialize with the model-generated message. + + Args: + message: The response message from the model + """ + super().__init__({"callback": {"message": message}}) + + +class ToolResultMessageEvent(TypedEvent): + """Event emitted when tool results are formatted as a message. + + This event is fired when tool execution results are converted into a + message format to be added to the conversation history. It provides + access to the formatted message containing tool results. + """ + + def __init__(self, message: Any) -> None: + """Initialize with the model-generated message. + + Args: + message: Message containing tool results for conversation history + """ + super().__init__({"callback": {"message": message}}) + + +class ForceStopEvent(TypedEvent): + """Event emitted when the agent execution is forcibly stopped, either by a tool or by an exception.""" + + def __init__(self, reason: str | Exception) -> None: + """Initialize with the reason for forced stop. + + Args: + reason: String description or exception that caused the forced stop + """ + super().__init__( + { + "callback": { + "force_stop": True, + "force_stop_reason": str(reason), + # "force_stop_reason_exception": reason if reason and isinstance(reason, Exception) else MISSING, + } + } + ) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py new file mode 100644 index 000000000..d63dd97d4 --- /dev/null +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -0,0 +1,159 @@ +import unittest.mock +from unittest.mock import ANY, MagicMock, call + +import pytest + +import strands +from strands import Agent +from strands.agent import AgentResult +from strands.types._events import TypedEvent +from strands.types.exceptions import ModelThrottledException +from tests.fixtures.mocked_model_provider import MockedModelProvider + + +@pytest.fixture +def mock_time(): + with unittest.mock.patch.object(strands.event_loop.event_loop, "time") as mock: + yield mock + + +@pytest.mark.asyncio +async def test_stream_async_e2e(alist, mock_time): + @strands.tool + def fake_tool(agent: Agent): + return "Done!" + + mock_provider = MockedModelProvider( + [ + {"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}, + {"role": "assistant", "content": [{"text": "Okay invoking tool!"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"name": "fake_tool", "toolUseId": "123", "input": {}}}], + }, + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, + ] + ) + model = MagicMock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + mock_provider.stream([]), + ] + + mock_callback = unittest.mock.Mock() + agent = Agent(model=model, tools=[fake_tool], callback_handler=mock_callback) + + stream = agent.stream_async("Do the stuff", arg1=1013) + + # Base object with common properties + throttle_props = { + "agent": ANY, + "event_loop_cycle_id": ANY, + "event_loop_cycle_span": ANY, + "event_loop_cycle_trace": ANY, + "arg1": 1013, + "request_state": {}, + } + + tru_events = await alist(stream) + exp_events = [ + {"arg1": 1013, "init_event_loop": True}, + {"start": True}, + {"start_event_loop": True}, + {"event_loop_throttled_delay": 8, **throttle_props}, + {"event_loop_throttled_delay": 16, **throttle_props}, + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"redactContent": {"redactUserContentMessage": "BLOCKED!"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "INPUT BLOCKED!"}}}}, + { + "agent": ANY, + "arg1": 1013, + "data": "INPUT BLOCKED!", + "delta": {"text": "INPUT BLOCKED!"}, + "event_loop_cycle_id": ANY, + "event_loop_cycle_span": ANY, + "event_loop_cycle_trace": ANY, + "request_state": {}, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "guardrail_intervened"}}}, + {"message": {"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}}, + { + "result": AgentResult( + stop_reason="guardrail_intervened", + message={"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}, + metrics=ANY, + state={}, + ), + }, + ] + assert tru_events == exp_events + + exp_calls = [call(**event) for event in exp_events] + act_calls = mock_callback.call_args_list + assert act_calls == exp_calls + + # Ensure that all events coming out of the agent are *not* typed events + typed_events = [event for event in tru_events if isinstance(event, TypedEvent)] + assert typed_events == [] + + +@pytest.mark.asyncio +async def test_event_loop_cycle_text_response_throttling_early_end( + agenerator, + alist, + mock_time, +): + model = MagicMock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ModelThrottledException("ThrottlingException | ConverseStream"), + ] + + mock_callback = unittest.mock.Mock() + with pytest.raises(ModelThrottledException): + agent = Agent(model=model, callback_handler=mock_callback) + + # Because we're throwing an exception, we manually collect the items here + tru_events = [] + stream = agent.stream_async("Do the stuff", arg1=1013) + async for event in stream: + tru_events.append(event) + + # Base object with common properties + common_props = { + "agent": ANY, + "event_loop_cycle_id": ANY, + "event_loop_cycle_span": ANY, + "event_loop_cycle_trace": ANY, + "arg1": 1013, + "request_state": {}, + } + + exp_events = [ + {"init_event_loop": True, "arg1": 1013}, + {"start": True}, + {"start_event_loop": True}, + {"event_loop_throttled_delay": 8, **common_props}, + {"event_loop_throttled_delay": 16, **common_props}, + {"event_loop_throttled_delay": 32, **common_props}, + {"event_loop_throttled_delay": 64, **common_props}, + {"event_loop_throttled_delay": 128, **common_props}, + {"force_stop": True, "force_stop_reason": "ThrottlingException | ConverseStream"}, + ] + + assert tru_events == exp_events + + exp_calls = [call(**event) for event in exp_events] + act_calls = mock_callback.call_args_list + assert act_calls == exp_calls + + # Ensure that all events coming out of the agent are *not* typed events + typed_events = [event for event in tru_events if isinstance(event, TypedEvent)] + assert typed_events == [] diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 7760c498a..fd9548dae 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -4,6 +4,7 @@ import strands import strands.event_loop +from strands.types._events import TypedEvent from strands.types.streaming import ( ContentBlockDeltaEvent, ContentBlockStartEvent, @@ -562,6 +563,10 @@ async def test_process_stream(response, exp_events, agenerator, alist): tru_events = await alist(stream) assert tru_events == exp_events + # Ensure that we're getting typed events coming out of process_stream + non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)] + assert non_typed_events == [] + @pytest.mark.asyncio async def test_stream_messages(agenerator, alist): @@ -624,3 +629,7 @@ async def test_stream_messages(agenerator, alist): None, "test prompt", ) + + # Ensure that we're getting typed events coming out of process_stream + non_typed_events = [event for event in tru_events if not isinstance(event, TypedEvent)] + assert non_typed_events == [] From 2946104f308ce5e749a5dafad86559f4a69f781a Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Tue, 26 Aug 2025 14:38:44 -0400 Subject: [PATCH 2/4] Adjust test to include all calls --- tests/strands/agent/test_agent.py | 129 ++++++++++++++++-------------- 1 file changed, 70 insertions(+), 59 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 67ea5940a..a4a8af09a 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -668,62 +668,71 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): ) agent("test") - callback_handler.assert_has_calls( - [ - unittest.mock.call(init_event_loop=True), - unittest.mock.call(start=True), - unittest.mock.call(start_event_loop=True), - unittest.mock.call( - event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}} - ), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), - unittest.mock.call( - agent=agent, - current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, - delta={"toolUse": {"input": '{"value"}'}}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - request_state={}, - ), - unittest.mock.call(event={"contentBlockStop": {}}), - unittest.mock.call(event={"contentBlockStart": {"start": {}}}), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), - unittest.mock.call( - agent=agent, - delta={"reasoningContent": {"text": "value"}}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - reasoning=True, - reasoningText="value", - request_state={}, - ), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), - unittest.mock.call( - agent=agent, - delta={"reasoningContent": {"signature": "value"}}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - reasoning=True, - reasoning_signature="value", - request_state={}, - ), - unittest.mock.call(event={"contentBlockStop": {}}), - unittest.mock.call(event={"contentBlockStart": {"start": {}}}), - unittest.mock.call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), - unittest.mock.call( - agent=agent, - data="value", - delta={"text": "value"}, - event_loop_cycle_id=unittest.mock.ANY, - event_loop_cycle_span=unittest.mock.ANY, - event_loop_cycle_trace=unittest.mock.ANY, - request_state={}, - ), - unittest.mock.call(event={"contentBlockStop": {}}), - unittest.mock.call( + assert callback_handler.call_args_list == [ + unittest.mock.call(init_event_loop=True), + unittest.mock.call(start=True), + unittest.mock.call(start_event_loop=True), + unittest.mock.call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), + unittest.mock.call( + agent=agent, + current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, + delta={"toolUse": {"input": '{"value"}'}}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + request_state={}, + ), + unittest.mock.call(event={"contentBlockStop": {}}), + unittest.mock.call(event={"contentBlockStart": {"start": {}}}), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), + unittest.mock.call( + agent=agent, + delta={"reasoningContent": {"text": "value"}}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + reasoning=True, + reasoningText="value", + request_state={}, + ), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), + unittest.mock.call( + agent=agent, + delta={"reasoningContent": {"signature": "value"}}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + reasoning=True, + reasoning_signature="value", + request_state={}, + ), + unittest.mock.call(event={"contentBlockStop": {}}), + unittest.mock.call(event={"contentBlockStart": {"start": {}}}), + unittest.mock.call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), + unittest.mock.call( + agent=agent, + data="value", + delta={"text": "value"}, + event_loop_cycle_id=unittest.mock.ANY, + event_loop_cycle_span=unittest.mock.ANY, + event_loop_cycle_trace=unittest.mock.ANY, + request_state={}, + ), + unittest.mock.call(event={"contentBlockStop": {}}), + unittest.mock.call( + message={ + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}, + {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, + {"text": "value"}, + ], + }, + ), + unittest.mock.call( + result=AgentResult( + stop_reason="end_turn", message={ "role": "assistant", "content": [ @@ -732,9 +741,11 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): {"text": "value"}, ], }, - ), - ], - ) + metrics=unittest.mock.ANY, + state={}, + ) + ), + ] @pytest.mark.asyncio From 6927a26b2dbd1cdc6042a200fa95a5aea91f80cc Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Tue, 26 Aug 2025 15:06:22 -0400 Subject: [PATCH 3/4] Add missing model events --- src/strands/event_loop/streaming.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index d21afe09b..7507c6d75 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -9,6 +9,9 @@ ModelStopReason, ModelStreamChunkEvent, ModelStreamEvent, + ReasoningSignatureStreamEvent, + ReasoningTextStreamEvent, + TextStreamEvent, ToolUseStreamEvent, TypedEvent, ) @@ -135,7 +138,7 @@ def handle_content_block_delta( elif "text" in delta_content: state["text"] += delta_content["text"] - typed_event["callback"] = {"data": delta_content["text"], "delta": delta_content} + typed_event = TextStreamEvent(text=delta_content["text"], delta=delta_content) elif "reasoningContent" in delta_content: if "text" in delta_content["reasoningContent"]: @@ -143,22 +146,20 @@ def handle_content_block_delta( state["reasoningText"] = "" state["reasoningText"] += delta_content["reasoningContent"]["text"] - typed_event["callback"] = { - "reasoningText": delta_content["reasoningContent"]["text"], - "delta": delta_content, - "reasoning": True, - } + typed_event = ReasoningTextStreamEvent( + reasoning_text=delta_content["reasoningContent"]["text"], + delta=delta_content, + ) elif "signature" in delta_content["reasoningContent"]: if "signature" not in state: state["signature"] = "" state["signature"] += delta_content["reasoningContent"]["signature"] - typed_event["callback"] = { - "reasoning_signature": delta_content["reasoningContent"]["signature"], - "delta": delta_content, - "reasoning": True, - } + typed_event = ReasoningSignatureStreamEvent( + reasoning_signature=delta_content["reasoningContent"]["signature"], + delta=delta_content, + ) return state, typed_event @@ -287,8 +288,8 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[T elif "contentBlockStart" in chunk: state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"]) elif "contentBlockDelta" in chunk: - state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state) - yield callback_event + state, typed_event = handle_content_block_delta(chunk["contentBlockDelta"], state) + yield typed_event elif "contentBlockStop" in chunk: state = handle_content_block_stop(state) elif "messageStop" in chunk: @@ -306,7 +307,7 @@ async def stream_messages( system_prompt: Optional[str], messages: Messages, tool_specs: list[ToolSpec], -) -> AsyncGenerator[dict[str, Any], None]: +) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. Args: From f40e7a5c41b383d03b35837af7311320abade957 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Tue, 26 Aug 2025 15:59:03 -0400 Subject: [PATCH 4/4] Fix type error --- src/strands/event_loop/event_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 84d3fea78..a166902eb 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -244,8 +244,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_start_time=cycle_start_time, invocation_state=invocation_state, ) - async for event in events: - yield event + async for typed_event in events: + yield typed_event return