From 418622fca47c5ec94bb43242e61d9fe821ff27a7 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 1 Jul 2025 13:57:05 +0000 Subject: [PATCH] stop passing around callback handler --- src/strands/agent/agent.py | 18 ++++------ src/strands/event_loop/event_loop.py | 12 +------ src/strands/handlers/tool_handler.py | 3 -- src/strands/models/mistral.py | 7 ++-- src/strands/types/tools.py | 2 -- tests/strands/agent/test_agent.py | 13 ++----- tests/strands/event_loop/test_event_loop.py | 40 --------------------- tests/strands/handlers/test_tool_handler.py | 2 -- 8 files changed, 13 insertions(+), 84 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 2860fb626..7e809de1e 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -133,7 +133,6 @@ def caller( system_prompt=self._agent.system_prompt, messages=self._agent.messages, tool_config=self._agent.tool_config, - callback_handler=self._agent.callback_handler, kwargs=kwargs, ) @@ -359,7 +358,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: self._start_agent_trace_span(prompt) try: - events = self._run_loop(callback_handler, prompt, kwargs) + events = self._run_loop(prompt, kwargs) for event in events: if "callback" in event: callback_handler(**event["callback"]) @@ -441,7 +440,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: self._start_agent_trace_span(prompt) try: - events = self._run_loop(callback_handler, prompt, kwargs) + events = self._run_loop(prompt, kwargs) for event in events: if "callback" in event: callback_handler(**event["callback"]) @@ -456,9 +455,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: self._end_agent_trace_span(error=e) raise - def _run_loop( - self, callback_handler: Callable[..., Any], prompt: str, kwargs: dict[str, Any] - ) -> Generator[dict[str, Any], None, None]: + def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]: """Execute the agent's event loop with the given prompt and parameters.""" try: # Extract key parameters @@ -470,14 +467,12 @@ def _run_loop( self.messages.append(new_message) # Execute the event loop cycle with retry logic for context limits - yield from self._execute_event_loop_cycle(callback_handler, kwargs) + yield from self._execute_event_loop_cycle(kwargs) finally: self.conversation_manager.apply_management(self) - def _execute_event_loop_cycle( - self, callback_handler: Callable[..., Any], kwargs: dict[str, Any] - ) -> Generator[dict[str, Any], None, None]: + def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements @@ -497,7 +492,6 @@ def _execute_event_loop_cycle( system_prompt=self.system_prompt, messages=self.messages, # will be modified by event_loop_cycle tool_config=self.tool_config, - callback_handler=callback_handler, tool_handler=self.tool_handler, tool_execution_handler=self.thread_pool_wrapper, event_loop_metrics=self.event_loop_metrics, @@ -508,7 +502,7 @@ def _execute_event_loop_cycle( except ContextWindowOverflowException as e: # Try reducing the context size and retrying self.conversation_manager.reduce_context(self, e=e) - yield from self._execute_event_loop_cycle(callback_handler, kwargs) + yield from self._execute_event_loop_cycle(kwargs) def _record_tool_execution( self, diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index bb45358a0..82c3ef176 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -12,7 +12,7 @@ import time import uuid from functools import partial -from typing import Any, Callable, Generator, Optional, cast +from typing import Any, Generator, Optional, cast from opentelemetry import trace @@ -40,7 +40,6 @@ def event_loop_cycle( system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], - callback_handler: Callable[..., Any], tool_handler: Optional[ToolHandler], tool_execution_handler: Optional[ParallelToolExecutorInterface], event_loop_metrics: EventLoopMetrics, @@ -65,7 +64,6 @@ def event_loop_cycle( system_prompt: System prompt instructions for the model. messages: Conversation history messages. tool_config: Configuration for available tools. - callback_handler: Callback for processing events as they happen. tool_handler: Handler for executing tools. tool_execution_handler: Optional handler for parallel tool execution. event_loop_metrics: Metrics tracking object for the event loop. @@ -212,7 +210,6 @@ def event_loop_cycle( messages, tool_config, tool_handler, - callback_handler, tool_execution_handler, event_loop_metrics, event_loop_parent_span, @@ -258,7 +255,6 @@ def recurse_event_loop( system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], - callback_handler: Callable[..., Any], tool_handler: Optional[ToolHandler], tool_execution_handler: Optional[ParallelToolExecutorInterface], event_loop_metrics: EventLoopMetrics, @@ -274,7 +270,6 @@ def recurse_event_loop( system_prompt: System prompt instructions for the model messages: Conversation history messages tool_config: Configuration for available tools - callback_handler: Callback for processing events as they happen tool_handler: Handler for tool execution tool_execution_handler: Optional handler for parallel tool execution. event_loop_metrics: Metrics tracking object for the event loop. @@ -302,7 +297,6 @@ def recurse_event_loop( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=event_loop_metrics, @@ -321,7 +315,6 @@ def _handle_tool_execution( messages: Messages, tool_config: ToolConfig, tool_handler: ToolHandler, - callback_handler: Callable[..., Any], tool_execution_handler: Optional[ParallelToolExecutorInterface], event_loop_metrics: EventLoopMetrics, event_loop_parent_span: Optional[trace.Span], @@ -345,7 +338,6 @@ def _handle_tool_execution( messages (Messages): The conversation history messages. tool_config (ToolConfig): Configuration for available tools. tool_handler (ToolHandler): Handler for tool execution. - callback_handler (Callable[..., Any]): Callback for processing events as they happen. tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution. event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop. event_loop_parent_span (Any): Span for the parent of this event loop. @@ -374,7 +366,6 @@ def _handle_tool_execution( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, kwargs=kwargs, ) @@ -415,7 +406,6 @@ def _handle_tool_execution( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=event_loop_metrics, diff --git a/src/strands/handlers/tool_handler.py b/src/strands/handlers/tool_handler.py index 21bd6c4fc..9d96202b6 100644 --- a/src/strands/handlers/tool_handler.py +++ b/src/strands/handlers/tool_handler.py @@ -34,7 +34,6 @@ def process( system_prompt: Optional[str], messages: Messages, tool_config: ToolConfig, - callback_handler: Any, kwargs: dict[str, Any], ) -> Any: """Process a tool invocation. @@ -47,7 +46,6 @@ def process( system_prompt: The system prompt for the agent. messages: The conversation history. tool_config: Configuration for the tool. - callback_handler: Callback for processing events as they happen. kwargs: Additional keyword arguments passed to the tool. Returns: @@ -81,7 +79,6 @@ def process( "system_prompt": system_prompt, "messages": messages, "tool_config": tool_config, - "callback_handler": callback_handler, } ) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 2726dd348..3d44cbe23 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -6,7 +6,7 @@ import base64 import json import logging -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, TypeVar, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Type, TypeVar, Union from mistralai import Mistral from pydantic import BaseModel @@ -471,14 +471,15 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: @override def structured_output( - self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None + self, + output_model: Type[T], + prompt: Messages, ) -> Generator[dict[str, Union[T, Any]], None, None]: """Get structured output from the model. Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. - callback_handler: Optional callback handler for processing events. Returns: An instance of the output model with the generated data. diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index ab4b7ca2f..aff22f157 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -249,7 +249,6 @@ def process( system_prompt: Optional[str], messages: "Messages", tool_config: ToolConfig, - callback_handler: Any, kwargs: dict[str, Any], ) -> ToolResult: """Process a tool use request and execute the tool. @@ -260,7 +259,6 @@ def process( model: The model being used for the conversation. system_prompt: The system prompt for the conversation. tool_config: The tool configuration for the current session. - callback_handler: Callback for processing events as they happen. kwargs: Additional context-specific arguments. Returns: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 7100b7c82..5f619c8da 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -795,7 +795,6 @@ def function(system_prompt: str) -> str: system_prompt="You are a helpful assistant.", messages=unittest.mock.ANY, tool_config=unittest.mock.ANY, - callback_handler=unittest.mock.ANY, kwargs={"system_prompt": "tool prompt"}, ) @@ -1075,18 +1074,10 @@ async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_trac mock_tracer.start_agent_span.return_value = mock_span mock_get_tracer.return_value = mock_tracer - # Define the side effect to simulate callback handler being called multiple times - def call_callback_handler(*args, **kwargs): - # Extract the callback handler from kwargs - callback_handler = kwargs.get("callback_handler") - # Call the callback handler with different data values - callback_handler(data="First chunk") - callback_handler(data="Second chunk") - callback_handler(data="Final chunk", complete=True) - # Return expected values from event_loop_cycle + def test_event_loop(*args, **kwargs): yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {})} - mock_event_loop_cycle.side_effect = call_callback_handler + mock_event_loop_cycle.side_effect = test_event_loop # Create agent and make a call agent = Agent(model=mock_model) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 46884c64e..0e0b0b682 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -38,11 +38,6 @@ def tool_config(): return {"tools": [{"toolSpec": {"name": "tool_for_testing"}}], "toolChoice": {"auto": {}}} -@pytest.fixture -def callback_handler(): - return unittest.mock.Mock() - - @pytest.fixture def tool_registry(): return ToolRegistry() @@ -111,7 +106,6 @@ def test_event_loop_cycle_text_response( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, ): @@ -125,7 +119,6 @@ def test_event_loop_cycle_text_response( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -148,7 +141,6 @@ def test_event_loop_cycle_text_response_throttling( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, ): @@ -165,7 +157,6 @@ def test_event_loop_cycle_text_response_throttling( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -190,7 +181,6 @@ def test_event_loop_cycle_exponential_backoff( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, ): @@ -211,7 +201,6 @@ def test_event_loop_cycle_exponential_backoff( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -238,7 +227,6 @@ def test_event_loop_cycle_text_response_throttling_exceeded( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, ): @@ -257,7 +245,6 @@ def test_event_loop_cycle_text_response_throttling_exceeded( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -282,7 +269,6 @@ def test_event_loop_cycle_text_response_error( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, ): @@ -294,7 +280,6 @@ def test_event_loop_cycle_text_response_error( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -309,7 +294,6 @@ def test_event_loop_cycle_tool_result( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, tool_stream, @@ -327,7 +311,6 @@ def test_event_loop_cycle_tool_result( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -382,7 +365,6 @@ def test_event_loop_cycle_tool_result_error( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, tool_stream, @@ -395,7 +377,6 @@ def test_event_loop_cycle_tool_result_error( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -410,7 +391,6 @@ def test_event_loop_cycle_tool_result_no_tool_handler( system_prompt, messages, tool_config, - callback_handler, tool_execution_handler, tool_stream, ): @@ -422,7 +402,6 @@ def test_event_loop_cycle_tool_result_no_tool_handler( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=None, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -436,7 +415,6 @@ def test_event_loop_cycle_tool_result_no_tool_config( model, system_prompt, messages, - callback_handler, tool_handler, tool_execution_handler, tool_stream, @@ -449,7 +427,6 @@ def test_event_loop_cycle_tool_result_no_tool_config( system_prompt=system_prompt, messages=messages, tool_config=None, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -464,7 +441,6 @@ def test_event_loop_cycle_stop( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, tool, @@ -491,7 +467,6 @@ def test_event_loop_cycle_stop( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -524,7 +499,6 @@ def test_cycle_exception( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, tool_stream, @@ -540,7 +514,6 @@ def test_cycle_exception( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -560,7 +533,6 @@ def test_event_loop_cycle_creates_spans( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, mock_tracer, @@ -583,7 +555,6 @@ def test_event_loop_cycle_creates_spans( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -607,7 +578,6 @@ def test_event_loop_tracing_with_model_error( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, mock_tracer, @@ -629,7 +599,6 @@ def test_event_loop_tracing_with_model_error( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -649,7 +618,6 @@ def test_event_loop_tracing_with_tool_execution( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, tool_stream, @@ -677,7 +645,6 @@ def test_event_loop_tracing_with_tool_execution( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -699,7 +666,6 @@ def test_event_loop_tracing_with_throttling_exception( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, mock_tracer, @@ -727,7 +693,6 @@ def test_event_loop_tracing_with_throttling_exception( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -750,7 +715,6 @@ def test_event_loop_cycle_with_parent_span( system_prompt, messages, tool_config, - callback_handler, tool_handler, tool_execution_handler, mock_tracer, @@ -772,7 +736,6 @@ def test_event_loop_cycle_with_parent_span( system_prompt=system_prompt, messages=messages, tool_config=tool_config, - callback_handler=callback_handler, tool_handler=tool_handler, tool_execution_handler=tool_execution_handler, event_loop_metrics=EventLoopMetrics(), @@ -794,7 +757,6 @@ def test_request_state_initialization(): system_prompt=MagicMock(), messages=MagicMock(), tool_config=MagicMock(), - callback_handler=MagicMock(), tool_handler=MagicMock(), tool_execution_handler=MagicMock(), event_loop_metrics=EventLoopMetrics(), @@ -814,7 +776,6 @@ def test_request_state_initialization(): system_prompt=MagicMock(), messages=MagicMock(), tool_config=MagicMock(), - callback_handler=MagicMock(), tool_handler=MagicMock(), tool_execution_handler=MagicMock(), event_loop_metrics=EventLoopMetrics(), @@ -855,7 +816,6 @@ def test_prepare_next_cycle_in_tool_execution(model, tool_stream): system_prompt=MagicMock(), messages=MagicMock(), tool_config=MagicMock(), - callback_handler=MagicMock(), tool_handler=MagicMock(), tool_execution_handler=MagicMock(), event_loop_metrics=EventLoopMetrics(), diff --git a/tests/strands/handlers/test_tool_handler.py b/tests/strands/handlers/test_tool_handler.py index 3e263cd9e..4ae59f432 100644 --- a/tests/strands/handlers/test_tool_handler.py +++ b/tests/strands/handlers/test_tool_handler.py @@ -47,7 +47,6 @@ def test_process(tool_handler, tool_use_identity): system_prompt="p1", messages=[], tool_config={}, - callback_handler=unittest.mock.Mock(), kwargs={}, ) exp_result = {"toolUseId": "identity", "status": "success", "content": [{"text": "1"}]} @@ -62,7 +61,6 @@ def test_process_missing_tool(tool_handler): system_prompt="p1", messages=[], tool_config={}, - callback_handler=unittest.mock.Mock(), kwargs={}, ) exp_result = {