diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index a356bc3ee..3ca04851b 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -12,7 +12,7 @@ import time import uuid from functools import partial -from typing import Any, Generator, Optional, cast +from typing import Any, Generator, Optional from opentelemetry import trace @@ -369,11 +369,10 @@ def _handle_tool_execution( kwargs=kwargs, ) - run_tools( + yield from run_tools( handler=tool_handler_process, tool_uses=tool_uses, event_loop_metrics=event_loop_metrics, - request_state=cast(Any, kwargs["request_state"]), invalid_tool_use_ids=invalid_tool_use_ids, tool_results=tool_results, cycle_trace=cycle_trace, diff --git a/src/strands/tools/executor.py b/src/strands/tools/executor.py index c90202393..912283d1f 100644 --- a/src/strands/tools/executor.py +++ b/src/strands/tools/executor.py @@ -1,9 +1,10 @@ """Tool execution functionality for the event loop.""" import logging +import queue +import threading import time -from concurrent.futures import TimeoutError -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Generator, Optional, cast from opentelemetry import trace @@ -19,127 +20,107 @@ def run_tools( handler: Callable[[ToolUse], ToolResult], - tool_uses: List[ToolUse], + tool_uses: list[ToolUse], event_loop_metrics: EventLoopMetrics, - request_state: Any, - invalid_tool_use_ids: List[str], - tool_results: List[ToolResult], + invalid_tool_use_ids: list[str], + tool_results: list[ToolResult], cycle_trace: Trace, parent_span: Optional[trace.Span] = None, parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None, -) -> bool: +) -> Generator[dict[str, Any], None, None]: """Execute tools either in parallel or sequentially. Args: handler: Tool handler processing function. tool_uses: List of tool uses to execute. event_loop_metrics: Metrics collection object. - request_state: Current request state. invalid_tool_use_ids: List of invalid tool use IDs. tool_results: List to populate with tool results. cycle_trace: Parent trace for the current cycle. parent_span: Parent span for the current cycle. parallel_tool_executor: Optional executor for parallel processing. - Returns: - bool: True if any tool failed, False otherwise. + Yields: + Events of the tool invocations. Tool results are appended to `tool_results`. """ - def _handle_tool_execution(tool: ToolUse) -> Tuple[bool, Optional[ToolResult]]: - result = None - tool_succeeded = False - + def handle(tool: ToolUse) -> Generator[dict[str, Any], None, ToolResult]: tracer = get_tracer() tool_call_span = tracer.start_tool_call_span(tool, parent_span) - try: - if "toolUseId" not in tool or tool["toolUseId"] not in invalid_tool_use_ids: - tool_name = tool["name"] - tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) - tool_start_time = time.time() - result = handler(tool) - tool_success = result.get("status") == "success" - if tool_success: - tool_succeeded = True - - tool_duration = time.time() - tool_start_time - message = Message(role="user", content=[{"toolResult": result}]) - event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) - cycle_trace.add_child(tool_trace) - - if tool_call_span: - tracer.end_tool_call_span(tool_call_span, result) - except Exception as e: - if tool_call_span: - tracer.end_span_with_error(tool_call_span, str(e), e) - - return tool_succeeded, result - - any_tool_failed = False + tool_name = tool["name"] + tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) + tool_start_time = time.time() + + result = handler(tool) + yield {"result": result} # Placeholder until handler becomes a generator from which we can yield from + + tool_success = result.get("status") == "success" + tool_duration = time.time() - tool_start_time + message = Message(role="user", content=[{"toolResult": result}]) + event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) + cycle_trace.add_child(tool_trace) + + if tool_call_span: + tracer.end_tool_call_span(tool_call_span, result) + + return result + + def work( + tool: ToolUse, + worker_id: int, + worker_queue: queue.Queue, + worker_event: threading.Event, + ) -> ToolResult: + events = handle(tool) + + while True: + try: + event = next(events) + worker_queue.put((worker_id, event)) + worker_event.wait() + + except StopIteration as stop: + return cast(ToolResult, stop.value) + + tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] + if parallel_tool_executor: logger.debug( "tool_count=<%s>, tool_executor=<%s> | executing tools in parallel", len(tool_uses), type(parallel_tool_executor).__name__, ) - # Submit all tasks with their associated tools - future_to_tool = { - parallel_tool_executor.submit(_handle_tool_execution, tool_use): tool_use for tool_use in tool_uses - } + + worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue() + worker_events = [threading.Event() for _ in range(len(tool_uses))] + + workers = [ + parallel_tool_executor.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id]) + for worker_id, tool_use in enumerate(tool_uses) + ] logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses)) - # Collect results truly in parallel using the provided executor's as_completed method - completed_results = [] - try: - for future in parallel_tool_executor.as_completed(future_to_tool): - try: - succeeded, result = future.result() - if result is not None: - completed_results.append(result) - if not succeeded: - any_tool_failed = True - except Exception as e: - tool = future_to_tool[future] - logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], e) - any_tool_failed = True - except TimeoutError: - logger.error("timeout_seconds=<%s> | parallel tool execution timed out", parallel_tool_executor.timeout) - # Process any completed tasks - for future in future_to_tool: - if future.done(): # type: ignore - try: - succeeded, result = future.result(timeout=0) - if result is not None: - completed_results.append(result) - except Exception as tool_e: - tool = future_to_tool[future] - logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], tool_e) - else: - # This future didn't complete within the timeout - tool = future_to_tool[future] - logger.debug("tool_name=<%s> | tool execution timed out", tool["name"]) - - any_tool_failed = True - - # Add completed results to tool_results - tool_results.extend(completed_results) + while not all(worker.done() for worker in workers): + if not worker_queue.empty(): + worker_id, event = worker_queue.get() + yield event + worker_events[worker_id].set() + + tool_results.extend([worker.result() for worker in workers]) + else: # Sequential execution fallback for tool_use in tool_uses: - succeeded, result = _handle_tool_execution(tool_use) - if result is not None: - tool_results.append(result) - if not succeeded: - any_tool_failed = True - - return any_tool_failed + result = yield from handle(tool_use) + tool_results.append(result) def validate_and_prepare_tools( message: Message, - tool_uses: List[ToolUse], - tool_results: List[ToolResult], - invalid_tool_use_ids: List[str], + tool_uses: list[ToolUse], + tool_results: list[ToolResult], + invalid_tool_use_ids: list[str], ) -> None: """Validate tool uses and prepare them for execution. diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index bbf4df95b..08ad8dc0d 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -65,6 +65,9 @@ def result(self, timeout: Optional[int] = None) -> Any: Any: The result of the asynchronous operation. """ + def done(self) -> bool: + """Returns true if future is done executing.""" + @runtime_checkable class ParallelToolExecutorInterface(Protocol): diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index 4b2387923..f730f473d 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -54,11 +54,6 @@ def event_loop_metrics(): return strands.telemetry.metrics.EventLoopMetrics() -@pytest.fixture -def request_state(): - return {} - - @pytest.fixture def invalid_tool_use_ids(request): return request.param if hasattr(request, "param") else [] @@ -92,24 +87,22 @@ def test_run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parallel_tool_executor, ) - assert not failed + list(stream) tru_results = tool_results exp_results = [ @@ -132,24 +125,22 @@ def test_run_tools_invalid_tool( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parallel_tool_executor, ) - assert failed + list(stream) tru_results = tool_results exp_results = [] @@ -162,24 +153,22 @@ def test_run_tools_failed_tool( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parallel_tool_executor, ) - assert failed + list(stream) tru_results = tool_results exp_results = [ @@ -222,23 +211,21 @@ def test_run_tools_sequential( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, ): tool_results = [] - failed = strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, None, # parallel_tool_executor ) - assert failed + list(stream) tru_results = tool_results exp_results = [ @@ -311,7 +298,6 @@ def test_run_tools_creates_and_ends_span_on_success( tool_uses, mock_metrics_client, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, @@ -329,17 +315,17 @@ def test_run_tools_creates_and_ends_span_on_success( tool_results = [] # Run the tool - strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parent_span, parallel_tool_executor, ) + list(stream) # Verify span was created with the parent span mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) @@ -359,7 +345,6 @@ def test_run_tools_creates_and_ends_span_on_failure( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, @@ -377,17 +362,17 @@ def test_run_tools_creates_and_ends_span_on_failure( tool_results = [] # Run the tool - strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parent_span, parallel_tool_executor, ) + list(stream) # Verify span was created with the parent span mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], parent_span) @@ -399,96 +384,6 @@ def test_run_tools_creates_and_ends_span_on_failure( assert args[1]["status"] == "failed" -@unittest.mock.patch("strands.tools.executor.get_tracer") -def test_run_tools_handles_exception_in_tool_execution( - mock_get_tracer, - tool_handler, - tool_uses, - event_loop_metrics, - request_state, - invalid_tool_use_ids, - cycle_trace, - parallel_tool_executor, -): - """Test that run_tools properly handles exceptions during tool execution.""" - # Setup mock tracer and span - mock_tracer = unittest.mock.MagicMock() - mock_span = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.return_value = mock_span - mock_get_tracer.return_value = mock_tracer - - # Make the tool handler throw an exception - exception = ValueError("Test tool execution error") - mock_handler = unittest.mock.MagicMock(side_effect=exception) - - tool_results = [] - - # Run the tool - the exception should be caught inside run_tools and not propagate - # because of the try-except block in the new implementation - failed = strands.tools.executor.run_tools( - mock_handler, - tool_uses, - event_loop_metrics, - request_state, - invalid_tool_use_ids, - tool_results, - cycle_trace, - None, - parallel_tool_executor, - ) - - # Tool execution should have failed - assert failed - - # Verify span was created - mock_tracer.start_tool_call_span.assert_called_once() - - # Verify span was ended with the error - mock_tracer.end_span_with_error.assert_called_once_with(mock_span, str(exception), exception) - - -@unittest.mock.patch("strands.tools.executor.get_tracer") -def test_run_tools_with_invalid_tool_use_id_still_creates_span( - mock_get_tracer, - tool_handler, - tool_uses, - event_loop_metrics, - request_state, - cycle_trace, - parallel_tool_executor, -): - """Test that run_tools creates a span even when the tool use ID is invalid.""" - # Setup mock tracer and span - mock_tracer = unittest.mock.MagicMock() - mock_span = unittest.mock.MagicMock() - mock_tracer.start_tool_call_span.return_value = mock_span - mock_get_tracer.return_value = mock_tracer - - # Mark the tool use ID as invalid - invalid_tool_use_ids = [tool_uses[0]["toolUseId"]] - - tool_results = [] - - # Run the tool - strands.tools.executor.run_tools( - tool_handler, - tool_uses, - event_loop_metrics, - request_state, - invalid_tool_use_ids, - tool_results, - cycle_trace, - None, - parallel_tool_executor, - ) - - # Verify span was created - mock_tracer.start_tool_call_span.assert_called_once_with(tool_uses[0], None) - - # Verify span was ended even though the tool wasn't executed - mock_tracer.end_tool_call_span.assert_called_once() - - @unittest.mock.patch("strands.tools.executor.get_tracer") @pytest.mark.parametrize( ("tool_uses", "invalid_tool_use_ids"), @@ -516,7 +411,6 @@ def test_run_tools_parallel_execution_with_spans( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, cycle_trace, parallel_tool_executor, @@ -535,17 +429,17 @@ def test_run_tools_parallel_execution_with_spans( tool_results = [] # Run the tools - strands.tools.executor.run_tools( + stream = strands.tools.executor.run_tools( tool_handler, tool_uses, event_loop_metrics, - request_state, invalid_tool_use_ids, tool_results, cycle_trace, parent_span, parallel_tool_executor, ) + list(stream) # Verify spans were created for both tools assert mock_tracer.start_tool_call_span.call_count == 2