Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def caller(
system_prompt=self._agent.system_prompt,
messages=self._agent.messages,
tool_config=self._agent.tool_config,
callback_handler=self._agent.callback_handler,
kwargs=kwargs,
)

Expand Down Expand Up @@ -359,7 +358,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
self._start_agent_trace_span(prompt)

try:
events = self._run_loop(callback_handler, prompt, kwargs)
events = self._run_loop(prompt, kwargs)
for event in events:
if "callback" in event:
callback_handler(**event["callback"])
Expand Down Expand Up @@ -441,7 +440,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
self._start_agent_trace_span(prompt)

try:
events = self._run_loop(callback_handler, prompt, kwargs)
events = self._run_loop(prompt, kwargs)
for event in events:
if "callback" in event:
callback_handler(**event["callback"])
Expand All @@ -456,9 +455,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
self._end_agent_trace_span(error=e)
raise

def _run_loop(
self, callback_handler: Callable[..., Any], prompt: str, kwargs: dict[str, Any]
) -> Generator[dict[str, Any], None, None]:
def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
"""Execute the agent's event loop with the given prompt and parameters."""
try:
# Extract key parameters
Expand All @@ -470,14 +467,12 @@ def _run_loop(
self.messages.append(new_message)

# Execute the event loop cycle with retry logic for context limits
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
yield from self._execute_event_loop_cycle(kwargs)

finally:
self.conversation_manager.apply_management(self)

def _execute_event_loop_cycle(
self, callback_handler: Callable[..., Any], kwargs: dict[str, Any]
) -> Generator[dict[str, Any], None, None]:
def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
"""Execute the event loop cycle with retry logic for context window limits.

This internal method handles the execution of the event loop cycle and implements
Expand All @@ -497,7 +492,6 @@ def _execute_event_loop_cycle(
system_prompt=self.system_prompt,
messages=self.messages, # will be modified by event_loop_cycle
tool_config=self.tool_config,
callback_handler=callback_handler,
tool_handler=self.tool_handler,
tool_execution_handler=self.thread_pool_wrapper,
event_loop_metrics=self.event_loop_metrics,
Expand All @@ -508,7 +502,7 @@ def _execute_event_loop_cycle(
except ContextWindowOverflowException as e:
# Try reducing the context size and retrying
self.conversation_manager.reduce_context(self, e=e)
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
yield from self._execute_event_loop_cycle(kwargs)

def _record_tool_execution(
self,
Expand Down
12 changes: 1 addition & 11 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
import uuid
from functools import partial
from typing import Any, Callable, Generator, Optional, cast
from typing import Any, Generator, Optional, cast

from opentelemetry import trace

Expand Down Expand Up @@ -40,7 +40,6 @@ def event_loop_cycle(
system_prompt: Optional[str],
messages: Messages,
tool_config: Optional[ToolConfig],
callback_handler: Callable[..., Any],
tool_handler: Optional[ToolHandler],
tool_execution_handler: Optional[ParallelToolExecutorInterface],
event_loop_metrics: EventLoopMetrics,
Expand All @@ -65,7 +64,6 @@ def event_loop_cycle(
system_prompt: System prompt instructions for the model.
messages: Conversation history messages.
tool_config: Configuration for available tools.
callback_handler: Callback for processing events as they happen.
tool_handler: Handler for executing tools.
tool_execution_handler: Optional handler for parallel tool execution.
event_loop_metrics: Metrics tracking object for the event loop.
Expand Down Expand Up @@ -212,7 +210,6 @@ def event_loop_cycle(
messages,
tool_config,
tool_handler,
callback_handler,
tool_execution_handler,
event_loop_metrics,
event_loop_parent_span,
Expand Down Expand Up @@ -258,7 +255,6 @@ def recurse_event_loop(
system_prompt: Optional[str],
messages: Messages,
tool_config: Optional[ToolConfig],
callback_handler: Callable[..., Any],
tool_handler: Optional[ToolHandler],
tool_execution_handler: Optional[ParallelToolExecutorInterface],
event_loop_metrics: EventLoopMetrics,
Expand All @@ -274,7 +270,6 @@ def recurse_event_loop(
system_prompt: System prompt instructions for the model
messages: Conversation history messages
tool_config: Configuration for available tools
callback_handler: Callback for processing events as they happen
tool_handler: Handler for tool execution
tool_execution_handler: Optional handler for parallel tool execution.
event_loop_metrics: Metrics tracking object for the event loop.
Expand Down Expand Up @@ -302,7 +297,6 @@ def recurse_event_loop(
system_prompt=system_prompt,
messages=messages,
tool_config=tool_config,
callback_handler=callback_handler,
tool_handler=tool_handler,
tool_execution_handler=tool_execution_handler,
event_loop_metrics=event_loop_metrics,
Expand All @@ -321,7 +315,6 @@ def _handle_tool_execution(
messages: Messages,
tool_config: ToolConfig,
tool_handler: ToolHandler,
callback_handler: Callable[..., Any],
tool_execution_handler: Optional[ParallelToolExecutorInterface],
event_loop_metrics: EventLoopMetrics,
event_loop_parent_span: Optional[trace.Span],
Expand All @@ -345,7 +338,6 @@ def _handle_tool_execution(
messages (Messages): The conversation history messages.
tool_config (ToolConfig): Configuration for available tools.
tool_handler (ToolHandler): Handler for tool execution.
callback_handler (Callable[..., Any]): Callback for processing events as they happen.
tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution.
event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop.
event_loop_parent_span (Any): Span for the parent of this event loop.
Expand Down Expand Up @@ -374,7 +366,6 @@ def _handle_tool_execution(
system_prompt=system_prompt,
messages=messages,
tool_config=tool_config,
callback_handler=callback_handler,
kwargs=kwargs,
)

Expand Down Expand Up @@ -415,7 +406,6 @@ def _handle_tool_execution(
system_prompt=system_prompt,
messages=messages,
tool_config=tool_config,
callback_handler=callback_handler,
tool_handler=tool_handler,
tool_execution_handler=tool_execution_handler,
event_loop_metrics=event_loop_metrics,
Expand Down
3 changes: 0 additions & 3 deletions src/strands/handlers/tool_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def process(
system_prompt: Optional[str],
messages: Messages,
tool_config: ToolConfig,
callback_handler: Any,
kwargs: dict[str, Any],
) -> Any:
"""Process a tool invocation.
Expand All @@ -47,7 +46,6 @@ def process(
system_prompt: The system prompt for the agent.
messages: The conversation history.
tool_config: Configuration for the tool.
callback_handler: Callback for processing events as they happen.
kwargs: Additional keyword arguments passed to the tool.

Returns:
Expand Down Expand Up @@ -81,7 +79,6 @@ def process(
"system_prompt": system_prompt,
"messages": messages,
"tool_config": tool_config,
"callback_handler": callback_handler,
}
)

Expand Down
7 changes: 4 additions & 3 deletions src/strands/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import base64
import json
import logging
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Type, TypeVar, Union

from mistralai import Mistral
from pydantic import BaseModel
Expand Down Expand Up @@ -471,14 +471,15 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:

@override
def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
self,
output_model: Type[T],
prompt: Messages,
) -> Generator[dict[str, Union[T, Any]], None, None]:
"""Get structured output from the model.

Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
callback_handler: Optional callback handler for processing events.

Returns:
An instance of the output model with the generated data.
Expand Down
2 changes: 0 additions & 2 deletions src/strands/types/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ def process(
system_prompt: Optional[str],
messages: "Messages",
tool_config: ToolConfig,
callback_handler: Any,
kwargs: dict[str, Any],
) -> ToolResult:
"""Process a tool use request and execute the tool.
Expand All @@ -260,7 +259,6 @@ def process(
model: The model being used for the conversation.
system_prompt: The system prompt for the conversation.
tool_config: The tool configuration for the current session.
callback_handler: Callback for processing events as they happen.
kwargs: Additional context-specific arguments.

Returns:
Expand Down
13 changes: 2 additions & 11 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,6 @@ def function(system_prompt: str) -> str:
system_prompt="You are a helpful assistant.",
messages=unittest.mock.ANY,
tool_config=unittest.mock.ANY,
callback_handler=unittest.mock.ANY,
kwargs={"system_prompt": "tool prompt"},
)

Expand Down Expand Up @@ -1075,18 +1074,10 @@ async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_trac
mock_tracer.start_agent_span.return_value = mock_span
mock_get_tracer.return_value = mock_tracer

# Define the side effect to simulate callback handler being called multiple times
def call_callback_handler(*args, **kwargs):
# Extract the callback handler from kwargs
callback_handler = kwargs.get("callback_handler")
# Call the callback handler with different data values
callback_handler(data="First chunk")
callback_handler(data="Second chunk")
callback_handler(data="Final chunk", complete=True)
# Return expected values from event_loop_cycle
def test_event_loop(*args, **kwargs):
yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {})}

mock_event_loop_cycle.side_effect = call_callback_handler
mock_event_loop_cycle.side_effect = test_event_loop

# Create agent and make a call
agent = Agent(model=mock_model)
Expand Down
Loading