3333 MultiAgentInitializedEvent ,
3434)
3535from ..hooks import HookProvider , HookRegistry
36+ from ..interrupt import Interrupt , _InterruptState
3637from ..session import SessionManager
3738from ..telemetry import get_tracer
3839from ..tools .decorator import tool
4445 MultiAgentResultEvent ,
4546)
4647from ..types .content import ContentBlock , Messages
48+ from ..types .interrupt import InterruptResponseContent
4749from ..types .event_loop import Metrics , Usage
4850from .base import MultiAgentBase , MultiAgentResult , NodeResult , Status
4951
@@ -145,7 +147,7 @@ class SwarmState:
145147 """Current state of swarm execution."""
146148
147149 current_node : SwarmNode | None # The agent currently executing
148- task : str | list [ContentBlock ] # The original task from the user that is being executed
150+ task : str | list [ContentBlock ] | list [ InterruptResponseContent ] # The original task from the user that is being executed
149151 completion_status : Status = Status .PENDING # Current swarm execution status
150152 shared_context : SharedContext = field (default_factory = SharedContext ) # Context shared between agents
151153 node_history : list [SwarmNode ] = field (default_factory = list ) # Complete history of agents that have executed
@@ -255,11 +257,14 @@ def __init__(
255257
256258 self .shared_context = SharedContext ()
257259 self .nodes : dict [str , SwarmNode ] = {}
260+
258261 self .state = SwarmState (
259262 current_node = None , # Placeholder, will be set properly
260263 task = "" ,
261264 completion_status = Status .PENDING ,
262265 )
266+ self ._interrupt_state = _InterruptState ()
267+
263268 self .tracer = get_tracer ()
264269
265270 self .session_manager = session_manager
@@ -277,7 +282,9 @@ def __init__(
277282 run_async (lambda : self .hooks .invoke_callbacks_async (MultiAgentInitializedEvent (self )))
278283
279284 def __call__ (
280- self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
285+ self ,
286+ task : str | list [ContentBlock ] | list [InterruptResponseContent ],
287+ invocation_state : dict [str , Any ] | None = None , ** kwargs : Any ,
281288 ) -> SwarmResult :
282289 """Invoke the swarm synchronously.
283290
@@ -292,7 +299,9 @@ def __call__(
292299 return run_async (lambda : self .invoke_async (task , invocation_state ))
293300
294301 async def invoke_async (
295- self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
302+ self ,
303+ task : str | list [ContentBlock ] | list [InterruptResponseContent ],
304+ invocation_state : dict [str , Any ] | None = None , ** kwargs : Any ,
296305 ) -> SwarmResult :
297306 """Invoke the swarm asynchronously.
298307
@@ -316,7 +325,9 @@ async def invoke_async(
316325 return cast (SwarmResult , final_event ["result" ])
317326
318327 async def stream_async (
319- self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
328+ self ,
329+ task : str | list [ContentBlock ] | list [InterruptResponseContent ],
330+ invocation_state : dict [str , Any ] | None = None , ** kwargs : Any ,
320331 ) -> AsyncIterator [dict [str , Any ]]:
321332 """Stream events during swarm execution.
322333
@@ -334,6 +345,8 @@ async def stream_async(
334345 - multi_agent_node_stop: When a node stops execution
335346 - result: Final swarm result
336347 """
348+ self ._interrupt_state .resume (task )
349+
337350 if invocation_state is None :
338351 invocation_state = {}
339352
@@ -644,6 +657,36 @@ def _build_node_input(self, target_node: SwarmNode) -> str:
644657
645658 return context_text
646659
660+ def _activate_interrupt (self , node : SwarmNode , interrupts : list [Interrupt ]) -> Any :
661+ """Activate the interrupt state.
662+
663+ Args:
664+ node: The interrupted node.
665+ interrupts: The interrupts raised by the user.
666+
667+ Returns:
668+ MultiAgentNodeInterruptEvent
669+ """
670+
671+ logger .debug ("node=<%s> | node interrupted" , node .node_id )
672+ self .state .completion_status = Status .INTERRUPTED
673+
674+ self ._interrupt_state .context .update (
675+ {
676+ node .node_id : {
677+ "interrupt_state" : node .executor ._interrupt_state .to_dict (),
678+ "state" : node .executor .state .get (),
679+ "messages" : node .executor .messages
680+ }
681+ }
682+ )
683+ self ._interrupt_state .activate ()
684+
685+ # return MultiAgentNodeInterruptEvent(
686+ # node_id=node.node_id,
687+ # interrupts=interrupts,
688+ # )
689+
647690 async def _execute_swarm (self , invocation_state : dict [str , Any ]) -> AsyncIterator [Any ]:
648691 """Execute swarm and yield TypedEvent objects."""
649692 try :
@@ -680,9 +723,13 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
680723
681724 # TODO: Implement cancellation token to stop _execute_node from continuing
682725 try :
683- await self .hooks .invoke_callbacks_async (
684- BeforeNodeCallEvent (self , current_node .node_id , invocation_state )
726+ _ , interrupts = await self .hooks .invoke_callbacks_async (
727+ BeforeNodeCallEvent (self , current_node .node_id , invocation_state )
685728 )
729+ if interrupts :
730+ yield self ._activate_interrupt (current_node , interrupts )
731+ break
732+
686733 node_stream = self ._stream_with_timeout (
687734 self ._execute_node (current_node , self .state .task , invocation_state ),
688735 self .node_timeout ,
@@ -691,6 +738,14 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
691738 async for event in node_stream :
692739 yield event
693740
741+ stop_event = cast (MultiAgentNodeStopEvent , event )
742+ node_result = stop_event ["node_result" ]
743+ if node_result .status == Status .INTERRUPTED :
744+ self ._interrupt_state .interrupts .update ({interrupt .id : interrupt for interrupt in node_result .interrupts })
745+ # yield self._activate_interrupt(current_node, node_result.interrupts)
746+ self ._activate_interrupt (current_node , node_result .interrupts )
747+ break
748+
694749 self .state .node_history .append (current_node )
695750 await self .hooks .invoke_callbacks_async (
696751 AfterNodeCallEvent (self , current_node .node_id , invocation_state )
@@ -741,7 +796,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
741796 )
742797
743798 async def _execute_node (
744- self , node : SwarmNode , task : str | list [ContentBlock ], invocation_state : dict [str , Any ]
799+ self ,
800+ node : SwarmNode ,
801+ task : str | list [ContentBlock ] | list [InterruptResponseContent ],
802+ invocation_state : dict [str , Any ],
745803 ) -> AsyncIterator [Any ]:
746804 """Execute swarm node and yield TypedEvent objects."""
747805 start_time = time .time ()
@@ -763,8 +821,16 @@ async def _execute_node(
763821 # Include additional ContentBlocks in node input
764822 node_input = node_input + task
765823
824+ if self ._interrupt_state .activated :
825+ node_input = task
826+
766827 # Execute node with streaming
767828 node .reset_executor_state ()
829+ if self ._interrupt_state .activated :
830+ context = self ._interrupt_state .context [node .node_id ]
831+ node .executor .messages = context ["messages" ]
832+ node .executor .state = AgentState (context ["state" ])
833+ node .executor ._interrupt_state = _InterruptState .from_dict (context ["interrupt_state" ])
768834
769835 # Stream agent events with node context and capture final result
770836 result = None
@@ -779,13 +845,8 @@ async def _execute_node(
779845 if result is None :
780846 raise ValueError (f"Node '{ node_name } ' did not produce a result event" )
781847
782- if result .stop_reason == "interrupt" :
783- node .executor .messages .pop () # remove interrupted tool use message
784- node .executor ._interrupt_state .deactivate ()
785-
786- raise RuntimeError ("user raised interrupt from agent | interrupts are not yet supported in swarms" )
787-
788848 execution_time = round ((time .time () - start_time ) * 1000 )
849+ status = Status .INTERRUPTED if result .stop_reason == "interrupt" else Status .COMPLETED
789850
790851 # Create NodeResult with extracted metrics
791852 result_metrics = getattr (result , "metrics" , None )
@@ -795,10 +856,11 @@ async def _execute_node(
795856 node_result = NodeResult (
796857 result = result ,
797858 execution_time = execution_time ,
798- status = Status . COMPLETED ,
859+ status = status ,
799860 accumulated_usage = usage ,
800861 accumulated_metrics = metrics ,
801862 execution_count = 1 ,
863+ interrupts = result .interrupts ,
802864 )
803865
804866 # Store result in state
@@ -849,6 +911,15 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None:
849911
850912 def _build_result (self ) -> SwarmResult :
851913 """Build swarm result from current state."""
914+ # Get interrupts from current node (latest iteration only)
915+ interrupts = []
916+ if (self .state .completion_status == Status .INTERRUPTED and
917+ self .state .current_node and
918+ self .state .current_node .node_id in self .state .results ):
919+
920+ node_result = self .state .results [self .state .current_node .node_id ]
921+ interrupts = node_result .interrupts
922+
852923 return SwarmResult (
853924 status = self .state .completion_status ,
854925 results = self .state .results ,
@@ -857,6 +928,7 @@ def _build_result(self) -> SwarmResult:
857928 execution_count = len (self .state .node_history ),
858929 execution_time = self .state .execution_time ,
859930 node_history = self .state .node_history ,
931+ interrupts = interrupts ,
860932 )
861933
862934 def serialize_state (self ) -> dict [str , Any ]:
@@ -881,6 +953,9 @@ def serialize_state(self) -> dict[str, Any]:
881953 "shared_context" : getattr (self .state .shared_context , "context" , {}) or {},
882954 "handoff_message" : self .state .handoff_message ,
883955 },
956+ "_internal_state" : {
957+ "interrupt_state" : self ._interrupt_state .to_dict (),
958+ },
884959 }
885960
886961 def deserialize_state (self , payload : dict [str , Any ]) -> None :
@@ -896,6 +971,9 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
896971 payload: Dictionary containing persisted state data including status,
897972 completed nodes, results, and next nodes to execute.
898973 """
974+ if "_internal_state" in payload :
975+ self ._interrupt_state = _InterruptState .from_dict (payload ["_internal_state" ]["interrupt_state" ])
976+
899977 if not payload .get ("next_nodes_to_execute" ):
900978 for node in self .nodes .values ():
901979 node .reset_executor_state ()
0 commit comments