Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
import uuid
from functools import partial
from typing import Any, Generator, Optional, cast
from typing import Any, Generator, Optional

from opentelemetry import trace

Expand Down Expand Up @@ -369,11 +369,10 @@ def _handle_tool_execution(
kwargs=kwargs,
)

run_tools(
yield from run_tools(
handler=tool_handler_process,
tool_uses=tool_uses,
event_loop_metrics=event_loop_metrics,
request_state=cast(Any, kwargs["request_state"]),
invalid_tool_use_ids=invalid_tool_use_ids,
tool_results=tool_results,
cycle_trace=cycle_trace,
Expand Down
155 changes: 68 additions & 87 deletions src/strands/tools/executor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Tool execution functionality for the event loop."""

import logging
import queue
import threading
import time
from concurrent.futures import TimeoutError
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Generator, Optional, cast

from opentelemetry import trace

Expand All @@ -19,127 +20,107 @@

def run_tools(
handler: Callable[[ToolUse], ToolResult],
tool_uses: List[ToolUse],
tool_uses: list[ToolUse],
event_loop_metrics: EventLoopMetrics,
request_state: Any,
invalid_tool_use_ids: List[str],
tool_results: List[ToolResult],
invalid_tool_use_ids: list[str],
tool_results: list[ToolResult],
cycle_trace: Trace,
parent_span: Optional[trace.Span] = None,
parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None,
) -> bool:
) -> Generator[dict[str, Any], None, None]:
"""Execute tools either in parallel or sequentially.

Args:
handler: Tool handler processing function.
tool_uses: List of tool uses to execute.
event_loop_metrics: Metrics collection object.
request_state: Current request state.
invalid_tool_use_ids: List of invalid tool use IDs.
tool_results: List to populate with tool results.
cycle_trace: Parent trace for the current cycle.
parent_span: Parent span for the current cycle.
parallel_tool_executor: Optional executor for parallel processing.

Returns:
bool: True if any tool failed, False otherwise.
Yields:
Events of the tool invocations. Tool results are appended to `tool_results`.
"""

def _handle_tool_execution(tool: ToolUse) -> Tuple[bool, Optional[ToolResult]]:
result = None
tool_succeeded = False

def handle(tool: ToolUse) -> Generator[dict[str, Any], None, ToolResult]:
tracer = get_tracer()
tool_call_span = tracer.start_tool_call_span(tool, parent_span)

try:
if "toolUseId" not in tool or tool["toolUseId"] not in invalid_tool_use_ids:
tool_name = tool["name"]
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
tool_start_time = time.time()
result = handler(tool)
tool_success = result.get("status") == "success"
if tool_success:
tool_succeeded = True

tool_duration = time.time() - tool_start_time
message = Message(role="user", content=[{"toolResult": result}])
event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message)
cycle_trace.add_child(tool_trace)

if tool_call_span:
tracer.end_tool_call_span(tool_call_span, result)
except Exception as e:
if tool_call_span:
tracer.end_span_with_error(tool_call_span, str(e), e)

return tool_succeeded, result

any_tool_failed = False
tool_name = tool["name"]
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
tool_start_time = time.time()

result = handler(tool)
yield {"result": result} # Placeholder until handler becomes a generator from which we can yield from

tool_success = result.get("status") == "success"
tool_duration = time.time() - tool_start_time
message = Message(role="user", content=[{"toolResult": result}])
event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message)
cycle_trace.add_child(tool_trace)

if tool_call_span:
tracer.end_tool_call_span(tool_call_span, result)

return result

def work(
tool: ToolUse,
worker_id: int,
worker_queue: queue.Queue,
worker_event: threading.Event,
) -> ToolResult:
events = handle(tool)

while True:
try:
event = next(events)
worker_queue.put((worker_id, event))
worker_event.wait()

except StopIteration as stop:
return cast(ToolResult, stop.value)

tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]

if parallel_tool_executor:
logger.debug(
"tool_count=<%s>, tool_executor=<%s> | executing tools in parallel",
len(tool_uses),
type(parallel_tool_executor).__name__,
)
# Submit all tasks with their associated tools
future_to_tool = {
parallel_tool_executor.submit(_handle_tool_execution, tool_use): tool_use for tool_use in tool_uses
}

worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue()
worker_events = [threading.Event() for _ in range(len(tool_uses))]

workers = [
parallel_tool_executor.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id])
for worker_id, tool_use in enumerate(tool_uses)
]
logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses))

# Collect results truly in parallel using the provided executor's as_completed method
completed_results = []
try:
for future in parallel_tool_executor.as_completed(future_to_tool):
try:
succeeded, result = future.result()
if result is not None:
completed_results.append(result)
if not succeeded:
any_tool_failed = True
except Exception as e:
tool = future_to_tool[future]
logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], e)
any_tool_failed = True
except TimeoutError:
logger.error("timeout_seconds=<%s> | parallel tool execution timed out", parallel_tool_executor.timeout)
# Process any completed tasks
for future in future_to_tool:
if future.done(): # type: ignore
try:
succeeded, result = future.result(timeout=0)
if result is not None:
completed_results.append(result)
except Exception as tool_e:
tool = future_to_tool[future]
logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], tool_e)
else:
# This future didn't complete within the timeout
tool = future_to_tool[future]
logger.debug("tool_name=<%s> | tool execution timed out", tool["name"])

any_tool_failed = True

# Add completed results to tool_results
tool_results.extend(completed_results)
while not all(worker.done() for worker in workers):
if not worker_queue.empty():
worker_id, event = worker_queue.get()
yield event
worker_events[worker_id].set()

tool_results.extend([worker.result() for worker in workers])

else:
# Sequential execution fallback
for tool_use in tool_uses:
succeeded, result = _handle_tool_execution(tool_use)
if result is not None:
tool_results.append(result)
if not succeeded:
any_tool_failed = True

return any_tool_failed
result = yield from handle(tool_use)
tool_results.append(result)


def validate_and_prepare_tools(
message: Message,
tool_uses: List[ToolUse],
tool_results: List[ToolResult],
invalid_tool_use_ids: List[str],
tool_uses: list[ToolUse],
tool_results: list[ToolResult],
invalid_tool_use_ids: list[str],
) -> None:
"""Validate tool uses and prepare them for execution.

Expand Down
3 changes: 3 additions & 0 deletions src/strands/types/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def result(self, timeout: Optional[int] = None) -> Any:
Any: The result of the asynchronous operation.
"""

def done(self) -> bool:
"""Returns true if future is done executing."""


@runtime_checkable
class ParallelToolExecutorInterface(Protocol):
Expand Down
Loading