From 3ae330b649724c9c6ddb677d1e4283997c26a634 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 1 Jul 2025 15:12:53 +0000 Subject: [PATCH 1/4] executor - run tools - yield --- src/strands/event_loop/event_loop.py | 5 +- src/strands/tools/executor.py | 153 ++++++++++++--------------- src/strands/types/event_loop.py | 3 + tests/strands/tools/test_executor.py | 134 +++-------------------- 4 files changed, 86 insertions(+), 209 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 82c3ef176..f33d259bd 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -12,7 +12,7 @@ import time import uuid from functools import partial -from typing import Any, Generator, Optional, cast +from typing import Any, Generator, Optional from opentelemetry import trace @@ -369,11 +369,10 @@ def _handle_tool_execution( kwargs=kwargs, ) - run_tools( + yield from run_tools( handler=tool_handler_process, tool_uses=tool_uses, event_loop_metrics=event_loop_metrics, - request_state=cast(Any, kwargs["request_state"]), invalid_tool_use_ids=invalid_tool_use_ids, tool_results=tool_results, cycle_trace=cycle_trace, diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index c90202393..fb0fda049 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -1,9 +1,10 @@ """Tool execution functionality for the event loop.""" import logging +import queue +import threading import time -from concurrent.futures import TimeoutError -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Generator, Optional from opentelemetry import trace @@ -19,127 +20,107 @@ def run_tools( handler: Callable[[ToolUse], ToolResult], - tool_uses: List[ToolUse], + tool_uses: list[ToolUse], event_loop_metrics: EventLoopMetrics, - request_state: Any, - invalid_tool_use_ids: List[str], - tool_results: List[ToolResult], + invalid_tool_use_ids: list[str], + tool_results: list[ToolResult], cycle_trace: Trace, parent_span: Optional[trace.Span] = None, parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None, -) -> bool: +) -> Generator[dict[str, Any], None, None]: """Execute tools either in parallel or sequentially. Args: handler: Tool handler processing function. tool_uses: List of tool uses to execute. event_loop_metrics: Metrics collection object. - request_state: Current request state. invalid_tool_use_ids: List of invalid tool use IDs. tool_results: List to populate with tool results. cycle_trace: Parent trace for the current cycle. parent_span: Parent span for the current cycle. parallel_tool_executor: Optional executor for parallel processing. - Returns: - bool: True if any tool failed, False otherwise. + Yields: + Events of the tool invocations. Tool results are appended to `tool_results`. """ - def _handle_tool_execution(tool: ToolUse) -> Tuple[bool, Optional[ToolResult]]: - result = None - tool_succeeded = False - + def handle(tool: ToolUse) -> Generator[dict[str, ToolResult], None, None]: tracer = get_tracer() tool_call_span = tracer.start_tool_call_span(tool, parent_span) - try: - if "toolUseId" not in tool or tool["toolUseId"] not in invalid_tool_use_ids: - tool_name = tool["name"] - tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) - tool_start_time = time.time() - result = handler(tool) - tool_success = result.get("status") == "success" - if tool_success: - tool_succeeded = True - - tool_duration = time.time() - tool_start_time - message = Message(role="user", content=[{"toolResult": result}]) - event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) - cycle_trace.add_child(tool_trace) - - if tool_call_span: - tracer.end_tool_call_span(tool_call_span, result) - except Exception as e: - if tool_call_span: - tracer.end_span_with_error(tool_call_span, str(e), e) - - return tool_succeeded, result - - any_tool_failed = False + tool_name = tool["name"] + tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) + tool_start_time = time.time() + + result = handler(tool) + yield {"result": result} + + tool_success = result.get("status") == "success" + tool_duration = time.time() - tool_start_time + message = Message(role="user", content=[{"toolResult": result}]) + event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) + cycle_trace.add_child(tool_trace) + + if tool_call_span: + tracer.end_tool_call_span(tool_call_span, result) + + def work( + tool: ToolUse, + worker_id: int, + worker_queue: queue.Queue, + worker_event: threading.Event, + worker_lock: threading.Lock, + ) -> None: + for event in handle(tool): + worker_queue.put((worker_id, event)) + worker_event.wait() + + with worker_lock: + tool_results.append(event["result"]) + + tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] + if parallel_tool_executor: logger.debug( "tool_count=<%s>, tool_executor=<%s> | executing tools in parallel", len(tool_uses), type(parallel_tool_executor).__name__, ) - # Submit all tasks with their associated tools - future_to_tool = { - parallel_tool_executor.submit(_handle_tool_execution, tool_use): tool_use for tool_use in tool_uses - } + + worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue() + worker_events = [threading.Event() for _ in range(len(tool_uses))] + worker_lock = threading.Lock() + + workers = [ + parallel_tool_executor.submit( + work, tool_use, worker_id, worker_queue, worker_events[worker_id], worker_lock + ) + for worker_id, tool_use in enumerate(tool_uses) + ] logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses)) - # Collect results truly in parallel using the provided executor's as_completed method - completed_results = [] - try: - for future in parallel_tool_executor.as_completed(future_to_tool): - try: - succeeded, result = future.result() - if result is not None: - completed_results.append(result) - if not succeeded: - any_tool_failed = True - except Exception as e: - tool = future_to_tool[future] - logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], e) - any_tool_failed = True - except TimeoutError: - logger.error("timeout_seconds=<%s> | parallel tool execution timed out", parallel_tool_executor.timeout) - # Process any completed tasks - for future in future_to_tool: - if future.done(): # type: ignore - try: - succeeded, result = future.result(timeout=0) - if result is not None: - completed_results.append(result) - except Exception as tool_e: - tool = future_to_tool[future] - logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], tool_e) - else: - # This future didn't complete within the timeout - tool = future_to_tool[future] - logger.debug("tool_name=<%s> | tool execution timed out", tool["name"]) - - any_tool_failed = True - - # Add completed results to tool_results - tool_results.extend(completed_results) + while not all(worker.done() for worker in workers): + if not worker_queue.empty(): + worker_id, event = worker_queue.get() + if "callback" in event: + yield event + worker_events[worker_id].set() + else: # Sequential execution fallback for tool_use in tool_uses: - succeeded, result = _handle_tool_execution(tool_use) - if result is not None: - tool_results.append(result) - if not succeeded: - any_tool_failed = True + for event in handle(tool_use): + if "callback" in event: + yield event - return any_tool_failed + tool_results.append(event["result"]) def validate_and_prepare_tools( message: Message, - tool_uses: List[ToolUse], - tool_results: List[ToolResult], - invalid_tool_use_ids: List[str], + tool_uses: list[ToolUse], + tool_results: list[ToolResult], + invalid_tool_use_ids: list[str], ) -> None: """Validate tool uses and prepare them for execution. diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index bbf4df95b..08ad8dc0d 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -65,6 +65,9 @@ def result(self, timeout: Optional[int] = None) -> Any: Any: The result of the asynchronous operation. """ + def done(self) -> bool: + """Returns true if future is done executing.""" + @runtime_checkable class ParallelToolExecutorInterface(Protocol): diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index 4b2387923..f730f473d 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -54,11 +54,6 @@ def event_loop_metrics(): return strands.telemetry.metrics.EventLoopMetrics() -@pytest.fixture -def request_state(): - return {} - - @pytest.fixture def invalid_tool_use_ids(request): return request.param if hasattr(request, "param") else [] @@ -92,24 +87,22 @@ def test_run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parallel_tool_executor, ) - assert not failed + list(stream) tru_results = tool_results exp_results = [ @@ -132,24 +125,22 @@ def test_run_tools_invalid_tool( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parallel_tool_executor, ) - assert failed + list(stream) tru_results = tool_results exp_results = [] @@ -162,24 +153,22 @@ def test_run_tools_failed_tool( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parallel_tool_executor, ) - assert failed + list(stream) tru_results = tool_results exp_results = [ @@ -222,23 +211,21 @@ def test_run_tools_sequential( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, None, # parallel_tool_executor ) - assert failed + list(stream) tru_results = tool_results exp_results = [ @@ -311,7 +298,6 @@ def test_run_tools_creates_and_ends_span_on_success( tool_uses, mock_metrics_client, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, @@ -329,17 +315,17 @@ def test_run_tools_creates_and_ends_span_on_success( tool_results = [] # Run the tool - strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parent_span, parallel_tool_executor, ) + list(stream) # Verify span was created with the parent span mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) @@ -359,7 +345,6 @@ def test_run_tools_creates_and_ends_span_on_failure( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, @@ -377,17 +362,17 @@ def test_run_tools_creates_and_ends_span_on_failure( tool_results = [] # Run the tool - strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parent_span, parallel_tool_executor, ) + list(stream) # Verify span was created with the parent span mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) @@ -399,96 +384,6 @@ def test_run_tools_creates_and_ends_span_on_failure( assert args[1]["status"] == "failed" -@unittest.mock.patch("strands.tools.executor.get_tracer") -def test_run_tools_handles_exception_in_tool_execution( - mock_get_tracer, - tool_handler, - tool_uses, - event_loop_metrics, - request_state, - invalid_tool_use_ids, - cycle_trace, - parallel_tool_executor, -): - """Test that run_tools properly handles exceptions during tool execution.""" - # Setup mock tracer and span - mock_tracer = unittest.mock.MagicMock() - mock_span = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.return_value = mock_span - mock_get_tracer.return_value = mock_tracer - - # Make the tool handler throw an exception - exception = ValueError("Test tool execution error") - mock_handler = unittest.mock.MagicMock(side_effect=exception) - - tool_results = [] - - # Run the tool - the exception should be caught inside run_tools and not propagate - # because of the try-except block in the new implementation - failed = strands.tools.executor.run_tools( - mock_handler, - tool_uses, - event_loop_metrics, - request_state, - invalid_tool_use_ids, - tool_results, - cycle_trace, - None, - parallel_tool_executor, - ) - - # Tool execution should have failed - assert failed - - # Verify span was created - mock_tracer.start_tool_call_span.assert_called_once() - - # Verify span was ended with the error - mock_tracer.end_span_with_error.assert_called_once_with(mock_span, str(exception), exception) - - -@unittest.mock.patch("strands.tools.executor.get_tracer") -def test_run_tools_with_invalid_tool_use_id_still_creates_span( - mock_get_tracer, - tool_handler, - tool_uses, - event_loop_metrics, - request_state, - cycle_trace, - parallel_tool_executor, -): - """Test that run_tools creates a span even when the tool use ID is invalid.""" - # Setup mock tracer and span - mock_tracer = unittest.mock.MagicMock() - mock_span = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.return_value = mock_span - mock_get_tracer.return_value = mock_tracer - - # Mark the tool use ID as invalid - invalid_tool_use_ids = [tool_uses[0]["toolUseId"]] - - tool_results = [] - - # Run the tool - strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - request_state, - invalid_tool_use_ids, - tool_results, - cycle_trace, - None, - parallel_tool_executor, - ) - - # Verify span was created - mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], None) - - # Verify span was ended even though the tool wasn't executed - mock_tracer.end_tool_call_span.assert_called_once() - - @unittest.mock.patch("strands.tools.executor.get_tracer") @pytest.mark.parametrize( ("tool_uses", "invalid_tool_use_ids"), @@ -516,7 +411,6 @@ def test_run_tools_parallel_execution_with_spans( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, @@ -535,17 +429,17 @@ def test_run_tools_parallel_execution_with_spans( tool_results = [] # Run the tools - strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parent_span, parallel_tool_executor, ) + list(stream) # Verify spans were created for both tools assert mock_tracer.start_tool_call_span.call_count == 2 From 8f258714f0e4a59705ec26b566c91540edbab0bd Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 1 Jul 2025 23:23:30 +0000 Subject: [PATCH 2/4] tool_call_span exists | yield events as given --- src/strands/tools/executor.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index fb0fda049..50f921c82 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -61,8 +61,7 @@ def handle(tool: ToolUse) -> Generator[dict[str, ToolResult], None, None]: event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) cycle_trace.add_child(tool_trace) - if tool_call_span: - tracer.end_tool_call_span(tool_call_span, result) + tracer.end_tool_call_span(tool_call_span, result) def work( tool: ToolUse, @@ -102,16 +101,14 @@ def work( while not all(worker.done() for worker in workers): if not worker_queue.empty(): worker_id, event = worker_queue.get() - if "callback" in event: - yield event + yield event worker_events[worker_id].set() else: # Sequential execution fallback for tool_use in tool_uses: for event in handle(tool_use): - if "callback" in event: - yield event + yield event tool_results.append(event["result"]) From a46039ee3e896904fd2aec4f3fd9396497dd015a Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 1 Jul 2025 23:25:11 +0000 Subject: [PATCH 3/4] tool_call_span None check --- src/strands/tools/executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index 50f921c82..f7f6ffe7b 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -61,7 +61,8 @@ def handle(tool: ToolUse) -> Generator[dict[str, ToolResult], None, None]: event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) cycle_trace.add_child(tool_trace) - tracer.end_tool_call_span(tool_call_span, result) + if tool_call_span: + tracer.end_tool_call_span(tool_call_span, result) def work( tool: ToolUse, From d44b78ce5d2a16fcf94bb286e4ee06963e86f5de Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 2 Jul 2025 13:24:52 +0000 Subject: [PATCH 4/4] return tool result --- src/strands/tools/executor.py | 38 ++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index f7f6ffe7b..912283d1f 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -4,7 +4,7 @@ import queue import threading import time -from typing import Any, Callable, Generator, Optional +from typing import Any, Callable, Generator, Optional, cast from opentelemetry import trace @@ -44,7 +44,7 @@ def run_tools( Events of the tool invocations. Tool results are appended to `tool_results`. """ - def handle(tool: ToolUse) -> Generator[dict[str, ToolResult], None, None]: + def handle(tool: ToolUse) -> Generator[dict[str, Any], None, ToolResult]: tracer = get_tracer() tool_call_span = tracer.start_tool_call_span(tool, parent_span) @@ -53,7 +53,7 @@ def handle(tool: ToolUse) -> Generator[dict[str, ToolResult], None, None]: tool_start_time = time.time() result = handler(tool) - yield {"result": result} + yield {"result": result} # Placeholder until handler becomes a generator from which we can yield from tool_success = result.get("status") == "success" tool_duration = time.time() - tool_start_time @@ -64,19 +64,24 @@ def handle(tool: ToolUse) -> Generator[dict[str, ToolResult], None, None]: if tool_call_span: tracer.end_tool_call_span(tool_call_span, result) + return result + def work( tool: ToolUse, worker_id: int, worker_queue: queue.Queue, worker_event: threading.Event, - worker_lock: threading.Lock, - ) -> None: - for event in handle(tool): - worker_queue.put((worker_id, event)) - worker_event.wait() + ) -> ToolResult: + events = handle(tool) + + while True: + try: + event = next(events) + worker_queue.put((worker_id, event)) + worker_event.wait() - with worker_lock: - tool_results.append(event["result"]) + except StopIteration as stop: + return cast(ToolResult, stop.value) tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] @@ -89,12 +94,9 @@ def work( worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue() worker_events = [threading.Event() for _ in range(len(tool_uses))] - worker_lock = threading.Lock() workers = [ - parallel_tool_executor.submit( - work, tool_use, worker_id, worker_queue, worker_events[worker_id], worker_lock - ) + parallel_tool_executor.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id]) for worker_id, tool_use in enumerate(tool_uses) ] logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses)) @@ -105,13 +107,13 @@ def work( yield event worker_events[worker_id].set() + tool_results.extend([worker.result() for worker in workers]) + else: # Sequential execution fallback for tool_use in tool_uses: - for event in handle(tool_use): - yield event - - tool_results.append(event["result"]) + result = yield from handle(tool_use) + tool_results.append(result) def validate_and_prepare_tools(