diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index bb602d66b..4579ebacf 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -514,8 +514,11 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu ) events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) async for event in events: - if "callback" in event: - self.callback_handler(**cast(dict, event["callback"])) + if isinstance(event, TypedEvent): + event.prepare(invocation_state={}) + if event.is_callback_event: + self.callback_handler(**event.as_dict()) + structured_output_span.add_event( "gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())} ) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 01bfc5409..ea2c817fc 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -3,10 +3,12 @@ from unittest.mock import ANY, MagicMock, call import pytest +from pydantic import BaseModel import strands from strands import Agent from strands.agent import AgentResult +from strands.models import BedrockModel from strands.types._events import TypedEvent from strands.types.exceptions import ModelThrottledException from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -440,3 +442,132 @@ async def test_event_loop_cycle_text_response_throttling_early_end( # Ensure that all events coming out of the agent are *not* typed events typed_events = [event for event in tru_events if isinstance(event, TypedEvent)] assert typed_events == [] + + +@pytest.mark.asyncio +async def test_structured_output(agenerator): + # we use bedrock here as it uses the tool implementation + model = BedrockModel() + model.stream = MagicMock() + model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": {"toolUse": {"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", "name": "Person"}}, + "contentBlockIndex": 0, + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"na'}}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": 'me"'}}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ': "J'}}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": 'ohn"'}}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ', "age": 3'}}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "1}"}}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": {"inputTokens": 407, "outputTokens": 53, "totalTokens": 460}, + "metrics": {"latencyMs": 1572}, + } + }, + ] + ) + + mock_callback = unittest.mock.Mock() + agent = Agent(model=model, callback_handler=mock_callback) + + class Person(BaseModel): + name: str + age: float + + await agent.structured_output_async(Person, "John is 31") + + exp_events = [ + { + "event": { + "contentBlockStart": { + "start": {"toolUse": {"toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", "name": "Person"}}, + "contentBlockIndex": 0, + } + } + }, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 0}}}, + { + "delta": {"toolUse": {"input": ""}}, + "current_tool_use": { + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", + "name": "Person", + "input": {"name": "John", "age": 31}, + }, + }, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"na'}}, "contentBlockIndex": 0}}}, + { + "delta": {"toolUse": {"input": '{"na'}}, + "current_tool_use": { + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", + "name": "Person", + "input": {"name": "John", "age": 31}, + }, + }, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": 'me"'}}, "contentBlockIndex": 0}}}, + { + "delta": {"toolUse": {"input": 'me"'}}, + "current_tool_use": { + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", + "name": "Person", + "input": {"name": "John", "age": 31}, + }, + }, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ': "J'}}, "contentBlockIndex": 0}}}, + { + "delta": {"toolUse": {"input": ': "J'}}, + "current_tool_use": { + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", + "name": "Person", + "input": {"name": "John", "age": 31}, + }, + }, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": 'ohn"'}}, "contentBlockIndex": 0}}}, + { + "delta": {"toolUse": {"input": 'ohn"'}}, + "current_tool_use": { + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", + "name": "Person", + "input": {"name": "John", "age": 31}, + }, + }, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ', "age": 3'}}, "contentBlockIndex": 0}}}, + { + "delta": {"toolUse": {"input": ', "age": 3'}}, + "current_tool_use": { + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", + "name": "Person", + "input": {"name": "John", "age": 31}, + }, + }, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "1}"}}, "contentBlockIndex": 0}}}, + { + "delta": {"toolUse": {"input": "1}"}}, + "current_tool_use": { + "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", + "name": "Person", + "input": {"name": "John", "age": 31}, + }, + }, + {"event": {"contentBlockStop": {"contentBlockIndex": 0}}}, + {"event": {"messageStop": {"stopReason": "tool_use"}}}, + { + "event": { + "metadata": { + "usage": {"inputTokens": 407, "outputTokens": 53, "totalTokens": 460}, + "metrics": {"latencyMs": 1572}, + } + } + }, + ] + + exp_calls = [call(**event) for event in exp_events] + act_calls = mock_callback.call_args_list + assert act_calls == exp_calls