Skip to content

Commit 48a006f

Browse files
committed
add A2AStreamEvent + handle partial responses in stream_async + modify constructor parameters
1 parent a346143 commit 48a006f

File tree

4 files changed

+229
-71
lines changed

4 files changed

+229
-71
lines changed

src/strands/agent/a2a_agent.py

Lines changed: 123 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
"""A2A Agent client for Strands Agents.
22
33
This module provides the A2AAgent class, which acts as a client wrapper for remote A2A agents,
4-
allowing them to be used in graphs, swarms, and other multi-agent patterns.
4+
allowing them to be used standalone or as part of multi-agent patterns.
5+
6+
A2AAgent can be used to get the Agent Card and interact with the agent.
57
"""
68

79
import logging
810
from typing import Any, AsyncIterator
911

1012
import httpx
11-
from a2a.client import A2ACardResolver, ClientConfig, ClientFactory
12-
from a2a.types import AgentCard
13+
from a2a.client import A2ACardResolver, Client, ClientConfig, ClientFactory
14+
from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent
1315

1416
from .._async import run_async
1517
from ..multiagent.a2a.converters import convert_input_to_message, convert_response_to_agent_result
18+
from ..types._events import AgentResultEvent
19+
from ..types.a2a import A2AResponse, A2AStreamEvent
1620
from ..types.agent import AgentInput
1721
from .agent_result import AgentResult
1822

@@ -22,59 +26,46 @@
2226

2327

2428
class A2AAgent:
25-
"""Client wrapper for remote A2A agents.
26-
27-
Implements the AgentBase protocol to enable remote A2A agents to be used
28-
in graphs, swarms, and other multi-agent patterns.
29-
"""
29+
"""Client wrapper for remote A2A agents."""
3030

3131
def __init__(
3232
self,
3333
endpoint: str,
34+
*,
35+
name: str | None = None,
36+
description: str = "",
3437
timeout: int = DEFAULT_TIMEOUT,
35-
httpx_client_args: dict[str, Any] | None = None,
38+
a2a_client_factory: ClientFactory | None = None,
3639
):
37-
"""Initialize A2A agent client.
40+
"""Initialize A2A agent.
3841
3942
Args:
40-
endpoint: The base URL of the remote A2A agent
41-
timeout: Timeout for HTTP operations in seconds (defaults to 300)
42-
httpx_client_args: Optional dictionary of arguments to pass to httpx.AsyncClient
43-
constructor. Allows custom auth, headers, proxies, etc.
44-
Example: {"headers": {"Authorization": "Bearer token"}}
43+
endpoint: The base URL of the remote A2A agent.
44+
name: Agent name. If not provided, will be populated from agent card.
45+
description: Agent description. If empty, will be populated from agent card.
46+
timeout: Timeout for HTTP operations in seconds (defaults to 300).
47+
a2a_client_factory: Optional pre-configured A2A ClientFactory. If provided,
48+
it will be used to create the A2A client after discovering the agent card.
4549
"""
4650
self.endpoint = endpoint
51+
self.name = name
52+
self.description = description
4753
self.timeout = timeout
48-
self._httpx_client_args: dict[str, Any] = httpx_client_args or {}
49-
50-
if "timeout" not in self._httpx_client_args:
51-
self._httpx_client_args["timeout"] = self.timeout
52-
54+
self._httpx_client: httpx.AsyncClient | None = None
55+
self._owns_client = a2a_client_factory is None
5356
self._agent_card: AgentCard | None = None
57+
self._a2a_client: Client | None = None
58+
self._a2a_client_factory: ClientFactory | None = a2a_client_factory
5459

5560
def _get_httpx_client(self) -> httpx.AsyncClient:
56-
"""Get a fresh httpx client for the current operation.
61+
"""Get or create the httpx client for this agent.
5762
5863
Returns:
5964
Configured httpx.AsyncClient instance.
6065
"""
61-
return httpx.AsyncClient(**self._httpx_client_args)
62-
63-
def _get_client_factory(self, streaming: bool = False) -> ClientFactory:
64-
"""Get a ClientFactory for the current operation.
65-
66-
Args:
67-
streaming: Whether to enable streaming mode.
68-
69-
Returns:
70-
Configured ClientFactory instance.
71-
"""
72-
httpx_client = self._get_httpx_client()
73-
config = ClientConfig(
74-
httpx_client=httpx_client,
75-
streaming=streaming,
76-
)
77-
return ClientFactory(config)
66+
if self._httpx_client is None:
67+
self._httpx_client = httpx.AsyncClient(timeout=self.timeout)
68+
return self._httpx_client
7869

