Skip to content

Commit 1f9ef9a

Browse files
committed
interrupts - swarm - from agent
1 parent 95ac650 commit 1f9ef9a

File tree

19 files changed

+622
-99
lines changed

19 files changed

+622
-99
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,8 @@ async def _handle_tool_execution(
483483

484484
if interrupts:
485485
# Session state stored on AfterInvocationEvent.
486-
agent._interrupt_state.activate(context={"tool_use_message": message, "tool_results": tool_results})
486+
agent._interrupt_state.context = {"tool_use_message": message, "tool_results": tool_results}
487+
agent._interrupt_state.activate()
487488

488489
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
489490
yield EventLoopStopEvent(

src/strands/experimental/hooks/multiagent/events.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
is used—hooks read from the orchestrator directly.
66
"""
77

8+
import uuid
89
from dataclasses import dataclass
910
from typing import TYPE_CHECKING, Any
1011

12+
from typing_extensions import override
13+
1114
from ....hooks import BaseHookEvent
15+
from ....types.interrupt import _Interruptible
1216

1317
if TYPE_CHECKING:
1418
from ....multiagent.base import MultiAgentBase
@@ -28,18 +32,37 @@ class MultiAgentInitializedEvent(BaseHookEvent):
2832

2933

3034
@dataclass
31-
class BeforeNodeCallEvent(BaseHookEvent):
35+
class BeforeNodeCallEvent(BaseHookEvent, _Interruptible):
3236
"""Event triggered before individual node execution starts.
3337
3438
Attributes:
3539
source: The multi-agent orchestrator instance
3640
node_id: ID of the node about to execute
3741
invocation_state: Configuration that user passes in
42+
cancel_node: A user defined message that when set, will cancel the node execution.
43+
The message will be placed into the node result with an error status. If set to `True`, Strands will cancel
44+
the node and use a default cancel message.
3845
"""
3946

4047
source: "MultiAgentBase"
4148
node_id: str
4249
invocation_state: dict[str, Any] | None = None
50+
cancel_node: bool | str = False
51+
52+
def _can_write(self, name: str) -> bool:
53+
return name in ["cancel_node"]
54+
55+
@override
56+
def _interrupt_id(self, name: str) -> str:
57+
"""Unique id for the interrupt.
58+
59+
Args:
60+
name: User defined name for the interrupt.
61+
62+
Returns:
63+
Interrupt id.
64+
"""
65+
return f"v1:before_node_call:{self.node_id}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}"
4366

4467

4568
@dataclass

src/strands/interrupt.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,8 @@ class _InterruptState:
5353
context: dict[str, Any] = field(default_factory=dict)
5454
activated: bool = False
5555

56-
def activate(self, context: dict[str, Any] | None = None) -> None:
57-
"""Activate the interrupt state.
58-
59-
Args:
60-
context: Context associated with the interrupt event.
61-
"""
62-
self.context = context or {}
56+
def activate(self) -> None:
57+
"""Activate the interrupt state."""
6358
self.activated = True
6459

6560
def deactivate(self) -> None:
@@ -104,6 +99,8 @@ def resume(self, prompt: "AgentInput") -> None:
10499

105100
self.interrupts[interrupt_id].response = interrupt_response
106101

102+
self.context["responses"] = contents
103+
107104
def to_dict(self) -> dict[str, Any]:
108105
"""Serialize to dict for session management."""
109106
return asdict(self)

src/strands/multiagent/base.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,34 @@
1212

1313
from .._async import run_async
1414
from ..agent import AgentResult
15-
from ..types.content import ContentBlock
15+
from ..interrupt import Interrupt
1616
from ..types.event_loop import Metrics, Usage
17+
from ..types.multiagent import MultiAgentInput
1718

1819
logger = logging.getLogger(__name__)
1920

2021

2122
class Status(Enum):
22-
"""Execution status for both graphs and nodes."""
23+
"""Execution status for both graphs and nodes.
24+
25+
Attributes:
26+
PENDING: Task has not started execution yet.
27+
EXECUTING: Task is currently running.
28+
COMPLETED: Task finished successfully.
29+
FAILED: Task encountered an error and could not complete.
30+
INTERRUPTED: Task was interrupted by user.
31+
"""
2332

2433
PENDING = "pending"
2534
EXECUTING = "executing"
2635
COMPLETED = "completed"
2736
FAILED = "failed"
37+
INTERRUPTED = "interrupted"
2838

2939

3040
@dataclass
3141
class NodeResult:
32-
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.
33-
34-
The status field represents the semantic outcome of the node's work:
35-
- COMPLETED: The node's task was successfully accomplished
36-
- FAILED: The node's task failed or produced an error
37-
"""
42+
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results."""
3843

3944
# Core result data - single AgentResult, nested MultiAgentResult, or Exception
4045
result: Union[AgentResult, "MultiAgentResult", Exception]
@@ -47,6 +52,7 @@ class NodeResult:
4752
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
4853
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
4954
execution_count: int = 0
55+
interrupts: list[Interrupt] = field(default_factory=list)
5056

5157
def get_agent_results(self) -> list[AgentResult]:
5258
"""Get all AgentResult objects from this node, flattened if nested."""
@@ -78,6 +84,7 @@ def to_dict(self) -> dict[str, Any]:
7884
"accumulated_usage": self.accumulated_usage,
7985
"accumulated_metrics": self.accumulated_metrics,
8086
"execution_count": self.execution_count,
87+
"interrupts": [interrupt.to_dict() for interrupt in self.interrupts],
8188
}
8289

8390
@classmethod
@@ -100,31 +107,32 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult":
100107
usage = _parse_usage(data.get("accumulated_usage", {}))
101108
metrics = _parse_metrics(data.get("accumulated_metrics", {}))
102109

110+
interrupts = []
111+
for interrupt_data in data.get("interrupts", []):
112+
interrupts.append(Interrupt(**interrupt_data))
113+
103114
return cls(
104115
result=result,
105116
execution_time=int(data.get("execution_time", 0)),
106117
status=Status(data.get("status", "pending")),
107118
accumulated_usage=usage,
108119
accumulated_metrics=metrics,
109120
execution_count=int(data.get("execution_count", 0)),
121+
interrupts=interrupts,
110122
)
111123

112124

113125
@dataclass
114126
class MultiAgentResult:
115-
"""Result from multi-agent execution with accumulated metrics.
116-
117-
The status field represents the outcome of the MultiAgentBase execution:
118-
- COMPLETED: The execution was successfully accomplished
119-
- FAILED: The execution failed or produced an error
120-
"""
127+
"""Result from multi-agent execution with accumulated metrics."""
121128

122129
status: Status = Status.PENDING
123130
results: dict[str, NodeResult] = field(default_factory=lambda: {})
124131
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
125132
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
126133
execution_count: int = 0
127134
execution_time: int = 0
135+
interrupts: list[Interrupt] = field(default_factory=list)
128136

129137
@classmethod
130138
def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
@@ -136,13 +144,18 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
136144
usage = _parse_usage(data.get("accumulated_usage", {}))
137145
metrics = _parse_metrics(data.get("accumulated_metrics", {}))
138146

147+
interrupts = []
148+
for interrupt_data in data.get("interrupts", []):
149+
interrupts.append(Interrupt(**interrupt_data))
150+
139151
multiagent_result = cls(
140152
status=Status(data["status"]),
141153
results=results,
142154
accumulated_usage=usage,
143155
accumulated_metrics=metrics,
144156
execution_count=int(data.get("execution_count", 0)),
145157
execution_time=int(data.get("execution_time", 0)),
158+
interrupts=interrupts,
146159
)
147160
return multiagent_result
148161

@@ -156,6 +169,7 @@ def to_dict(self) -> dict[str, Any]:
156169
"accumulated_metrics": self.accumulated_metrics,
157170
"execution_count": self.execution_count,
158171
"execution_time": self.execution_time,
172+
"interrupts": [interrupt.to_dict() for interrupt in self.interrupts],
159173
}
160174

161175

@@ -173,7 +187,7 @@ class MultiAgentBase(ABC):
173187

174188
@abstractmethod
175189
async def invoke_async(
176-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
190+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
177191
) -> MultiAgentResult:
178192
"""Invoke asynchronously.
179193
@@ -186,7 +200,7 @@ async def invoke_async(
186200
raise NotImplementedError("invoke_async not implemented")
187201

188202
async def stream_async(
189-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
203+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
190204
) -> AsyncIterator[dict[str, Any]]:
191205
"""Stream events during multi-agent execution.
192206
@@ -211,7 +225,7 @@ async def stream_async(
211225
yield {"result": result}
212226

213227
def __call__(
214-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
228+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
215229
) -> MultiAgentResult:
216230
"""Invoke synchronously.
217231

src/strands/multiagent/graph.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646
from ..types.content import ContentBlock, Messages
4747
from ..types.event_loop import Metrics, Usage
48+
from ..types.multiagent import MultiAgentInput
4849
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
4950

5051
logger = logging.getLogger(__name__)
@@ -67,7 +68,7 @@ class GraphState:
6768
"""
6869

6970
# Task (with default empty string)
70-
task: str | list[ContentBlock] = ""
71+
task: MultiAgentInput = ""
7172

7273
# Execution state
7374
status: Status = Status.PENDING
@@ -456,7 +457,7 @@ def __init__(
456457
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
457458

458459
def __call__(
459-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
460+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
460461
) -> GraphResult:
461462
"""Invoke the graph synchronously.
462463
@@ -472,7 +473,7 @@ def __call__(
472473
return run_async(lambda: self.invoke_async(task, invocation_state))
473474

474475
async def invoke_async(
475-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
476+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
476477
) -> GraphResult:
477478
"""Invoke the graph asynchronously.
478479
@@ -496,7 +497,7 @@ async def invoke_async(
496497
return cast(GraphResult, final_event["result"])
497498

498499
async def stream_async(
499-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
500+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
500501
) -> AsyncIterator[dict[str, Any]]:
501502
"""Stream events during graph execution.
502503
@@ -963,7 +964,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
963964
if isinstance(self.state.task, str):
964965
return [ContentBlock(text=self.state.task)]
965966
else:
966-
return self.state.task
967+
return cast(list[ContentBlock], self.state.task)
967968

968969
# Combine task with dependency outputs
969970
node_input = []
@@ -974,7 +975,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
974975
else:
975976
# Add task content blocks with a prefix
976977
node_input.append(ContentBlock(text="Original Task:"))
977-
node_input.extend(self.state.task)
978+
node_input.extend(cast(list[ContentBlock], self.state.task))
978979

979980
# Add dependency outputs
980981
node_input.append(ContentBlock(text="\nInputs from previous nodes:"))

0 commit comments

Comments
 (0)