Skip to content

Commit 8a89d91

Browse files
feat: add multiagent hooks, add serialize & deserialize function to multiagent base & agent result (#1070)
* feat: add multiagent hooks, add serialize & deserialize function to multiagent base & agent result * Delete __init__.py
1 parent 78c59b9 commit 8a89d91

File tree

9 files changed

+517
-1
lines changed

9 files changed

+517
-1
lines changed

src/strands/agent/agent_result.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from dataclasses import dataclass
7-
from typing import Any, Sequence
7+
from typing import Any, Sequence, cast
88

99
from ..interrupt import Interrupt
1010
from ..telemetry.metrics import EventLoopMetrics
@@ -46,3 +46,34 @@ def __str__(self) -> str:
4646
if isinstance(item, dict) and "text" in item:
4747
result += item.get("text", "") + "\n"
4848
return result
49+
50+
@classmethod
51+
def from_dict(cls, data: dict[str, Any]) -> "AgentResult":
52+
"""Rehydrate an AgentResult from persisted JSON.
53+
54+
Args:
55+
data: Dictionary containing the serialized AgentResult data
56+
Returns:
57+
AgentResult instance
58+
Raises:
59+
TypeError: If the data format is invalid@
60+
"""
61+
if data.get("type") != "agent_result":
62+
raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}")
63+
64+
message = cast(Message, data.get("message"))
65+
stop_reason = cast(StopReason, data.get("stop_reason"))
66+
67+
return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={})
68+
69+
def to_dict(self) -> dict[str, Any]:
70+
"""Convert this AgentResult to JSON-serializable dictionary.
71+
72+
Returns:
73+
Dictionary containing serialized AgentResult data
74+
"""
75+
return {
76+
"type": "agent_result",
77+
"message": self.message,
78+
"stop_reason": self.stop_reason,
79+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Multi-agent hook events and utilities.
2+
3+
Provides event classes for hooking into multi-agent orchestrator lifecycle.
4+
"""
5+
6+
from .events import (
7+
AfterMultiAgentInvocationEvent,
8+
AfterNodeCallEvent,
9+
BeforeMultiAgentInvocationEvent,
10+
BeforeNodeCallEvent,
11+
MultiAgentInitializedEvent,
12+
)
13+
14+
__all__ = [
15+
"AfterMultiAgentInvocationEvent",
16+
"AfterNodeCallEvent",
17+
"BeforeMultiAgentInvocationEvent",
18+
"BeforeNodeCallEvent",
19+
"MultiAgentInitializedEvent",
20+
]
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Multi-agent execution lifecycle events for hook system integration.
2+
3+
These events are fired by orchestrators (Graph/Swarm) at key points so
4+
hooks can persist, monitor, or debug execution. No intermediate state model
5+
is used—hooks read from the orchestrator directly.
6+
"""
7+
8+
from dataclasses import dataclass
9+
from typing import TYPE_CHECKING, Any
10+
11+
from ....hooks import BaseHookEvent
12+
13+
if TYPE_CHECKING:
14+
from ....multiagent.base import MultiAgentBase
15+
16+
17+
@dataclass
18+
class MultiAgentInitializedEvent(BaseHookEvent):
19+
"""Event triggered when multi-agent orchestrator initialized.
20+
21+
Attributes:
22+
source: The multi-agent orchestrator instance
23+
invocation_state: Configuration that user passes in
24+
"""
25+
26+
source: "MultiAgentBase"
27+
invocation_state: dict[str, Any] | None = None
28+
29+
30+
@dataclass
31+
class BeforeNodeCallEvent(BaseHookEvent):
32+
"""Event triggered before individual node execution starts.
33+
34+
Attributes:
35+
source: The multi-agent orchestrator instance
36+
node_id: ID of the node about to execute
37+
invocation_state: Configuration that user passes in
38+
"""
39+
40+
source: "MultiAgentBase"
41+
node_id: str
42+
invocation_state: dict[str, Any] | None = None
43+
44+
45+
@dataclass
46+
class AfterNodeCallEvent(BaseHookEvent):
47+
"""Event triggered after individual node execution completes.
48+
49+
Attributes:
50+
source: The multi-agent orchestrator instance
51+
node_id: ID of the node that just completed execution
52+
invocation_state: Configuration that user passes in
53+
"""
54+
55+
source: "MultiAgentBase"
56+
node_id: str
57+
invocation_state: dict[str, Any] | None = None
58+
59+
@property
60+
def should_reverse_callbacks(self) -> bool:
61+
"""True to invoke callbacks in reverse order."""
62+
return True
63+
64+
65+
@dataclass
66+
class BeforeMultiAgentInvocationEvent(BaseHookEvent):
67+
"""Event triggered before orchestrator execution starts.
68+
69+
Attributes:
70+
source: The multi-agent orchestrator instance
71+
invocation_state: Configuration that user passes in
72+
"""
73+
74+
source: "MultiAgentBase"
75+
invocation_state: dict[str, Any] | None = None
76+
77+
78+
@dataclass
79+
class AfterMultiAgentInvocationEvent(BaseHookEvent):
80+
"""Event triggered after orchestrator execution completes.
81+
82+
Attributes:
83+
source: The multi-agent orchestrator instance
84+
invocation_state: Configuration that user passes in
85+
"""
86+
87+
source: "MultiAgentBase"
88+
invocation_state: dict[str, Any] | None = None
89+
90+
@property
91+
def should_reverse_callbacks(self) -> bool:
92+
"""True to invoke callbacks in reverse order."""
93+
return True

src/strands/multiagent/base.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import asyncio
7+
import logging
78
import warnings
89
from abc import ABC, abstractmethod
910
from concurrent.futures import ThreadPoolExecutor
@@ -15,6 +16,8 @@
1516
from ..types.content import ContentBlock
1617
from ..types.event_loop import Metrics, Usage
1718

19+
logger = logging.getLogger(__name__)
20+
1821

1922
class Status(Enum):
2023
"""Execution status for both graphs and nodes."""
@@ -59,6 +62,54 @@ def get_agent_results(self) -> list[AgentResult]:
5962
flattened.extend(nested_node_result.get_agent_results())
6063
return flattened
6164

65+
def to_dict(self) -> dict[str, Any]:
66+
"""Convert NodeResult to JSON-serializable dict, ignoring state field."""
67+
if isinstance(self.result, Exception):
68+
result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)}
69+
elif isinstance(self.result, AgentResult):
70+
result_data = self.result.to_dict()
71+
else:
72+
# MultiAgentResult case
73+
result_data = self.result.to_dict()
74+
75+
return {
76+
"result": result_data,
77+
"execution_time": self.execution_time,
78+
"status": self.status.value,
79+
"accumulated_usage": self.accumulated_usage,
80+
"accumulated_metrics": self.accumulated_metrics,
81+
"execution_count": self.execution_count,
82+
}
83+
84+
@classmethod
85+
def from_dict(cls, data: dict[str, Any]) -> "NodeResult":
86+
"""Rehydrate a NodeResult from persisted JSON."""
87+
if "result" not in data:
88+
raise TypeError("NodeResult.from_dict: missing 'result'")
89+
raw = data["result"]
90+
91+
result: Union[AgentResult, "MultiAgentResult", Exception]
92+
if isinstance(raw, dict) and raw.get("type") == "agent_result":
93+
result = AgentResult.from_dict(raw)
94+
elif isinstance(raw, dict) and raw.get("type") == "exception":
95+
result = Exception(str(raw.get("message", "node failed")))
96+
elif isinstance(raw, dict) and raw.get("type") == "multiagent_result":
97+
result = MultiAgentResult.from_dict(raw)
98+
else:
99+
raise TypeError(f"NodeResult.from_dict: unsupported result payload: {raw!r}")
100+
101+
usage = _parse_usage(data.get("accumulated_usage", {}))
102+
metrics = _parse_metrics(data.get("accumulated_metrics", {}))
103+
104+
return cls(
105+
result=result,
106+
execution_time=int(data.get("execution_time", 0)),
107+
status=Status(data.get("status", "pending")),
108+
accumulated_usage=usage,
109+
accumulated_metrics=metrics,
110+
execution_count=int(data.get("execution_count", 0)),
111+
)
112+
62113

