Skip to content

Commit 7c5dbc8

Browse files
committed
interrupts - swarm - from agent
1 parent b4efc9d commit 7c5dbc8

File tree

5 files changed

+111
-31
lines changed

5 files changed

+111
-31
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,8 @@ async def _handle_tool_execution(
483483

484484
if interrupts:
485485
# Session state stored on AfterInvocationEvent.
486-
agent._interrupt_state.activate(context={"tool_use_message": message, "tool_results": tool_results})
486+
agent._interrupt_state.context = {"tool_use_message": message, "tool_results": tool_results}
487+
agent._interrupt_state.activate()
487488

488489
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
489490
yield EventLoopStopEvent(

src/strands/interrupt.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,8 @@ class _InterruptState:
5353
context: dict[str, Any] = field(default_factory=dict)
5454
activated: bool = False
5555

56-
def activate(self, context: dict[str, Any] | None = None) -> None:
57-
"""Activate the interrupt state.
58-
59-
Args:
60-
context: Context associated with the interrupt event.
61-
"""
62-
self.context = context or {}
56+
def activate(self) -> None:
57+
"""Activate the interrupt state."""
6358
self.activated = True
6459

6560
def deactivate(self) -> None:

src/strands/multiagent/base.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from abc import ABC, abstractmethod
99
from dataclasses import dataclass, field
1010
from enum import Enum
11-
from typing import Any, AsyncIterator, Union
11+
from typing import Any, AsyncIterator, Sequence, Union
1212

1313
from .._async import run_async
1414
from ..agent import AgentResult
15+
from ..interrupt import Interrupt
1516
from ..types.content import ContentBlock
1617
from ..types.event_loop import Metrics, Usage
1718

@@ -25,16 +26,12 @@ class Status(Enum):
2526
EXECUTING = "executing"
2627
COMPLETED = "completed"
2728
FAILED = "failed"
29+
INTERRUPTED = "interrupted"
2830

2931

3032
@dataclass
3133
class NodeResult:
32-
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.
33-
34-
The status field represents the semantic outcome of the node's work:
35-
- COMPLETED: The node's task was successfully accomplished
36-
- FAILED: The node's task failed or produced an error
37-
"""
34+
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results."""
3835

3936
# Core result data - single AgentResult, nested MultiAgentResult, or Exception
4037
result: Union[AgentResult, "MultiAgentResult", Exception]
@@ -47,6 +44,7 @@ class NodeResult:
4744
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
4845
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
4946
execution_count: int = 0
47+
interrupts: Sequence[Interrupt] | None = None
5048

5149
def get_agent_results(self) -> list[AgentResult]:
5250
"""Get all AgentResult objects from this node, flattened if nested."""
@@ -78,6 +76,7 @@ def to_dict(self) -> dict[str, Any]:
7876
"accumulated_usage": self.accumulated_usage,
7977
"accumulated_metrics": self.accumulated_metrics,
8078
"execution_count": self.execution_count,
79+
"interrupts": [interrupt.to_dict() for interrupt in self.interrupts],
8180
}
8281

8382
@classmethod
@@ -99,6 +98,11 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult":
9998

10099
usage = _parse_usage(data.get("accumulated_usage", {}))
101100
metrics = _parse_metrics(data.get("accumulated_metrics", {}))
101+
102+
# Parse interrupts
103+
interrupts = []
104+
for interrupt_data in data.get("interrupts", []):
105+
interrupts.append(Interrupt.from_dict(interrupt_data))
102106

103107
return cls(
104108
result=result,
@@ -107,6 +111,7 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult":
107111
accumulated_usage=usage,
108112
accumulated_metrics=metrics,
109113
execution_count=int(data.get("execution_count", 0)),
114+
interrupts=interrupts,
110115
)
111116

112117

@@ -125,6 +130,7 @@ class MultiAgentResult:
125130
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
126131
execution_count: int = 0
127132
execution_time: int = 0
133+
interrupts: list[Interrupt] = field(default_factory=list)
128134

129135
@classmethod
130136
def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":

src/strands/multiagent/swarm.py

Lines changed: 92 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
MultiAgentInitializedEvent,
3434
)
3535
from ..hooks import HookProvider, HookRegistry
36+
from ..interrupt import Interrupt, _InterruptState
3637
from ..session import SessionManager
3738
from ..telemetry import get_tracer
3839
from ..tools.decorator import tool
@@ -44,6 +45,7 @@
4445
MultiAgentResultEvent,
4546
)
4647
from ..types.content import ContentBlock, Messages
48+
from ..types.interrupt import InterruptResponseContent
4749
from ..types.event_loop import Metrics, Usage
4850
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
4951

