Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.models import Model
from ..types.tools import ToolConfig
from ..types.tools import ToolConfig, ToolResult
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .conversation_manager import (
Expand Down Expand Up @@ -182,7 +182,7 @@ def caller(**kwargs: Any) -> Any:
}

# Execute the tool
tool_result = self._agent.tool_handler.process(
events = self._agent.tool_handler.process(
tool=tool_use,
model=self._agent.model,
system_prompt=self._agent.system_prompt,
Expand All @@ -194,6 +194,7 @@ def caller(**kwargs: Any) -> Any:
agent=self._agent,
**handler_kwargs,
)
tool_result = list(events)[-1]

if record_direct_tool_call:
# Create a record of this tool execution in the message history
Expand Down Expand Up @@ -576,7 +577,7 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
def _record_tool_execution(
self,
tool: Dict[str, Any],
tool_result: Dict[str, Any],
tool_result: ToolResult,
user_message_override: Optional[str],
messages: List[Dict[str, Any]],
) -> None:
Expand Down
5 changes: 2 additions & 3 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
import uuid
from functools import partial
from typing import Any, Callable, Generator, Optional, cast
from typing import Any, Callable, Generator, Optional

from ..telemetry.metrics import EventLoopMetrics, Trace
from ..telemetry.tracer import get_tracer
Expand Down Expand Up @@ -352,11 +352,10 @@ def _handle_tool_execution(
**kwargs,
)

run_tools(
yield from run_tools(
handler=tool_handler_process,
tool_uses=tool_uses,
event_loop_metrics=event_loop_metrics,
request_state=cast(Any, kwargs["request_state"]),
invalid_tool_use_ids=invalid_tool_use_ids,
tool_results=tool_results,
cycle_trace=cycle_trace,
Expand Down
20 changes: 11 additions & 9 deletions src/strands/handlers/tool_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This module provides handlers for managing tool invocations."""

import logging
from typing import Any, List, Optional
from typing import Any, Generator, Optional, Union

from ..tools.registry import ToolRegistry
from ..types.models import Model
Expand Down Expand Up @@ -49,11 +49,11 @@ def process(
*,
model: Model,
system_prompt: Optional[str],
messages: List[Any],
messages: list[Any],
tool_config: Any,
callback_handler: Any,
**kwargs: Any,
) -> Any:
) -> Generator[Union[ToolResult, Any], None, None]:
"""Process a tool invocation.

Looks up the tool in the registry and invokes it with the provided parameters.
Expand All @@ -67,10 +67,10 @@ def process(
callback_handler: Callback for processing events as they happen.
**kwargs: Additional keyword arguments passed to the tool.

Returns:
The result of the tool invocation, or an error response if the tool fails or is not found.
Yields:
Events of the tool invocation. The final event is always the tool result.
"""
logger.debug("tool=<%s> | invoking", tool)
logger.debug("tool=<%s> | streaming", tool)
tool_use_id = tool["toolUseId"]
tool_name = tool["name"]

Expand All @@ -86,11 +86,13 @@ def process(
tool_name,
list(self.tool_registry.registry.keys()),
)
return {
yield {
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Unknown tool: {tool_name}"}],
}
return

# Add standard arguments to kwargs for Python tools
kwargs.update(
{
Expand All @@ -102,11 +104,11 @@ def process(
}
)

return tool_func.invoke(tool, **kwargs)
yield from tool_func.stream(tool, **kwargs)

except Exception as e:
logger.exception("tool_name=<%s> | failed to process tool", tool_name)
return {
yield {
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Error: {str(e)}"}],
Expand Down
52 changes: 32 additions & 20 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Optional,
ParamSpec,
Expand Down Expand Up @@ -119,7 +119,7 @@ def _create_input_model(self) -> Type[BaseModel]:
Returns:
A Pydantic BaseModel class customized for the function's parameters.
"""
field_definitions: Dict[str, Any] = {}
field_definitions: dict[str, Any] = {}

for name, param in self.signature.parameters.items():
# Skip special parameters
Expand Down Expand Up @@ -179,7 +179,7 @@ def extract_metadata(self) -> ToolSpec:

return tool_spec

def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None:
def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None:
"""Clean up Pydantic schema to match Strands' expected format.

Pydantic's JSON schema output includes several elements that aren't needed for Strands Agent tools and could
Expand Down Expand Up @@ -227,7 +227,7 @@ def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None:
if key in prop_schema:
del prop_schema[key]

def validate_input(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
"""Validate input data using the Pydantic model.

This method ensures that the input data meets the expected schema before it's passed to the actual function. It
Expand Down Expand Up @@ -353,12 +353,15 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
# This block is only for backwards compatability so we cast as any for now
logger.warning(
"issue=<%s> | "
"passing tool use into a function instead of using .invoke will be removed in a future release",
"passing tool use into a function instead of using .stream will be removed in a future release",
"https://github.com/strands-agents/sdk-python/pull/258",
)
tool_use = cast(Any, args[0])

return cast(R, self.invoke(tool_use, **kwargs))
events = self.stream(tool_use, **kwargs)
result = list(events)[-1]

return cast(R, result)

return self.original_function(*args, **kwargs)

Expand Down Expand Up @@ -389,7 +392,8 @@ def tool_type(self) -> str:
"""
return "function"

def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult:
@override
def stream(self, tool: ToolUse, *args: Any, **kwargs: Any) -> Generator[Union[ToolResult, Any], None, None]:
"""Invoke the tool with a tool use specification.

This method handles tool use invocations from a Strands Agent. It validates the input,
Expand All @@ -408,8 +412,8 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
*args: Additional positional arguments (not typically used).
**kwargs: Additional keyword arguments, may include 'agent' reference.

Returns:
A standardized tool result dictionary with status and content.
Yields:
Events of the tool invocation. The final event is always the tool result.
"""
# This is a tool use call - process accordingly
tool_use = tool
Expand All @@ -424,27 +428,35 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
if "agent" in kwargs and "agent" in self._metadata.signature.parameters:
validated_input["agent"] = kwargs.get("agent")

# User will need to piece together a tool result themselves if using a generator
if inspect.isgeneratorfunction(self.original_function):
validated_input["tool_use_id"] = tool_use_id

# We get "too few arguments here" but because that's because fof the way we're calling it
result = self.original_function(**validated_input) # type: ignore
if inspect.isgenerator(result):
yield from result
return

# FORMAT THE RESULT for Strands Agent
if isinstance(result, dict) and "status" in result and "content" in result:
# Result is already in the expected format, just add toolUseId
result["toolUseId"] = tool_use_id
return cast(ToolResult, result)
else:
# Wrap any other return value in the standard format
# Always include at least one content item for consistency
return {
"toolUseId": tool_use_id,
"status": "success",
"content": [{"text": str(result)}],
}
yield result
return

# Wrap any other return value in the standard format
# Always include at least one content item for consistency
yield {
"toolUseId": tool_use_id,
"status": "success",
"content": [{"text": str(result)}],
}

except ValueError as e:
# Special handling for validation errors
error_msg = str(e)
return {
yield {
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Error: {error_msg}"}],
Expand All @@ -453,7 +465,7 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
# Return error result with exception details for any other error
error_type = type(e).__name__
error_msg = str(e)
return {
yield {
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Error: {error_type} - {error_msg}"}],
Expand Down
Loading
Loading