Skip to content

Commit ce9efae

Browse files
committed
interrupts - swarm - from agent
1 parent 432d269 commit ce9efae

File tree

13 files changed

+553
-63
lines changed

13 files changed

+553
-63
lines changed

src/strands/experimental/hooks/multiagent/events.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
is used—hooks read from the orchestrator directly.
66
"""
77

8+
import uuid
89
from dataclasses import dataclass
910
from typing import TYPE_CHECKING, Any
1011

12+
from typing_extensions import override
13+
1114
from ....hooks import BaseHookEvent
15+
from ....types.interrupt import _Interruptible
1216

1317
if TYPE_CHECKING:
1418
from ....multiagent.base import MultiAgentBase
@@ -28,18 +32,37 @@ class MultiAgentInitializedEvent(BaseHookEvent):
2832

2933

3034
@dataclass
31-
class BeforeNodeCallEvent(BaseHookEvent):
35+
class BeforeNodeCallEvent(BaseHookEvent, _Interruptible):
3236
"""Event triggered before individual node execution starts.
3337
3438
Attributes:
3539
source: The multi-agent orchestrator instance
3640
node_id: ID of the node about to execute
3741
invocation_state: Configuration that user passes in
42+
cancel_node: A user defined message that when set, will cancel the node execution.
43+
The message will be placed into the node result with an error status. If set to `True`, Strands will cancel
44+
the node and use a default cancel message.
3845
"""
3946

4047
source: "MultiAgentBase"
4148
node_id: str
4249
invocation_state: dict[str, Any] | None = None
50+
cancel_node: bool | str = False
51+
52+
def _can_write(self, name: str) -> bool:
53+
return name in ["cancel_node"]
54+
55+
@override
56+
def _interrupt_id(self, name: str) -> str:
57+
"""Unique id for the interrupt.
58+
59+
Args:
60+
name: User defined name for the interrupt.
61+
62+
Returns:
63+
Interrupt id.
64+
"""
65+
return f"v1:before_node_call:{self.node_id}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}"
4366

4467

4568
@dataclass

src/strands/interrupt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def resume(self, prompt: "AgentInput") -> None:
9999

100100
self.interrupts[interrupt_id].response = interrupt_response
101101

102+
self.context["responses"] = contents
103+
102104
def to_dict(self) -> dict[str, Any]:
103105
"""Serialize to dict for session management."""
104106
return asdict(self)

src/strands/multiagent/base.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,34 @@
1212

1313
from .._async import run_async
1414
from ..agent import AgentResult
15+
from ..interrupt import Interrupt
1516
from ..types.event_loop import Metrics, Usage
1617
from ..types.multiagent import MultiAgentInput
1718

1819
logger = logging.getLogger(__name__)
1920

2021

2122
class Status(Enum):
22-
"""Execution status for both graphs and nodes."""
23+
"""Execution status for both graphs and nodes.
24+
25+
Attributes:
26+
PENDING: Task has not started execution yet.
27+
EXECUTING: Task is currently running.
28+
COMPLETED: Task finished successfully.
29+
FAILED: Task encountered an error and could not complete.
30+
INTERRUPTED: Task was interrupted by user.
31+
"""
2332

2433
PENDING = "pending"
2534
EXECUTING = "executing"
2635
COMPLETED = "completed"
2736
FAILED = "failed"
37+
INTERRUPTED = "interrupted"
2838

2939

3040
@dataclass
3141
class NodeResult:
32-
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.
33-
34-
The status field represents the semantic outcome of the node's work:
35-
- COMPLETED: The node's task was successfully accomplished
36-
- FAILED: The node's task failed or produced an error
37-
"""
42+
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results."""
3843

3944
# Core result data - single AgentResult, nested MultiAgentResult, or Exception
4045
result: Union[AgentResult, "MultiAgentResult", Exception]
@@ -47,6 +52,7 @@ class NodeResult:
4752
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
4853
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
4954
execution_count: int = 0
55+
interrupts: list[Interrupt] = field(default_factory=list)
5056

5157
def get_agent_results(self) -> list[AgentResult]:
5258
"""Get all AgentResult objects from this node, flattened if nested."""
@@ -78,6 +84,7 @@ def to_dict(self) -> dict[str, Any]:
7884
"accumulated_usage": self.accumulated_usage,
7985
"accumulated_metrics": self.accumulated_metrics,
8086
"execution_count": self.execution_count,
87+
"interrupts": [interrupt.to_dict() for interrupt in self.interrupts],
8188
}
8289

8390
@classmethod
@@ -100,31 +107,32 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult":
100107
usage = _parse_usage(data.get("accumulated_usage", {}))
101108
metrics = _parse_metrics(data.get("accumulated_metrics", {}))
102109

110+
interrupts = []
111+
for interrupt_data in data.get("interrupts", []):
112+
interrupts.append(Interrupt(**interrupt_data))
113+
103114
return cls(
104115
result=result,
105116
execution_time=int(data.get("execution_time", 0)),
106117
status=Status(data.get("status", "pending")),
107118
accumulated_usage=usage,
108119
accumulated_metrics=metrics,
109120
execution_count=int(data.get("execution_count", 0)),
121+
interrupts=interrupts,
110122
)
111123

112124

113125
@dataclass
114126
class MultiAgentResult:
115-
"""Result from multi-agent execution with accumulated metrics.
116-
117-
The status field represents the outcome of the MultiAgentBase execution:
118-
- COMPLETED: The execution was successfully accomplished
119-
- FAILED: The execution failed or produced an error
120-
"""
127+
"""Result from multi-agent execution with accumulated metrics."""
121128

122129
status: Status = Status.PENDING
123130
results: dict[str, NodeResult] = field(default_factory=lambda: {})
124131
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
125132
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
126133
execution_count: int = 0
127134
execution_time: int = 0
135+
interrupts: list[Interrupt] = field(default_factory=list)
128136

129137
@classmethod
130138
def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
@@ -136,13 +144,18 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
136144
usage = _parse_usage(data.get("accumulated_usage", {}))
137145
metrics = _parse_metrics(data.get("accumulated_metrics", {}))
138146

147+
interrupts = []
148+
for interrupt_data in data.get("interrupts", []):
149+
interrupts.append(Interrupt(**interrupt_data))
150+
139151
multiagent_result = cls(
140152
status=Status(data["status"]),
141153
results=results,
142154
accumulated_usage=usage,
143155
accumulated_metrics=metrics,
144156
execution_count=int(data.get("execution_count", 0)),
145157
execution_time=int(data.get("execution_time", 0)),
158+
interrupts=interrupts,
146159
)
147160
return multiagent_result
148161

@@ -156,6 +169,7 @@ def to_dict(self) -> dict[str, Any]:
156169
"accumulated_metrics": self.accumulated_metrics,
157170
"execution_count": self.execution_count,
158171
"execution_time": self.execution_time,
172+
"interrupts": [interrupt.to_dict() for interrupt in self.interrupts],
159173
}
160174

161175

src/strands/multiagent/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
964964
if isinstance(self.state.task, str):
965965
return [ContentBlock(text=self.state.task)]
966966
else:
967-
return self.state.task
967+
return cast(list[ContentBlock], self.state.task)
968968

969969
# Combine task with dependency outputs
970970
node_input = []
@@ -975,7 +975,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
975975
else:
976976
# Add task content blocks with a prefix
977977
node_input.append(ContentBlock(text="Original Task:"))
978-
node_input.extend(self.state.task)
978+
node_input.extend(cast(list[ContentBlock], self.state.task))
979979

980980
# Add dependency outputs
981981
node_input.append(ContentBlock(text="\nInputs from previous nodes:"))

0 commit comments

Comments
 (0)