Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from ..tools.executors._executor import ToolExecutor
from ..tools.registry import ToolRegistry
from ..tools.watcher import ToolWatcher
from ..types._events import InitEventLoopEvent
from ..types.agent import AgentInput
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
Expand Down Expand Up @@ -604,7 +605,7 @@ async def _run_loop(
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))

try:
yield {"callback": {"init_event_loop": True, **invocation_state}}
yield InitEventLoopEvent(invocation_state)

for message in messages:
self._append_message(message)
Expand Down
35 changes: 22 additions & 13 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
from ..telemetry.metrics import Trace
from ..telemetry.tracer import get_tracer
from ..tools._validator import validate_and_prepare_tools
from ..types._events import (
EventLoopStopEvent,
EventLoopThrottleEvent,
ForceStopEvent,
ModelMessageEvent,
StartEvent,
StartEventLoopEvent,
ToolResultMessageEvent,
)
from ..types.content import Message
from ..types.exceptions import (
ContextWindowOverflowException,
Expand Down Expand Up @@ -91,8 +100,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes)
invocation_state["event_loop_cycle_trace"] = cycle_trace

yield {"callback": {"start": True}}
yield {"callback": {"start_event_loop": True}}
yield StartEvent()
yield StartEventLoopEvent()

# Create tracer span for this event loop cycle
tracer = get_tracer()
Expand Down Expand Up @@ -175,7 +184,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->

if isinstance(e, ModelThrottledException):
if attempt + 1 == MAX_ATTEMPTS:
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
yield ForceStopEvent(reason=e)
raise e

logger.debug(
Expand All @@ -189,7 +198,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
time.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)

yield {"callback": {"event_loop_throttled_delay": current_delay, **invocation_state}}
yield EventLoopThrottleEvent(delay=current_delay, invocation_state=invocation_state)
else:
raise e

Expand All @@ -201,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
# Add the response message to the conversation
agent.messages.append(message)
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
yield {"callback": {"message": message}}
yield ModelMessageEvent(message=message)

# Update metrics
agent.event_loop_metrics.update_usage(usage)
Expand Down Expand Up @@ -235,8 +244,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
cycle_start_time=cycle_start_time,
invocation_state=invocation_state,
)
async for event in events:
yield event
async for typed_event in events:
yield typed_event

return

Expand Down Expand Up @@ -264,11 +273,11 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
tracer.end_span_with_error(cycle_span, str(e), e)

# Handle any other exceptions
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
yield ForceStopEvent(reason=e)
logger.exception("cycle failed")
raise EventLoopException(e, invocation_state["request_state"]) from e

yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])


async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
Expand All @@ -295,7 +304,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -
recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id)
cycle_trace.add_child(recursive_trace)

yield {"callback": {"start": True}}
yield StartEvent()

events = event_loop_cycle(agent=agent, invocation_state=invocation_state)
async for event in events:
Expand Down Expand Up @@ -339,7 +348,7 @@ async def _handle_tool_execution(
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
if not tool_uses:
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
return

tool_events = agent.tool_executor._execute(
Expand All @@ -358,15 +367,15 @@ async def _handle_tool_execution(

agent.messages.append(tool_result_message)
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
yield {"callback": {"message": tool_result_message}}
yield ToolResultMessageEvent(message=message)

if cycle_span:
tracer = get_tracer()
tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message)

if invocation_state["request_state"].get("stop_event_loop", False):
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
return

events = recurse_event_loop(agent=agent, invocation_state=invocation_state)
Expand Down
50 changes: 29 additions & 21 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
from typing import Any, AsyncGenerator, AsyncIterable, Optional

from ..models.model import Model
from ..types._events import (
ModelStopReason,
ModelStreamChunkEvent,
ModelStreamEvent,
ReasoningSignatureStreamEvent,
ReasoningTextStreamEvent,
TextStreamEvent,
ToolUseStreamEvent,
TypedEvent,
)
from ..types.content import ContentBlock, Message, Messages
from ..types.streaming import (
ContentBlockDeltaEvent,
Expand Down Expand Up @@ -105,7 +115,7 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]:

def handle_content_block_delta(
event: ContentBlockDeltaEvent, state: dict[str, Any]
) -> tuple[dict[str, Any], dict[str, Any]]:
) -> tuple[dict[str, Any], ModelStreamEvent]:
"""Handles content block delta updates by appending text, tool input, or reasoning content to the state.

Args:
Expand All @@ -117,43 +127,41 @@ def handle_content_block_delta(
"""
delta_content = event["delta"]

callback_event = {}
typed_event: ModelStreamEvent = ModelStreamEvent({})

if "toolUse" in delta_content:
if "input" not in state["current_tool_use"]:
state["current_tool_use"]["input"] = ""

state["current_tool_use"]["input"] += delta_content["toolUse"]["input"]
callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]}
typed_event = ToolUseStreamEvent(delta_content, state["current_tool_use"])

elif "text" in delta_content:
state["text"] += delta_content["text"]
callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content}
typed_event = TextStreamEvent(text=delta_content["text"], delta=delta_content)

elif "reasoningContent" in delta_content:
if "text" in delta_content["reasoningContent"]:
if "reasoningText" not in state:
state["reasoningText"] = ""

state["reasoningText"] += delta_content["reasoningContent"]["text"]
callback_event["callback"] = {
"reasoningText": delta_content["reasoningContent"]["text"],
"delta": delta_content,
"reasoning": True,
}
typed_event = ReasoningTextStreamEvent(
reasoning_text=delta_content["reasoningContent"]["text"],
delta=delta_content,
)

elif "signature" in delta_content["reasoningContent"]:
if "signature" not in state:
state["signature"] = ""

state["signature"] += delta_content["reasoningContent"]["signature"]
callback_event["callback"] = {
"reasoning_signature": delta_content["reasoningContent"]["signature"],
"delta": delta_content,
"reasoning": True,
}
typed_event = ReasoningSignatureStreamEvent(
reasoning_signature=delta_content["reasoningContent"]["signature"],
delta=delta_content,
)

return state, callback_event
return state, typed_event


def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -251,7 +259,7 @@ def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]:
return usage, metrics


async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[dict[str, Any], None]:
async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[TypedEvent, None]:
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.

Args:
Expand All @@ -274,14 +282,14 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
metrics: Metrics = Metrics(latencyMs=0)

async for chunk in chunks:
yield {"callback": {"event": chunk}}
yield ModelStreamChunkEvent(chunk=chunk)
if "messageStart" in chunk:
state["message"] = handle_message_start(chunk["messageStart"], state["message"])
elif "contentBlockStart" in chunk:
state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"])
elif "contentBlockDelta" in chunk:
state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state)
yield callback_event
state, typed_event = handle_content_block_delta(chunk["contentBlockDelta"], state)
yield typed_event
elif "contentBlockStop" in chunk:
state = handle_content_block_stop(state)
elif "messageStop" in chunk:
Expand All @@ -291,15 +299,15 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
elif "redactContent" in chunk:
handle_redact_content(chunk["redactContent"], state)

yield {"stop": (stop_reason, state["message"], usage, metrics)}
yield ModelStopReason(stop_reason=stop_reason, message=state["message"], usage=usage, metrics=metrics)


async def stream_messages(
model: Model,
system_prompt: Optional[str],
messages: Messages,
tool_specs: list[ToolSpec],
) -> AsyncGenerator[dict[str, Any], None]:
) -> AsyncGenerator[TypedEvent, None]:
"""Streams messages to the model and processes the response.

Args:
Expand Down
Loading
Loading