Skip to content

Commit 50f616e

Browse files
committed
interrupts - swarm - from agent
1 parent b4efc9d commit 50f616e

File tree

7 files changed

+192
-67
lines changed

7 files changed

+192
-67
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,8 @@ async def _handle_tool_execution(
483483

484484
if interrupts:
485485
# Session state stored on AfterInvocationEvent.
486-
agent._interrupt_state.activate(context={"tool_use_message": message, "tool_results": tool_results})
486+
agent._interrupt_state.context = {"tool_use_message": message, "tool_results": tool_results}
487+
agent._interrupt_state.activate()
487488

488489
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
489490
yield EventLoopStopEvent(

src/strands/interrupt.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,8 @@ class _InterruptState:
5353
context: dict[str, Any] = field(default_factory=dict)
5454
activated: bool = False
5555

56-
def activate(self, context: dict[str, Any] | None = None) -> None:
57-
"""Activate the interrupt state.
58-
59-
Args:
60-
context: Context associated with the interrupt event.
61-
"""
62-
self.context = context or {}
56+
def activate(self) -> None:
57+
"""Activate the interrupt state."""
6358
self.activated = True
6459

6560
def deactivate(self) -> None:

src/strands/multiagent/base.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,34 @@
1212

1313
from .._async import run_async
1414
from ..agent import AgentResult
15-
from ..types.content import ContentBlock
15+
from ..interrupt import Interrupt
1616
from ..types.event_loop import Metrics, Usage
17+
from ..types.multiagent import MultiAgentInput
1718

1819
logger = logging.getLogger(__name__)
1920

2021

2122
class Status(Enum):
22-
"""Execution status for both graphs and nodes."""
23+
"""Execution status for both graphs and nodes.
24+
25+
Attributes:
26+
PENDING: Task has not started execution yet.
27+
EXECUTING: Task is currently running.
28+
COMPLETED: Task finished successfully.
29+
FAILED: Task encountered an error and could not complete.
30+
INTERRUPTED: Task was interrupted by user.
31+
"""
2332

2433
PENDING = "pending"
2534
EXECUTING = "executing"
2635
COMPLETED = "completed"
2736
FAILED = "failed"
37+
INTERRUPTED = "interrupted"
2838

2939

3040
@dataclass
3141
class NodeResult:
32-
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.
33-
34-
The status field represents the semantic outcome of the node's work:
35-
- COMPLETED: The node's task was successfully accomplished
36-
- FAILED: The node's task failed or produced an error
37-
"""
42+
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results."""
3843

3944
# Core result data - single AgentResult, nested MultiAgentResult, or Exception
4045
result: Union[AgentResult, "MultiAgentResult", Exception]
@@ -47,6 +52,7 @@ class NodeResult:
4752
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
4853
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
4954
execution_count: int = 0
55+
interrupts: list[Interrupt] = field(default_factory=list)
5056

5157
def get_agent_results(self) -> list[AgentResult]:
5258
"""Get all AgentResult objects from this node, flattened if nested."""
@@ -78,6 +84,7 @@ def to_dict(self) -> dict[str, Any]:
7884
"accumulated_usage": self.accumulated_usage,
7985
"accumulated_metrics": self.accumulated_metrics,
8086
"execution_count": self.execution_count,
87+
"interrupts": [interrupt.to_dict() for interrupt in self.interrupts],
8188
}
8289

8390
@classmethod
@@ -99,6 +106,10 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult":
99106

100107
usage = _parse_usage(data.get("accumulated_usage", {}))
101108
metrics = _parse_metrics(data.get("accumulated_metrics", {}))
109+
110+
interrupts = []
111+
for interrupt_data in data.get("interrupts", []):
112+
interrupts.append(Interrupt.from_dict(interrupt_data))
102113

103114
return cls(
104115
result=result,
@@ -107,24 +118,21 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult":
107118
accumulated_usage=usage,
108119
accumulated_metrics=metrics,
109120
execution_count=int(data.get("execution_count", 0)),
121+
interrupts=interrupts,
110122
)
111123

112124

113125
@dataclass
114126
class MultiAgentResult:
115-
"""Result from multi-agent execution with accumulated metrics.
116-
117-
The status field represents the outcome of the MultiAgentBase execution:
118-
- COMPLETED: The execution was successfully accomplished
119-
- FAILED: The execution failed or produced an error
120-
"""
127+
"""Result from multi-agent execution with accumulated metrics."""
121128

122129
status: Status = Status.PENDING
123130
results: dict[str, NodeResult] = field(default_factory=lambda: {})
124131
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
125132
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
126133
execution_count: int = 0
127134
execution_time: int = 0
135+
interrupts: list[Interrupt] = field(default_factory=list)
128136

129137
@classmethod
130138
def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
@@ -136,13 +144,18 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
136144
usage = _parse_usage(data.get("accumulated_usage", {}))
137145
metrics = _parse_metrics(data.get("accumulated_metrics", {}))
138146

147+
interrupts = []
148+
for interrupt_data in data.get("interrupts", []):
149+
interrupts.append(Interrupt.from_dict(interrupt_data))
150+
139151
multiagent_result = cls(
140152
status=Status(data["status"]),
141153
results=results,
142154
accumulated_usage=usage,
143155
accumulated_metrics=metrics,
144156
execution_count=int(data.get("execution_count", 0)),
145157
execution_time=int(data.get("execution_time", 0)),
158+
interrupts=interrupts,
146159
)
147160
return multiagent_result
148161

@@ -156,6 +169,7 @@ def to_dict(self) -> dict[str, Any]:
156169
"accumulated_metrics": self.accumulated_metrics,
157170
"execution_count": self.execution_count,
158171
"execution_time": self.execution_time,
172+
"interrupts": [interrupt.to_dict() for interrupt in self.interrupts],
159173
}
160174