63114
@dataclass
64115
class MultiAgentResult:
@@ -76,6 +127,38 @@ class MultiAgentResult:
76127
execution_count: int = 0
77128
execution_time: int = 0
78129

130+
@classmethod
131+
def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
132+
"""Rehydrate a MultiAgentResult from persisted JSON."""
133+
if data.get("type") != "multiagent_result":
134+
raise TypeError(f"MultiAgentResult.from_dict: unexpected type {data.get('type')!r}")
135+
136+
results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()}
137+
usage = _parse_usage(data.get("accumulated_usage", {}))
138+
metrics = _parse_metrics(data.get("accumulated_metrics", {}))
139+
140+
multiagent_result = cls(
141+
status=Status(data.get("status", Status.PENDING.value)),
142+
results=results,
143+
accumulated_usage=usage,
144+
accumulated_metrics=metrics,
145+
execution_count=int(data.get("execution_count", 0)),
146+
execution_time=int(data.get("execution_time", 0)),
147+
)
148+
return multiagent_result
149+
150+
def to_dict(self) -> dict[str, Any]:
151+
"""Convert MultiAgentResult to JSON-serializable dict."""
152+
return {
153+
"type": "multiagent_result",
154+
"status": self.status.value,
155+
"results": {k: v.to_dict() for k, v in self.results.items()},
156+
"accumulated_usage": self.accumulated_usage,
157+
"accumulated_metrics": self.accumulated_metrics,
158+
"execution_count": self.execution_count,
159+
"execution_time": self.execution_time,
160+
}
161+
79162

