Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 55 additions & 14 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)

from opentelemetry import trace as trace_api
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

from .. import _identifier
from ..event_loop.event_loop import event_loop_cycle
Expand Down Expand Up @@ -445,7 +445,7 @@ async def invoke_async(

return cast(AgentResult, event["result"])

def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T:
def structured_output(self, output_model: Type[T], prompt: AgentInput = None, max_retries: int = 0) -> T:
"""This method allows you to get structured output from the agent.

If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
Expand All @@ -462,19 +462,23 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) ->
- list[ContentBlock]: Multi-modal content blocks
- list[Message]: Complete messages with roles
- None: Use existing conversation history
max_retries: Maximum number of self-healing retry attempts (additional LLM calls)
if validation fails (default: 0).

Raises:
ValueError: If no conversation history or prompt is provided.
"""

def execute() -> T:
return asyncio.run(self.structured_output_async(output_model, prompt))
return asyncio.run(self.structured_output_async(output_model, prompt, max_retries))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()

async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T:
async def structured_output_async(
self, output_model: Type[T], prompt: AgentInput = None, max_retries: int = 0
) -> T:
"""This method allows you to get structured output from the agent.

If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
Expand All @@ -487,6 +491,8 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
that the agent will use when responding.
prompt: The prompt to use for the agent (will not be added to conversation history).
max_retries: Maximum number of self-healing retry attempts (additional LLM calls)
if validation fails (default: 0).

Raises:
ValueError: If no conversation history or prompt is provided.
Expand All @@ -507,6 +513,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
"gen_ai.agent.name": self.name,
"gen_ai.agent.id": self.agent_id,
"gen_ai.operation.name": "execute_structured_output",
"gen_ai.structured_output.max_retries": max_retries,
}
)
if self.system_prompt:
Expand All @@ -519,17 +526,51 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
f"gen_ai.{message['role']}.message",
attributes={"role": message["role"], "content": serialize(message["content"])},
)
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
async for event in events:
if isinstance(event, TypedEvent):
event.prepare(invocation_state={})
if event.is_callback_event:
self.callback_handler(**event.as_dict())

structured_output_span.add_event(
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
)
return event["output"]
last_exception = None
for attempt in range(max_retries + 1):
try:
if attempt > 0:
structured_output_span.add_event(
"gen_ai.structured_output.retry",
attributes={"attempt": attempt, "error": str(last_exception)},
)

events = self.model.structured_output(
output_model, temp_messages, system_prompt=self.system_prompt
)
async for event in events:
if isinstance(event, TypedEvent):
event.prepare(invocation_state={})
if event.is_callback_event:
self.callback_handler(**event.as_dict())

structured_output_span.add_event(
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
)
return event["output"]

except (ValidationError, ValueError) as e:
last_exception = e
if attempt < max_retries:
temp_messages = temp_messages + [
{
"role": "user",
"content": [
{
"text": (
"Try again to generate a structured output. "
f"Your previous attempt failed with this exception: {e}"
)
}
],
}
]
else:
raise

# Should never reach here, but satisfy type checker
raise RuntimeError("Structured output failed after all retry attempts")

finally:
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
Expand Down
71 changes: 71 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,7 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator):
"gen_ai.agent.name": "Strands Agents",
"gen_ai.agent.id": "default",
"gen_ai.operation.name": "execute_structured_output",
"gen_ai.structured_output.max_retries": 0,
}
)

Expand Down Expand Up @@ -1143,6 +1144,76 @@ async def test_agent_structured_output_async(agent, system_prompt, user, agenera
)


def test_agent_structured_output_with_retry_on_validation_error(agent, system_prompt, user, agenerator):
"""Test that structured_output retries on ValidationError."""
from pydantic import ValidationError

# First call raises ValidationError, second call succeeds
call_count = 0

async def mock_structured_output(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ValidationError.from_exception_data("test", [])
else:
async for event in agenerator([{"output": user}]):
yield event

agent.model.structured_output = mock_structured_output

prompt = "Jane Doe is 30 years old and her email is [email protected]"

# Call with max_retries=1
tru_result = agent.structured_output(type(user), prompt, max_retries=1)
exp_result = user
assert tru_result == exp_result
assert call_count == 2 # Should have been called twice


def test_agent_structured_output_with_retry_on_value_error(agent, system_prompt, user, agenerator):
"""Test that structured_output retries on ValueError."""
# First call raises ValueError, second call succeeds
call_count = 0

async def mock_structured_output(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ValueError("No valid tool use found")
else:
async for event in agenerator([{"output": user}]):
yield event

agent.model.structured_output = mock_structured_output

prompt = "Jane Doe is 30 years old and her email is [email protected]"

# Call with max_retries=1
tru_result = agent.structured_output(type(user), prompt, max_retries=1)
exp_result = user
assert tru_result == exp_result
assert call_count == 2 # Should have been called twice


def test_agent_structured_output_retry_exhausted(agent, system_prompt, user):
"""Test that structured_output raises exception after exhausting retries."""
from pydantic import ValidationError

# Always raise ValidationError
async def mock_structured_output(*args, **kwargs):
raise ValidationError.from_exception_data("test", [])
yield # Make it a generator

agent.model.structured_output = mock_structured_output

prompt = "Jane Doe is 30 years old and her email is [email protected]"

# Should raise after max_retries attempts
with pytest.raises(ValidationError):
agent.structured_output(type(user), prompt, max_retries=2)


@pytest.mark.asyncio
async def test_stream_async_returns_all_events(mock_event_loop_cycle, alist):
agent = Agent()
Expand Down