161175

@@ -173,7 +187,7 @@ class MultiAgentBase(ABC):
173187

174188
@abstractmethod
175189
async def invoke_async(
176-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
190+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
177191
) -> MultiAgentResult:
178192
"""Invoke asynchronously.
179193
@@ -186,7 +200,7 @@ async def invoke_async(
186200
raise NotImplementedError("invoke_async not implemented")
187201

188202
async def stream_async(
189-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
203+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
190204
) -> AsyncIterator[dict[str, Any]]:
191205
"""Stream events during multi-agent execution.
192206
@@ -211,7 +225,7 @@ async def stream_async(
211225
yield {"result": result}
212226

213227
def __call__(
214-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
228+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
215229
) -> MultiAgentResult:
216230
"""Invoke synchronously.
217231

src/strands/multiagent/graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646
from ..types.content import ContentBlock, Messages
4747
from ..types.event_loop import Metrics, Usage
48+
from ..types.multiagent import MultiAgentInput
4849
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
4950

5051
logger = logging.getLogger(__name__)
@@ -67,7 +68,7 @@ class GraphState:
6768
"""
6869

6970
# Task (with default empty string)
70-
task: str | list[ContentBlock] = ""
71+
task: MultiAgentInput = ""
7172

7273
# Execution state
7374
status: Status = Status.PENDING
@@ -456,7 +457,7 @@ def __init__(
456457
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
457458

458459
def __call__(
459-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
460+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
460461
) -> GraphResult:
461462
"""Invoke the graph synchronously.
462463
@@ -472,7 +473,7 @@ def __call__(
472473
return run_async(lambda: self.invoke_async(task, invocation_state))
473474

474475
async def invoke_async(
475-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
476+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
476477
) -> GraphResult:
477478
"""Invoke the graph asynchronously.
478479
@@ -496,7 +497,7 @@ async def invoke_async(
496497
return cast(GraphResult, final_event["result"])
497498

498499
async def stream_async(
499-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
500+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
500501
) -> AsyncIterator[dict[str, Any]]:
501502
"""Stream events during graph execution.
502503

0 commit comments

Comments
 (0)