7970
async def _get_agent_card(self) -> AgentCard:
8071
"""Discover and cache the agent card from the remote endpoint.
@@ -88,15 +79,44 @@ async def _get_agent_card(self) -> AgentCard:
8879
httpx_client = self._get_httpx_client()
8980
resolver = A2ACardResolver(httpx_client=httpx_client, base_url=self.endpoint)
9081
self._agent_card = await resolver.get_agent_card()
91-
logger.info("endpoint=<%s> | discovered agent card", self.endpoint)
82+
83+
# Populate name from card if not set
84+
if self.name is None and self._agent_card.name:
85+
self.name = self._agent_card.name
86+
87+
# Populate description from card if not set
88+
if not self.description and self._agent_card.description:
89+
self.description = self._agent_card.description
90+
91+
logger.info("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint)
9292
return self._agent_card
9393

94-
async def _send_message(self, prompt: AgentInput, streaming: bool) -> AsyncIterator[Any]:
94+
async def _get_a2a_client(self) -> Client:
95+
"""Get or create the A2A client for this agent.
96+
97+
Returns:
98+
Configured A2A client instance.
99+
"""
100+
if self._a2a_client is None:
101+
agent_card = await self._get_agent_card()
102+
103+
if self._a2a_client_factory is not None:
104+
# Use provided factory
105+
factory = self._a2a_client_factory
106+
else:
107+
# Create default factory
108+
httpx_client = self._get_httpx_client()
109+
config = ClientConfig(httpx_client=httpx_client, streaming=False)
110+
factory = ClientFactory(config)
111+
112+
self._a2a_client = factory.create(agent_card)
113+
return self._a2a_client
114+
115+
async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]:
95116
"""Send message to A2A agent.
96117
97118
Args:
98119
prompt: Input to send to the agent.
99-
streaming: Whether to use streaming mode.
100120
101121
Returns:
102122
Async iterator of A2A events.
@@ -107,13 +127,46 @@ async def _send_message(self, prompt: AgentInput, streaming: bool) -> AsyncItera
107127
if prompt is None:
108128
raise ValueError("prompt is required for A2AAgent")
109129

110-
agent_card = await self._get_agent_card()
111-
client = self._get_client_factory(streaming=streaming).create(agent_card)
130+
client = await self._get_a2a_client()
112131
message = convert_input_to_message(prompt)
113132

114-
logger.info("endpoint=<%s> | %s message", self.endpoint, "streaming" if streaming else "sending")
133+
logger.info("agent=<%s>, endpoint=<%s> | sending message", self.name, self.endpoint)
115134
return client.send_message(message)
116135