80163
class MultiAgentBase(ABC):
81164
"""Base class for multi-agent helpers.
@@ -122,3 +205,34 @@ def execute() -> MultiAgentResult:
122205
with ThreadPoolExecutor() as executor:
123206
future = executor.submit(execute)
124207
return future.result()
208+
209+
def serialize_state(self) -> dict[str, Any]:
210+
"""Return a JSON-serializable snapshot of the orchestrator state."""
211+
raise NotImplementedError
212+
213+
def deserialize_state(self, payload: dict[str, Any]) -> None:
214+
"""Restore orchestrator state from a session dict."""
215+
raise NotImplementedError
216+
217+
218+
# Private helper function to avoid duplicate code
219+
220+
221+
def _parse_usage(usage_data: dict[str, Any]) -> Usage:
222+
"""Parse Usage from dict data."""
223+
usage = Usage(
224+
inputTokens=usage_data.get("inputTokens", 0),
225+
outputTokens=usage_data.get("outputTokens", 0),
226+
totalTokens=usage_data.get("totalTokens", 0),
227+
)
228+
# Add optional fields if they exist
229+
if "cacheReadInputTokens" in usage_data:
230+
usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"]
231+
if "cacheWriteInputTokens" in usage_data:
232+
usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"]
233+
return usage
234+
235+
236+
def _parse_metrics(metrics_data: dict[str, Any]) -> Metrics:
237+
"""Parse Metrics from dict data."""
238+
return Metrics(latencyMs=metrics_data.get("latencyMs", 0))
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import Iterator, Literal, Tuple, Type
2+
3+
from strands.experimental.hooks.multiagent.events import (
4+
AfterMultiAgentInvocationEvent,
5+
AfterNodeCallEvent,
6+
BeforeNodeCallEvent,
7+
MultiAgentInitializedEvent,
8+
)
9+
from strands.hooks import (
10+
HookEvent,
11+
HookProvider,
12+
HookRegistry,
13+
)
14+
15+
16+
class MockMultiAgentHookProvider(HookProvider):
17+
def __init__(self, event_types: list[Type] | Literal["all"]):
18+
if event_types == "all":
19+
event_types = [
20+
MultiAgentInitializedEvent,
21+
BeforeNodeCallEvent,
22+
AfterNodeCallEvent,
23+
AfterMultiAgentInvocationEvent,
24+
]
25+
26+
self.events_received = []
27+
self.events_types = event_types
28+
29+
@property
30+
def event_types_received(self):
31+
return [type(event) for event in self.events_received]
32+
33+
def get_events(self) -> Tuple[int, Iterator[HookEvent]]:
34+
return len(self.events_received), iter(self.events_received)
35+
36+
def register_hooks(self, registry: HookRegistry) -> None:
37+
for event_type in self.events_types:
38+
registry.add_callback(event_type, self.add_event)
39+
40+
def add_event(self, event: HookEvent) -> None:
41+
self.events_received.append(event)

tests/strands/agent/test_agent_result.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,48 @@ def test__str__non_dict_content(mock_metrics):
9595

9696
message_string = str(result)
9797
assert message_string == "Valid text\nMore valid text\n"
98+
99+
100+
def test_to_dict(mock_metrics, simple_message: Message):
101+
"""Test that to_dict serializes AgentResult correctly."""
102+
result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={"key": "value"})
103+
104+
data = result.to_dict()
105+
106+
assert data == {
107+
"type": "agent_result",
108+
"message": simple_message,
109+
"stop_reason": "end_turn",
110+
}
111+
112+
113+
def test_from_dict():
114+
"""Test that from_dict works with valid data."""
115+
data = {
116+
"type": "agent_result",
117+
"message": {"role": "assistant", "content": [{"text": "Test response"}]},
118+
"stop_reason": "end_turn",
119+
}
120+
121+
result = AgentResult.from_dict(data)
122+
123+
assert result.message == data["message"]
124+
assert result.stop_reason == data["stop_reason"]
125+
assert isinstance(result.metrics, EventLoopMetrics)
126+
assert result.state == {}
127+
128+
129+
def test_roundtrip_serialization(mock_metrics, complex_message: Message):
130+
"""Test that to_dict() and from_dict() work together correctly."""
131+
original = AgentResult(
132+
stop_reason="max_tokens", message=complex_message, metrics=mock_metrics, state={"test": "data"}
133+
)
134+
135+
# Serialize and deserialize
136+
data = original.to_dict()
137+
restored = AgentResult.from_dict(data)
138+
139+
assert restored.message == original.message
140+
assert restored.stop_reason == original.stop_reason
141+
assert isinstance(restored.metrics, EventLoopMetrics)
142+
assert restored.state == {} # State is not serialized

tests/strands/experimental/hooks/multiagent/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)