@@ -145,7 +147,7 @@ class SwarmState:
145147
"""Current state of swarm execution."""
146148

147149
current_node: SwarmNode | None # The agent currently executing
148-
task: str | list[ContentBlock] # The original task from the user that is being executed
150+
task: str | list[ContentBlock] | list[InterruptResponseContent] # The original task from the user that is being executed
149151
completion_status: Status = Status.PENDING # Current swarm execution status
150152
shared_context: SharedContext = field(default_factory=SharedContext) # Context shared between agents
151153
node_history: list[SwarmNode] = field(default_factory=list) # Complete history of agents that have executed
@@ -255,11 +257,14 @@ def __init__(
255257

256258
self.shared_context = SharedContext()
257259
self.nodes: dict[str, SwarmNode] = {}
260+
258261
self.state = SwarmState(
259262
current_node=None, # Placeholder, will be set properly
260263
task="",
261264
completion_status=Status.PENDING,
262265
)
266+
self._interrupt_state = _InterruptState()
267+
263268
self.tracer = get_tracer()
264269

265270
self.session_manager = session_manager
@@ -277,7 +282,9 @@ def __init__(
277282
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
278283

279284
def __call__(
280-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
285+
self,
286+
task: str | list[ContentBlock] | list[InterruptResponseContent],
287+
invocation_state: dict[str, Any] | None = None, **kwargs: Any,
281288
) -> SwarmResult:
282289
"""Invoke the swarm synchronously.
283290
@@ -292,7 +299,9 @@ def __call__(
292299
return run_async(lambda: self.invoke_async(task, invocation_state))
293300

294301
async def invoke_async(
295-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
302+
self,
303+
task: str | list[ContentBlock] | list[InterruptResponseContent],
304+
invocation_state: dict[str, Any] | None = None, **kwargs: Any,
296305
) -> SwarmResult:
297306
"""Invoke the swarm asynchronously.
298307
@@ -316,7 +325,9 @@ async def invoke_async(
316325
return cast(SwarmResult, final_event["result"])
317326

318327
async def stream_async(
319-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
328+
self,
329+
task: str | list[ContentBlock] | list[InterruptResponseContent],
330+
invocation_state: dict[str, Any] | None = None, **kwargs: Any,
320331
) -> AsyncIterator[dict[str, Any]]:
321332
"""Stream events during swarm execution.
322333
@@ -334,6 +345,8 @@ async def stream_async(
334345
- multi_agent_node_stop: When a node stops execution
335346
- result: Final swarm result
336347
"""
348+
self._interrupt_state.resume(task)
349+
337350
if invocation_state is None:
338351
invocation_state = {}
339352

@@ -644,6 +657,36 @@ def _build_node_input(self, target_node: SwarmNode) -> str:
644657

645658
return context_text
646659

660+
def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> Any:
661+
"""Activate the interrupt state.
662+
663+
Args:
664+
node: The interrupted node.
665+
interrupts: The interrupts raised by the user.
666+
667+
Returns:
668+
MultiAgentNodeInterruptEvent
669+
"""
670+
671+
logger.debug("node=<%s> | node interrupted", node.node_id)
672+
self.state.completion_status = Status.INTERRUPTED
673+
674+
self._interrupt_state.context.update(
675+
{
676+
node.node_id: {
677+
"interrupt_state": node.executor._interrupt_state.to_dict(),
678+
"state": node.executor.state.get(),
679+
"messages": node.executor.messages
680+
}
681+
}
682+
)
683+
self._interrupt_state.activate()
684+
685+
# return MultiAgentNodeInterruptEvent(
686+
# node_id=node.node_id,
687+
# interrupts=interrupts,
688+
# )
689+
647690
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
648691
"""Execute swarm and yield TypedEvent objects."""
649692
try:
@@ -680,9 +723,13 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
680723

681724
# TODO: Implement cancellation token to stop _execute_node from continuing
682725
try:
683-
await self.hooks.invoke_callbacks_async(
684-
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
726+
_, interrupts = await self.hooks.invoke_callbacks_async(
727+
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
685728
)
729+
if interrupts:
730+
yield self._activate_interrupt(current_node, interrupts)
731+
break
732+
686733
node_stream = self._stream_with_timeout(
687734
self._execute_node(current_node, self.state.task, invocation_state),
688735
self.node_timeout,
@@ -691,6 +738,14 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
691738
async for event in node_stream:
692739
yield event
693740

741+
stop_event = cast(MultiAgentNodeStopEvent, event)
742+
node_result = stop_event["node_result"]
743+
if node_result.status == Status.INTERRUPTED:
744+
self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in node_result.interrupts})
745+
# yield self._activate_interrupt(current_node, node_result.interrupts)
746+
self._activate_interrupt(current_node, node_result.interrupts)
747+
break
748+
694749
self.state.node_history.append(current_node)
695750
await self.hooks.invoke_callbacks_async(
696751
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
@@ -741,7 +796,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
741796
)
742797

743798
async def _execute_node(
744-
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any]
799+
self,
800+
node: SwarmNode,
801+
task: str | list[ContentBlock] | list[InterruptResponseContent],
802+
invocation_state: dict[str, Any],
745803
) -> AsyncIterator[Any]:
746804
"""Execute swarm node and yield TypedEvent objects."""
747805
start_time = time.time()
@@ -763,8 +821,16 @@ async def _execute_node(
763821
# Include additional ContentBlocks in node input
764822
node_input = node_input + task
765823

824+
if self._interrupt_state.activated:
825+
node_input = task
826+
766827
# Execute node with streaming
767828
node.reset_executor_state()
829+
if self._interrupt_state.activated:
830+
context = self._interrupt_state.context[node.node_id]
831+
node.executor.messages = context["messages"]
832+
node.executor.state = AgentState(context["state"])
833+
node.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"])
768834

769835
# Stream agent events with node context and capture final result
770836
result = None
@@ -779,13 +845,8 @@ async def _execute_node(
779845
if result is None:
780846
raise ValueError(f"Node '{node_name}' did not produce a result event")
781847

782-
if result.stop_reason == "interrupt":
783-
node.executor.messages.pop() # remove interrupted tool use message
784-
node.executor._interrupt_state.deactivate()
785-
786-
raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in swarms")
787-
788848
execution_time = round((time.time() - start_time) * 1000)
849+
status = Status.INTERRUPTED if result.stop_reason == "interrupt" else Status.COMPLETED
789850

790851
# Create NodeResult with extracted metrics
791852
result_metrics = getattr(result, "metrics", None)
@@ -795,10 +856,11 @@ async def _execute_node(
795856
node_result = NodeResult(
796857
result=result,
797858
execution_time=execution_time,
798-
status=Status.COMPLETED,
859+
status=status,
799860
accumulated_usage=usage,
800861
accumulated_metrics=metrics,
801862
execution_count=1,
863+
interrupts=result.interrupts,
802864
)
803865

804866
# Store result in state
@@ -849,6 +911,15 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None:
849911

850912
def _build_result(self) -> SwarmResult:
851913
"""Build swarm result from current state."""
914+
# Get interrupts from current node (latest iteration only)
915+
interrupts = []
916+
if (self.state.completion_status == Status.INTERRUPTED and
917+
self.state.current_node and
918+
self.state.current_node.node_id in self.state.results):
919+
920+
node_result = self.state.results[self.state.current_node.node_id]
921+
interrupts = node_result.interrupts
922+
852923
return SwarmResult(
853924
status=self.state.completion_status,
854925
results=self.state.results,
@@ -857,6 +928,7 @@ def _build_result(self) -> SwarmResult:
857928
execution_count=len(self.state.node_history),
858929
execution_time=self.state.execution_time,
859930
node_history=self.state.node_history,
931+
interrupts=interrupts,
860932
)
861933

862934
def serialize_state(self) -> dict[str, Any]:
@@ -881,6 +953,9 @@ def serialize_state(self) -> dict[str, Any]:
881953
"shared_context": getattr(self.state.shared_context, "context", {}) or {},
882954
"handoff_message": self.state.handoff_message,
883955
},
956+
"_internal_state": {
957+
"interrupt_state": self._interrupt_state.to_dict(),
958+
},
884959
}
885960

886961
def deserialize_state(self, payload: dict[str, Any]) -> None:
@@ -896,6 +971,9 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
896971
payload: Dictionary containing persisted state data including status,
897972
completed nodes, results, and next nodes to execute.
898973
"""
974+
if "_internal_state" in payload:
975+
self._interrupt_state = _InterruptState.from_dict(payload["_internal_state"]["interrupt_state"])
976+
899977
if not payload.get("next_nodes_to_execute"):
900978
for node in self.nodes.values():
901979
node.reset_executor_state()

src/strands/session/session_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ..experimental.hooks.multiagent.events import (
88
AfterMultiAgentInvocationEvent,
9-
AfterNodeCallEvent,
9+
BeforeNodeCallEvent,
1010
MultiAgentInitializedEvent,
1111
)
1212
from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent
@@ -44,7 +44,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
4444
registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent))
4545

4646
registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source))
47-
registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source))
47+
registry.add_callback(BeforeNodeCallEvent, lambda event: self.sync_multi_agent(event.source))
4848
registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source))
4949

5050
@abstractmethod

0 commit comments

Comments
 (0)