136+
def _is_complete_event(self, event: A2AResponse) -> bool:
137+
"""Check if an A2A event represents a complete response.
138+
139+
Args:
140+
event: A2A event.
141+
142+
Returns:
143+
True if the event represents a complete response.
144+
"""
145+
# Direct Message is always complete
146+
if isinstance(event, Message):
147+
return True
148+
149+
# Handle tuple responses (Task, UpdateEvent | None)
150+
if isinstance(event, tuple) and len(event) == 2:
151+
task, update_event = event
152+
153+
# Initial task response (no update event)
154+
if update_event is None:
155+
return True
156+
157+
# Artifact update with last_chunk flag
158+
if isinstance(update_event, TaskArtifactUpdateEvent):
159+
if hasattr(update_event, "last_chunk") and update_event.last_chunk is not None:
160+
return update_event.last_chunk
161+
return False
162+
163+
# Status update with completed state
164+
if isinstance(update_event, TaskStatusUpdateEvent):
165+
if update_event.status and hasattr(update_event.status, "state"):
166+
return update_event.status.state == TaskState.completed
167+
168+
return False
169+
117170
async def invoke_async(
118171
self,
119172
prompt: AgentInput = None,
@@ -132,7 +185,7 @@ async def invoke_async(
132185
ValueError: If prompt is None.
133186
RuntimeError: If no response received from agent.
134187
"""
135-
async for event in await self._send_message(prompt, streaming=False):
188+
async for event in await self._send_message(prompt):
136189
return convert_response_to_agent_result(event)
137190

138191
raise RuntimeError("No response received from A2A agent")
@@ -169,10 +222,32 @@ async def stream_async(
169222
**kwargs: Additional arguments (ignored).
170223
171224
Yields:
172-
A2A events wrapped in dictionaries with an 'a2a_event' key.
225+
A2A events and a final AgentResult event.
173226
174227
Raises:
175228
ValueError: If prompt is None.
176229
"""
177-
async for event in await self._send_message(prompt, streaming=True):
178-
yield {"a2a_event": event}
230+
last_event = None
231+
last_complete_event = None
232+
233+
async for event in await self._send_message(prompt):
234+
last_event = event
235+
if self._is_complete_event(event):
236+
last_complete_event = event
237+
yield A2AStreamEvent(event)
238+
239+
# Use the last complete event if available, otherwise fall back to last event
240+
final_event = last_complete_event if last_complete_event is not None else last_event
241+
242+
if final_event is not None:
243+
result = convert_response_to_agent_result(final_event)
244+
yield AgentResultEvent(result)
245+
246+
def __del__(self) -> None:
247+
"""Clean up resources when agent is garbage collected."""
248+
if self._owns_client and self._httpx_client is not None:
249+
try:
250+
client = self._httpx_client
251+
run_async(lambda: client.aclose())
252+
except Exception:
253+
pass # Best effort cleanup, ignore errors in __del__

src/strands/multiagent/a2a/converters.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
"""Conversion functions between Strands and A2A types."""
22

3-
from typing import TypeAlias, cast
3+
from typing import cast
44
from uuid import uuid4
55

66
from a2a.types import Message as A2AMessage
7-
from a2a.types import Part, Role, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart
7+
from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart
88

99
from ...agent.agent_result import AgentResult
1010
from ...telemetry.metrics import EventLoopMetrics
11+
from ...types.a2a import A2AResponse
1112
from ...types.agent import AgentInput
1213
from ...types.content import ContentBlock, Message
1314

14-
A2AResponse: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | A2AMessage
15-
1615

1716
def convert_input_to_message(prompt: AgentInput) -> A2AMessage:
1817
"""Convert AgentInput to A2A Message.
@@ -89,7 +88,21 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult:
8988

9089
if isinstance(response, tuple) and len(response) == 2:
9190
task, update_event = response
92-
if update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None:
91+
92+
# Handle artifact updates
93+
if isinstance(update_event, TaskArtifactUpdateEvent):
94+
if update_event.artifact and hasattr(update_event.artifact, "parts"):
95+
for part in update_event.artifact.parts:
96+
if hasattr(part, "root") and hasattr(part.root, "text"):
97+
content.append({"text": part.root.text})
98+
# Handle status updates with messages
99+
elif isinstance(update_event, TaskStatusUpdateEvent):
100+
if update_event.status and hasattr(update_event.status, "message") and update_event.status.message:
101+
for part in update_event.status.message.parts:
102+
if hasattr(part, "root") and hasattr(part.root, "text"):
103+
content.append({"text": part.root.text})
104+
# Handle initial task or task without update event
105+
elif update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None:
93106
for artifact in task.artifacts:
94107
if hasattr(artifact, "parts"):
95108
for part in artifact.parts:

src/strands/types/a2a.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Additional A2A types."""
2+
3+
from typing import TypeAlias
4+
5+
from a2a.types import Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent
6+
7+
from ._events import TypedEvent
8+
9+
A2AResponse: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message
10+
11+
12+
class A2AStreamEvent(TypedEvent):
13+
"""Event that wraps streamed A2A types."""
14+
15+
def __init__(self, a2a_event: A2AResponse) -> None:
16+
"""Initialize with A2A event.
17+
18+
Args:
19+
a2a_event: The original A2A event (Task tuple or Message)
20+
"""
21+
super().__init__(
22+
{
23+
"type": "a2a_stream",
24+
"event": a2a_event, # Nest A2A event to avoid field conflicts
25+
}
26+
)

0 commit comments

Comments
 (0)