Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ dev = [
"pytest-cov>=6.0.0,<7.0.0",
"pytest-asyncio>=1.0.0,<1.2.0",
"pytest-xdist>=3.0.0,<4.0.0",
"ruff>=0.12.0,<0.13.0",
"ruff>=0.4.4,<0.5.0",
]
docs = [
"sphinx>=5.0.0,<6.0.0",
Expand Down
35 changes: 25 additions & 10 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR
def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T:
"""This method allows you to get structured output from the agent.

If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
If you pass in a prompt, it will be added to the conversation history along with the structured output result.
If you don't pass in a prompt, it will use only the existing conversation history to respond.

For smaller models, you may want to use the optional prompt to add additional instructions to explicitly
Expand Down Expand Up @@ -470,7 +470,7 @@ def execute() -> T:
async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T:
"""This method allows you to get structured output from the agent.

If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
If you pass in a prompt, it will be added to the conversation history along with the structured output result.
If you don't pass in a prompt, it will use only the existing conversation history to respond.

For smaller models, you may want to use the optional prompt to add additional instructions to explicitly
Expand All @@ -479,7 +479,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
Args:
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
that the agent will use when responding.
prompt: The prompt to use for the agent (will not be added to conversation history).
prompt: The prompt to use for the agent (will be added to conversation history).

Raises:
ValueError: If no conversation history or prompt is provided.
Expand All @@ -492,7 +492,13 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
if not self.messages and not prompt:
raise ValueError("No conversation history or prompt provided")

temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt)
# Add prompt to conversation history if provided
if prompt:
prompt_messages = self._convert_prompt_to_messages(prompt)
for message in prompt_messages:
self._append_message(message)

temp_messages: Messages = self.messages

structured_output_span.set_attributes(
{
Expand All @@ -502,24 +508,33 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
"gen_ai.operation.name": "execute_structured_output",
}
)
if self.system_prompt:
structured_output_span.add_event(
"gen_ai.system.message",
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
)
for message in temp_messages:
structured_output_span.add_event(
f"gen_ai.{message['role']}.message",
attributes={"role": message["role"], "content": serialize(message["content"])},
)
if self.system_prompt:
structured_output_span.add_event(
"gen_ai.system.message",
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
)
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
async for event in events:
if "callback" in event:
self.callback_handler(**cast(dict, event["callback"]))
structured_output_span.add_event(
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
)
return event["output"]

# Add structured output result to conversation history
result = event["output"]
assistant_message = {
"role": "assistant",
"content": [{"text": f"Structured output ({output_model.__name__}): {result.model_dump_json()}"}]
}
self._append_message(assistant_message)

return result

finally:
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
Expand Down
18 changes: 9 additions & 9 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult
"""Invoke the graph synchronously."""

def execute() -> GraphResult:
return asyncio.run(self.invoke_async(task))
return asyncio.run(self.invoke_async(task, **kwargs))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
Expand Down Expand Up @@ -420,7 +420,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G
self.node_timeout or "None",
)

await self._execute_graph()
await self._execute_graph(kwargs)

# Set final status based on execution results
if self.state.failed_nodes:
Expand Down Expand Up @@ -450,7 +450,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
# Validate Agent-specific constraints for each node
_validate_node_executor(node.executor)

async def _execute_graph(self) -> None:
async def _execute_graph(self, invocation_state: dict[str, Any]) -> None:
"""Unified execution flow with conditional routing."""
ready_nodes = list(self.entry_points)

Expand All @@ -470,7 +470,7 @@ async def _execute_graph(self) -> None:

# Execute current batch of ready nodes concurrently
tasks = [
asyncio.create_task(self._execute_node(node))
asyncio.create_task(self._execute_node(node, invocation_state))
for node in current_batch
if node not in self.state.completed_nodes
]
Expand Down Expand Up @@ -515,7 +515,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode) -> bool:
)
return False

async def _execute_node(self, node: GraphNode) -> None:
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None:
"""Execute a single node with error handling and timeout protection."""
# Reset the node's state if reset_on_revisit is enabled and it's being revisited
if self.reset_on_revisit and node in self.state.completed_nodes:
Expand All @@ -538,11 +538,11 @@ async def _execute_node(self, node: GraphNode) -> None:
if isinstance(node.executor, MultiAgentBase):
if self.node_timeout is not None:
multi_agent_result = await asyncio.wait_for(
node.executor.invoke_async(node_input),
node.executor.invoke_async(node_input, **invocation_state),
timeout=self.node_timeout,
)
else:
multi_agent_result = await node.executor.invoke_async(node_input)
multi_agent_result = await node.executor.invoke_async(node_input, **invocation_state)

# Create NodeResult with MultiAgentResult directly
node_result = NodeResult(
Expand All @@ -557,11 +557,11 @@ async def _execute_node(self, node: GraphNode) -> None:
elif isinstance(node.executor, Agent):
if self.node_timeout is not None:
agent_response = await asyncio.wait_for(
node.executor.invoke_async(node_input),
node.executor.invoke_async(node_input, **invocation_state),
timeout=self.node_timeout,
)
else:
agent_response = await node.executor.invoke_async(node_input)
agent_response = await node.executor.invoke_async(node_input, **invocation_state)

# Extract metrics from agent response
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
Expand Down
14 changes: 8 additions & 6 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult
"""Invoke the swarm synchronously."""

def execute() -> SwarmResult:
return asyncio.run(self.invoke_async(task))
return asyncio.run(self.invoke_async(task, **kwargs))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
Expand Down Expand Up @@ -272,7 +272,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> S
self.execution_timeout,
)

await self._execute_swarm()
await self._execute_swarm(kwargs)
except Exception:
logger.exception("swarm execution failed")
self.state.completion_status = Status.FAILED
Expand Down Expand Up @@ -483,7 +483,7 @@ def _build_node_input(self, target_node: SwarmNode) -> str:

return context_text

async def _execute_swarm(self) -> None:
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
"""Shared execution logic used by execute_async."""
try:
# Main execution loop
Expand Down Expand Up @@ -522,7 +522,7 @@ async def _execute_swarm(self) -> None:
# TODO: Implement cancellation token to stop _execute_node from continuing
try:
await asyncio.wait_for(
self._execute_node(current_node, self.state.task),
self._execute_node(current_node, self.state.task, invocation_state),
timeout=self.node_timeout,
)

Expand Down Expand Up @@ -563,7 +563,9 @@ async def _execute_swarm(self) -> None:
f"{elapsed_time:.2f}",
)

async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -> AgentResult:
async def _execute_node(
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any]
) -> AgentResult:
"""Execute swarm node."""
start_time = time.time()
node_name = node.node_id
Expand All @@ -583,7 +585,7 @@ async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -
# Execute node
result = None
node.reset_executor_state()
result = await node.executor.invoke_async(node_input)
result = await node.executor.invoke_async(node_input, **invocation_state)

execution_time = round((time.time() - start_time) * 1000)

Expand Down
Loading
Loading