1111"""
1212
1313import threading
14- from typing import Optional
1514
1615from ...hooks .registry import HookProvider , HookRegistry
1716from ...multiagent .base import MultiAgentBase
1817from ...session import SessionManager
1918from .multiagent_events import (
20- AfterGraphInvocationEvent ,
19+ AfterMultiAgentInvocationEvent ,
2120 AfterNodeInvocationEvent ,
22- BeforeGraphInvocationEvent ,
21+ BeforeMultiAgentInvocationEvent ,
2322 BeforeNodeInvocationEvent ,
2423 MultiAgentInitializationEvent ,
25- MultiAgentState ,
2624)
27- from .multiagent_state_adapter import MultiAgentAdapter
28-
29-
30- def _get_multiagent_state (
31- multiagent_state : Optional [MultiAgentState ],
32- orchestrator : MultiAgentBase ,
33- ) -> MultiAgentState :
34- if multiagent_state is not None :
35- return multiagent_state
36-
37- return MultiAgentAdapter .create_multi_agent_state (orchestrator = orchestrator )
3825
3926
4027class MultiAgentHook (HookProvider ):
@@ -67,17 +54,18 @@ def register_hooks(self, registry: HookRegistry, **kwargs: object) -> None:
6754 **kwargs: Additional keyword arguments (unused)
6855 """
6956 registry .add_callback (MultiAgentInitializationEvent , self ._on_initialization )
70- registry .add_callback (BeforeGraphInvocationEvent , self ._on_before_graph )
57+ registry .add_callback (BeforeMultiAgentInvocationEvent , self ._on_initialization )
7158 registry .add_callback (BeforeNodeInvocationEvent , self ._on_before_node )
7259 registry .add_callback (AfterNodeInvocationEvent , self ._on_after_node )
73- registry .add_callback (AfterGraphInvocationEvent , self ._on_after_graph )
60+ registry .add_callback (AfterMultiAgentInvocationEvent , self ._on_after_Execution )
7461
7562 def _on_initialization (self , event : MultiAgentInitializationEvent ):
7663 """Persist state when multi-agent orchestrator initializes."""
77- self ._persist (_get_multiagent_state (event .state , event .orchestrator ))
64+ self ._persist (event .orchestrator )
65+ pass
7866
79- def _on_before_graph (self , event : BeforeGraphInvocationEvent ):
80- """Hook called before graph execution starts ."""
67+ def _on_before_Invocation (self , event : MultiAgentInitializationEvent ):
68+ """Persist state when multi-agent orchestrator initializes ."""
8169 pass
8270
8371 def _on_before_node (self , event : BeforeNodeInvocationEvent ):
@@ -86,21 +74,20 @@ def _on_before_node(self, event: BeforeNodeInvocationEvent):
8674
8775 def _on_after_node (self , event : AfterNodeInvocationEvent ):
8876 """Persist state after each node completes execution."""
89- multi_agent_state = _get_multiagent_state (multiagent_state = event .state , orchestrator = event .orchestrator )
90- self ._persist (multi_agent_state )
77+ self ._persist (event .orchestrator )
9178
92- def _on_after_graph (self , event : AfterGraphInvocationEvent ):
79+ def _on_after_Execution (self , event : AfterMultiAgentInvocationEvent ):
9380 """Persist final state after graph execution completes."""
94- multiagent_state = _get_multiagent_state (multiagent_state = event .state , orchestrator = event .orchestrator )
95- self ._persist (multiagent_state )
81+ self ._persist (event .orchestrator )
9682
97- def _persist (self , multiagent_state : MultiAgentState ) -> None :
83+ def _persist (self , orchestrator : MultiAgentBase ) -> None :
9884 """Persist the provided MultiAgentState using the configured SessionManager.
9985
10086 This method is synchronized across threads/tasks to avoid write races.
10187
10288 Args:
103- multiagent_state : State to persist
89+ orchestrator : State to persist
10490 """
91+ current_state = orchestrator .get_state_from_orchestrator ()
10592 with self ._lock :
106- self ._session_manager .write_multi_agent_state (multiagent_state )
93+ self ._session_manager .write_multi_agent_state (current_state )
0 commit comments