Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 43 additions & 22 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ..models.model import Model
from ..session.session_manager import SessionManager
from ..telemetry.metrics import EventLoopMetrics
from ..telemetry.tracer import get_tracer
from ..telemetry.tracer import get_tracer, serialize
from ..tools.registry import ToolRegistry
from ..tools.watcher import ToolWatcher
from ..types.content import ContentBlock, Message, Messages
Expand Down Expand Up @@ -445,27 +445,48 @@ async def structured_output_async(
ValueError: If no conversation history or prompt is provided.
"""
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))

try:
if not self.messages and not prompt:
raise ValueError("No conversation history or prompt provided")

# Create temporary messages array if prompt is provided
if prompt:
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
temp_messages = self.messages + [{"role": "user", "content": content}]
else:
temp_messages = self.messages

events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
async for event in events:
if "callback" in event:
self.callback_handler(**cast(dict, event["callback"]))

return event["output"]

finally:
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
with self.tracer.tracer.start_as_current_span(
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
) as structured_output_span:
try:
if not self.messages and not prompt:
raise ValueError("No conversation history or prompt provided")
# Create temporary messages array if prompt is provided
if prompt:
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
temp_messages = self.messages + [{"role": "user", "content": content}]
else:
temp_messages = self.messages

structured_output_span.set_attributes(
{
"gen_ai.system": "strands-agents",
"gen_ai.agent.name": self.name,
"gen_ai.agent.id": self.agent_id,
"gen_ai.operation.name": "execute_structured_output",
}
)
for message in temp_messages:
structured_output_span.add_event(
f"gen_ai.{message['role']}.message",
attributes={"role": message["role"], "content": serialize(message["content"])},
)
if self.system_prompt:
structured_output_span.add_event(
"gen_ai.system.message",
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
)
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
async for event in events:
if "callback" in event:
self.callback_handler(**cast(dict, event["callback"]))
structured_output_span.add_event(
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
)
return event["output"]

finally:
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))

async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
Expand Down
39 changes: 39 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,14 @@ def test_agent_callback_handler_custom_handler_used():


def test_agent_structured_output(agent, system_prompt, user, agenerator):
# Setup mock tracer and span
mock_strands_tracer = unittest.mock.MagicMock()
mock_otel_tracer = unittest.mock.MagicMock()
mock_span = unittest.mock.MagicMock()
mock_strands_tracer.tracer = mock_otel_tracer
mock_otel_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span
agent.tracer = mock_strands_tracer

agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))

prompt = "Jane Doe is 30 years old and her email is [email protected]"
Expand All @@ -999,8 +1007,34 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator):
type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt
)

mock_span.set_attributes.assert_called_once_with(
{
"gen_ai.system": "strands-agents",
"gen_ai.agent.name": "Strands Agents",
"gen_ai.agent.id": "default",
"gen_ai.operation.name": "execute_structured_output",
}
)

mock_span.add_event.assert_any_call(
"gen_ai.user.message",
attributes={"role": "user", "content": '[{"text": "Jane Doe is 30 years old and her email is [email protected]"}]'},
)

mock_span.add_event.assert_called_with(
"gen_ai.choice",
attributes={"message": json.dumps(user.model_dump())},
)


def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, agenerator):
# Setup mock tracer and span
mock_strands_tracer = unittest.mock.MagicMock()
mock_otel_tracer = unittest.mock.MagicMock()
mock_span = unittest.mock.MagicMock()
mock_strands_tracer.tracer = mock_otel_tracer
mock_otel_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span
agent.tracer = mock_strands_tracer
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))

prompt = [
Expand Down Expand Up @@ -1030,6 +1064,11 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a
type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt
)

mock_span.add_event.assert_called_with(
"gen_ai.choice",
attributes={"message": json.dumps(user.model_dump())},
)


@pytest.mark.asyncio
async def test_agent_structured_output_in_async_context(agent, user, agenerator):
Expand Down
Loading