From ed1768dfc9c985a53d77a0ecfaf9e9b01c93b027 Mon Sep 17 00:00:00 2001 From: Ryan Coleman Date: Tue, 3 Jun 2025 16:18:36 -0700 Subject: [PATCH] Prototype: A2A support through a protocols abstraction --- examples/iss_agent.py | 52 ++++ examples/iss_client.py | 72 +++++ pyproject.toml | 10 +- src/strands/__init__.py | 4 +- src/strands/agent/agent.py | 220 +++++++++++++- src/strands/protocols/__init__.py | 38 +++ src/strands/protocols/a2a/__init__.py | 20 ++ src/strands/protocols/a2a/client.py | 223 ++++++++++++++ src/strands/protocols/a2a/server.py | 381 ++++++++++++++++++++++++ src/strands/protocols/a2a/tools.py | 352 ++++++++++++++++++++++ src/strands/types/protocols/__init__.py | 5 + src/strands/types/protocols/protocol.py | 88 ++++++ 12 files changed, 1459 insertions(+), 6 deletions(-) create mode 100644 examples/iss_agent.py create mode 100644 examples/iss_client.py create mode 100644 src/strands/protocols/__init__.py create mode 100644 src/strands/protocols/a2a/__init__.py create mode 100644 src/strands/protocols/a2a/client.py create mode 100644 src/strands/protocols/a2a/server.py create mode 100644 src/strands/protocols/a2a/tools.py create mode 100644 src/strands/types/protocols/__init__.py create mode 100644 src/strands/types/protocols/protocol.py diff --git a/examples/iss_agent.py b/examples/iss_agent.py new file mode 100644 index 000000000..e72a5be80 --- /dev/null +++ b/examples/iss_agent.py @@ -0,0 +1,52 @@ +"""ISS Location Agent Server + +This agent can answer questions about the International Space Station's location +and calculate distances to various cities. + +Run with: uv run examples/iss.py +Then test with: uv run examples/iss_client.py +""" + +from strands import Agent +from strands_tools import http_request, python_repl +from strands.protocols import A2AProtocolServer + +# Create the ISS agent with tools for web requests and calculations +agent = Agent( + tools=[http_request, python_repl], + system_prompt="You are a helpful assistant that can answer questions about the International Space Station's location and calculate distances to various cities.", + name="ISS Location Agent", + description="An intelligent agent that tracks the International Space Station's real-time position and calculates distances to cities worldwide. Provides accurate geospatial analysis and space-related information.", + # Uncomment to use a specific model: + # model="us.amazon.nova-premier-v1:0", + # model="us.anthropic.claude-sonnet-4-20250514-v1:0", +) + +# Configure the A2A server +server_config = A2AProtocolServer( + port=8000, + host="0.0.0.0", + version="1.2.3" +) + +print(f"Starting ISS Location Agent...") +print(f"Model: {agent.model.config}") + +# Serve the agent - it's now ready to handle requests! +server = agent.serve(server_config) + +print("\n" + "="*50) +print("ISS Agent is now running!") +print(f"- Agent card: http://localhost:8000/.well-known/agent.json") +print(f"- Send requests to: http://localhost:8000/") +print("- Test with: uv run examples/iss_client.py") +print("="*50) + +# Keep the server running +try: + import time + while True: + time.sleep(1) +except KeyboardInterrupt: + print("\n\nShutting down ISS agent...") + server.stop() diff --git a/examples/iss_client.py b/examples/iss_client.py new file mode 100644 index 000000000..837423662 --- /dev/null +++ b/examples/iss_client.py @@ -0,0 +1,72 @@ +"""Example client for the ISS agent. + +This shows how to interact with the ISS agent once it's running. +First run: uv run examples/iss_agent.py +Then run: uv run examples/iss_client.py +""" + +import time +from strands.protocols import A2AProtocolClient +import asyncio + +# The URL where your ISS agent is running +AGENT_URL = "http://localhost:8000" + +async def test_agent(): + """Test the ISS agent with better error handling.""" + + # Check if server is running first + print("Checking if agent server is running...") + + async with A2AProtocolClient(AGENT_URL) as client: + try: + # Try to fetch agent card with shorter timeout first + print("Fetching agent card...") + agent_card = await client.fetch_agent_card() + print(f"\nāœ… Connected to agent!") + print(f"Agent: {agent_card.name}") + print(f"Description: {agent_card.description}") + print(f"Available skills: {[skill.name for skill in agent_card.skills]}") + + except Exception as e: + print(f"āŒ Failed to connect to agent at {AGENT_URL}") + print(f"Error: {e}") + print("\nMake sure the agent server is running:") + print(" uv run examples/iss_agent.py") + return + + # Now send your ISS question with longer timeout + print("\n" + "="*50) + print("Sending ISS question to agent...") + print("This may take a while as the agent needs to:") + print("- Look up real-time ISS position") + print("- Calculate distances to multiple cities") + print("- Perform complex calculations") + print("="*50 + "\n") + + try: + # Use longer timeout for complex calculation + response = await client.send_task_and_wait( + message="Who is the closest to the ISS? People in: " + "Portland, Vancouver, Seattle, or New York? " + "First, lookup realtime information about the position of the ISS. " + "Give me the altitude of the ISS, and the distance and vector from the closest city to the ISS. " + "After you give me the answer, explain your reasoning and show me any code you used", + timeout=120.0 # 2 minutes for complex calculation + ) + + print("šŸš€ ISS Agent Response:") + print("="*50) + print(response) + + except TimeoutError as e: + print(f"ā±ļø Request timed out: {e}") + print("The agent may be taking longer than expected.") + print("Try again or increase the timeout.") + + except Exception as e: + print(f"āŒ Error during request: {e}") + +# Run the async client +if __name__ == "__main__": + asyncio.run(test_agent()) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index bd3097327..fe7e15caf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "strands-agents" dynamic = ["version"] description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" license = {text = "Apache-2.0"} authors = [ {name = "AWS", email = "opensource@amazon.com"}, @@ -59,7 +59,7 @@ dev = [ "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", "ruff>=0.4.4,<0.5.0", - "swagger-parser>=1.0.2,<2.0.0", + "swagger-parser>=1.0.1,<2.0.0", ] docs = [ "sphinx>=5.0.0,<6.0.0", @@ -78,6 +78,12 @@ ollama = [ openai = [ "openai>=1.68.0,<2.0.0", ] +a2a = [ + "a2a-sdk>=0.2.5", + "httpx>=0.27.0", + "fastapi>=0.68.0", + "uvicorn>=0.15.0" +] [tool.hatch.version] # Tells Hatch to use your version control system (git) to determine the version. diff --git a/src/strands/__init__.py b/src/strands/__init__.py index f4b1228d2..4c4658e22 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -1,8 +1,8 @@ """A framework for building, deploying, and managing AI agents.""" -from . import agent, event_loop, models, telemetry, types +from . import agent, event_loop, models, telemetry, types, protocols from .agent.agent import Agent from .tools.decorator import tool from .tools.thread_pool_executor import ThreadPoolExecutorWrapper -__all__ = ["Agent", "ThreadPoolExecutorWrapper", "agent", "event_loop", "models", "tool", "types", "telemetry"] +__all__ = ["Agent", "ThreadPoolExecutorWrapper", "agent", "event_loop", "models", "tool", "types", "telemetry", "protocols"] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index bfa83fe20..a0936e188 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,7 +16,7 @@ import random from concurrent.futures import ThreadPoolExecutor from threading import Thread -from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Union, TYPE_CHECKING from uuid import uuid4 from opentelemetry import trace @@ -183,6 +183,8 @@ def __init__( record_direct_tool_call: bool = True, load_tools_from_directory: bool = True, trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + name: Optional[str] = None, + description: Optional[str] = None, ): """Initialize the Agent with the specified configuration. @@ -214,6 +216,10 @@ def __init__( load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. Defaults to True. trace_attributes: Custom trace attributes to apply to the agent's trace span. + name: Optional name for the agent. Used in agent cards and logging. + If None, defaults to the class name when generating agent cards. + description: Optional description of the agent's purpose and capabilities. + If None, defaults to a generic description when generating agent cards. Raises: ValueError: If max_parallel_tools is less than 1. @@ -223,6 +229,10 @@ def __init__( self.system_prompt = system_prompt self.callback_handler = callback_handler or null_callback_handler + + # Agent metadata for agent cards + self.name = name + self.description = description self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager() @@ -301,15 +311,194 @@ def tool_config(self) -> ToolConfig: """ return self.tool_registry.initialize_tool_config() + def register_tool(self, tool: Any) -> None: + """Register a tool with this agent. + + Args: + tool: The tool to register. Can be a function decorated with @strands.tool, + a Tool instance, or a dictionary with tool configuration. + + Example: + ```python + # Register a function tool + @strands.tool + def calculate(expression: str) -> str: + return str(eval(expression)) + + agent.register_tool(calculate) + + # Register a Tool instance + from strands.tools import WebSearchTool + agent.register_tool(WebSearchTool()) + ``` + """ + self.tool_registry.register_tool(tool) + + def register_remote_agent( + self, + agent_url: str, + name: Optional[str] = None, + skills: Optional[List[str]] = None, + auth_token: Optional[str] = None + ) -> None: + """Register a remote A2A agent as tool(s). + + This method enables the "Agent as Tool" pattern, allowing you to use + remote agents as if they were local tools. The remote agent's capabilities + are automatically discovered via the A2A protocol. + + Args: + agent_url: URL of the remote A2A agent (e.g., "http://agent:8000") + name: Optional custom name for the tool. If not provided, uses the + remote agent's name. + skills: Optional list of specific skills to register. If None, + registers the entire agent as one tool. If provided, + registers each skill as a separate tool. + auth_token: Optional authentication token for the remote agent. + + Examples: + ```python + # Register entire remote agent as one tool + agent.register_remote_agent("http://research-agent:8000") + + # Register with custom name + agent.register_remote_agent( + "http://research-agent:8000", + name="research" + ) + + # Register specific skills as separate tools + agent.register_remote_agent( + "http://multi-skill-agent:8000", + skills=["web_search", "summarize"] + ) + + # With authentication + agent.register_remote_agent( + "http://secure-agent:8000", + auth_token="secret-token" + ) + ``` + + Raises: + ConnectionError: If unable to connect to the remote agent + ValueError: If requested skills are not found in the remote agent + """ + from ..protocols.a2a.tools import A2ARemoteTool, create_agent_tools_from_skills + + if skills: + # Register each skill as a separate tool + tools = create_agent_tools_from_skills(agent_url, skills, auth_token) + for tool in tools: + self.register_tool(tool) + logger.info(f"Registered remote skill '{tool.skill_id}' from {agent_url}") + else: + # Register entire agent as one tool + tool = A2ARemoteTool(agent_url, name=name, auth_token=auth_token) + self.register_tool(tool) + logger.info(f"Registered remote agent from {agent_url} as '{tool.tool_name}'") + def __del__(self) -> None: """Clean up resources when Agent is garbage collected. - Ensures proper shutdown of the thread pool executor if one exists. + Ensures proper shutdown of the thread pool executor and any protocol servers. """ + # Stop any running protocol servers + self.stop_servers() + + # Shutdown thread pool if self.thread_pool_wrapper and hasattr(self.thread_pool_wrapper, "shutdown"): self.thread_pool_wrapper.shutdown(wait=False) logger.debug("thread pool executor shutdown complete") + def serve( + self, + protocol: Union[str, "ProtocolServer", None] = None, + **server_config: Any + ) -> "ProtocolServer": + """Make this agent network accessible via a protocol server. + + This method enables the one-liner pattern for exposing agents over the network, + supporting multiple protocols and configuration options. + + Args: + protocol: The protocol to use. Can be: + - None: Use default protocol (A2A) + - str: Protocol name ('a2a', 'mcp', 'graphql', etc.) + - ProtocolServer: A pre-configured server instance + **server_config: Configuration parameters passed to the protocol server + (e.g., port=8080, host='0.0.0.0', enable_auth=True) + + Returns: + The started protocol server instance + + Examples: + ```python + # Default A2A server + agent.serve() + + # Specify protocol by name + agent.serve(protocol='mcp') + + # Configure server parameters + agent.serve(port=8080, enable_auth=True, auth_token='secret') + + # Use pre-configured server + from strands.protocols import A2AProtocolServer + server = A2AProtocolServer(port=9000, tls_cert='cert.pem') + agent.serve(server) + ``` + """ + from ..types.protocols import ProtocolServer + from ..protocols import PROTOCOL_REGISTRY, A2AProtocolServer + + # Determine which server to use + if protocol is None: + # Default to A2A + server = A2AProtocolServer(**server_config) + elif isinstance(protocol, str): + # Look up protocol by name + protocol_lower = protocol.lower() + if protocol_lower not in PROTOCOL_REGISTRY: + available = ", ".join(PROTOCOL_REGISTRY.keys()) + raise ValueError( + f"Unknown protocol '{protocol}'. Available protocols: {available}" + ) + server_class = PROTOCOL_REGISTRY[protocol_lower] + server = server_class(**server_config) + elif isinstance(protocol, ProtocolServer): + # Use provided server instance + server = protocol + # Apply any additional config + if server_config: + server.update_config(**server_config) + else: + raise TypeError( + f"protocol must be None, str, or ProtocolServer instance, got {type(protocol)}" + ) + + # Start the server with this agent + server.start(self) + + # Store reference for cleanup + if not hasattr(self, '_protocol_servers'): + self._protocol_servers = [] + self._protocol_servers.append(server) + + logger.info( + f"Agent now accessible via {server.protocol_name} at {server.get_endpoint()}" + ) + + return server + + def stop_servers(self) -> None: + """Stop all protocol servers associated with this agent.""" + if hasattr(self, '_protocol_servers'): + for server in self._protocol_servers: + if server.is_running: + server.stop() + self._protocol_servers.clear() + def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -580,3 +769,30 @@ def _end_agent_trace_span( trace_attributes["error"] = error self.tracer.end_agent_span(**trace_attributes) + + async def handle_a2a_request(self, user_input: str) -> str: + """Handle an A2A request by processing it through the agent. + + Args: + user_input: The text input from the A2A request + + Returns: + The agent's response as a string + """ + try: + # Process the request through the agent's main workflow + result = self(user_input) + + # Extract the text response from the AgentResult + if hasattr(result, 'message') and result.message: + content = result.message.get('content', []) + for content_block in content: + if isinstance(content_block, dict) and 'text' in content_block: + return content_block['text'] + + # Fallback to string representation + return str(result) + + except Exception as e: + logger.error(f"Error handling A2A request: {e}") + return f"Error processing request: {str(e)}" diff --git a/src/strands/protocols/__init__.py b/src/strands/protocols/__init__.py new file mode 100644 index 000000000..6a527e34c --- /dev/null +++ b/src/strands/protocols/__init__.py @@ -0,0 +1,38 @@ +"""Protocol implementations for Strands agents. + +This module provides protocol server implementations for exposing agents +through various communication protocols. +""" + +from .a2a import ( + A2AProtocolServer, + A2AProtocolClient, + A2ARemoteTool, + create_agent_tool, + create_agent_tools_from_skills, + fetch_agent_card_sync, + send_a2a_request_sync +) + +# Protocol registry for dynamic server creation +PROTOCOL_REGISTRY = { + "a2a": A2AProtocolServer, + # Future protocols can be added here: + # "mcp": MCPProtocolServer, + # "graphql": GraphQLProtocolServer, + # "grpc": GRPCProtocolServer, +} + +__all__ = [ + # A2A Protocol + "A2AProtocolServer", + "A2AProtocolClient", + "A2ARemoteTool", + "create_agent_tool", + "create_agent_tools_from_skills", + "fetch_agent_card_sync", + "send_a2a_request_sync", + + # Registry + "PROTOCOL_REGISTRY" +] \ No newline at end of file diff --git a/src/strands/protocols/a2a/__init__.py b/src/strands/protocols/a2a/__init__.py new file mode 100644 index 000000000..6fefc925b --- /dev/null +++ b/src/strands/protocols/a2a/__init__.py @@ -0,0 +1,20 @@ +"""A2A (Agent-to-Agent) protocol implementation.""" + +from .server import A2AProtocolServer +from .client import A2AProtocolClient, fetch_agent_card_sync, send_a2a_request_sync +from .tools import A2ARemoteTool, create_agent_tool, create_agent_tools_from_skills + +__all__ = [ + # Server + "A2AProtocolServer", + + # Client + "A2AProtocolClient", + "fetch_agent_card_sync", + "send_a2a_request_sync", + + # Tools + "A2ARemoteTool", + "create_agent_tool", + "create_agent_tools_from_skills" +] \ No newline at end of file diff --git a/src/strands/protocols/a2a/client.py b/src/strands/protocols/a2a/client.py new file mode 100644 index 000000000..3d9fbaabe --- /dev/null +++ b/src/strands/protocols/a2a/client.py @@ -0,0 +1,223 @@ +"""A2A Protocol Client implementation. + +This module provides client functionality for interacting with remote A2A agents. +""" + +import asyncio +import json +import logging +import uuid +from typing import Any, Dict, Optional, List +from urllib.parse import urljoin + +import httpx +from a2a.types import AgentCard, Message + +logger = logging.getLogger(__name__) + + +class A2AProtocolClient: + """Client for interacting with A2A protocol servers.""" + + def __init__(self, base_url: str, auth_token: Optional[str] = None): + """Initialize A2A client. + + Args: + base_url: Base URL of the remote A2A agent + auth_token: Optional bearer token for authentication + """ + self.base_url = base_url.rstrip('/') + self.auth_token = auth_token + self._agent_card: Optional[AgentCard] = None + self._session: Optional[httpx.AsyncClient] = None + + @property + def agent_card(self) -> Optional[AgentCard]: + """Get cached agent card.""" + return self._agent_card + + async def _get_session(self) -> httpx.AsyncClient: + """Get or create HTTP session.""" + if self._session is None: + headers = {} + if self.auth_token: + headers["Authorization"] = f"Bearer {self.auth_token}" + self._session = httpx.AsyncClient(headers=headers) + return self._session + + async def fetch_agent_card(self, force_refresh: bool = False) -> AgentCard: + """Fetch the agent card from a remote A2A server. + + Args: + force_refresh: Force fetching even if cached + + Returns: + The agent's card with capabilities and skills + """ + if self._agent_card and not force_refresh: + return self._agent_card + + try: + session = await self._get_session() + url = urljoin(self.base_url, "/.well-known/agent.json") + response = await session.get(url) + response.raise_for_status() + + card_data = response.json() + self._agent_card = AgentCard(**card_data) + return self._agent_card + + except httpx.HTTPError as e: + raise ConnectionError(f"Failed to fetch agent card from {self.base_url}: {e}") + except Exception as e: + raise RuntimeError(f"Error parsing agent card: {e}") + + async def send_task( + self, + message: str, + skill_id: Optional[str] = None, + session_id: Optional[str] = None, + timeout: Optional[float] = 30.0 + ) -> Dict[str, Any]: + """Send a task to the remote agent. + + Args: + message: The message/prompt to send + skill_id: Optional specific skill to use + session_id: Optional session ID for conversation continuity + timeout: Request timeout in seconds + + Returns: + The agent's response + """ + task_id = str(uuid.uuid4()) + session_id = session_id or str(uuid.uuid4()) + + # Build JSON-RPC 2.0 request + request_data = { + "jsonrpc": "2.0", + "method": "task.execute", + "id": task_id, + "params": { + "message": message, + "sessionId": session_id + } + } + + # Add skill hint if specified + if skill_id: + request_data["params"]["skillId"] = skill_id + + try: + session = await self._get_session() + response = await session.post( + self.base_url, + json=request_data, + timeout=timeout + ) + response.raise_for_status() + + result = response.json() + + # Check for JSON-RPC error + if "error" in result: + error = result["error"] + raise RuntimeError(f"A2A Error {error.get('code', 'unknown')}: {error.get('message', 'Unknown error')}") + + return result + + except httpx.TimeoutException: + raise TimeoutError(f"Request to {self.base_url} timed out after {timeout}s") + except httpx.HTTPError as e: + raise ConnectionError(f"HTTP error communicating with {self.base_url}: {e}") + + async def send_task_and_wait( + self, + message: str, + skill_id: Optional[str] = None, + session_id: Optional[str] = None, + timeout: Optional[float] = 30.0, + poll_interval: float = 0.5 + ) -> str: + """Send a task and wait for completion, returning just the text result. + + Args: + message: The message/prompt to send + skill_id: Optional specific skill to use + session_id: Optional session ID for conversation continuity + timeout: Total timeout for task completion + poll_interval: Interval between status checks + + Returns: + The text content of the agent's response + """ + response = await self.send_task(message, skill_id, session_id, timeout) + + # Extract text content from response + result = response.get("result", {}) + content = result.get("content", "") + + # Handle different content formats + if isinstance(content, str): + return content + elif isinstance(content, dict): + # Try to extract text from structured content + return content.get("text", str(content)) + else: + return str(content) + + async def close(self): + """Close the client session.""" + if self._session: + await self._session.aclose() + self._session = None + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + + +# Synchronous convenience functions +def fetch_agent_card_sync(agent_url: str, auth_token: Optional[str] = None) -> AgentCard: + """Synchronously fetch an agent card. + + Args: + agent_url: URL of the A2A agent + auth_token: Optional authentication token + + Returns: + The agent's card + """ + async def _fetch(): + async with A2AProtocolClient(agent_url, auth_token) as client: + return await client.fetch_agent_card() + + return asyncio.run(_fetch()) + + +def send_a2a_request_sync( + agent_url: str, + message: str, + auth_token: Optional[str] = None, + skill_id: Optional[str] = None +) -> str: + """Synchronously send a request to an A2A agent. + + Args: + agent_url: URL of the A2A agent + message: Message to send + auth_token: Optional authentication token + skill_id: Optional specific skill to use + + Returns: + The agent's text response + """ + async def _send(): + async with A2AProtocolClient(agent_url, auth_token) as client: + return await client.send_task_and_wait(message, skill_id) + + return asyncio.run(_send()) \ No newline at end of file diff --git a/src/strands/protocols/a2a/server.py b/src/strands/protocols/a2a/server.py new file mode 100644 index 000000000..77f3813aa --- /dev/null +++ b/src/strands/protocols/a2a/server.py @@ -0,0 +1,381 @@ +"""A2A Protocol Server implementation. + +This module implements the A2A (Agent-to-Agent) protocol server following +the ProtocolServer abstraction pattern. +""" + +import json +import logging +import threading +from typing import Any, Dict, Optional, TypedDict, Union +from typing_extensions import Unpack, override + +import uvicorn +from fastapi import FastAPI, Request, Response +from a2a.types import AgentCard, AgentSkill + +from ...types.protocols import ProtocolServer + +logger = logging.getLogger(__name__) + + +class A2AProtocolServer(ProtocolServer): + """A2A protocol server implementation with auto-discovery.""" + + class A2AConfig(TypedDict, total=False): + """Configuration options for A2A servers. + + Attributes: + host: Host to bind the server to. Defaults to "0.0.0.0". + port: Port to bind the server to. Defaults to 8000. + version: A2A protocol version. Defaults to "1.0.0". + enable_auth: Whether to enable authentication. Defaults to False. + auth_token: Bearer token for authentication if enabled. + max_concurrent_tasks: Maximum concurrent tasks. Defaults to 1. + tls_cert: Path to TLS certificate file. + tls_key: Path to TLS private key file. + cors_origins: List of allowed CORS origins. Defaults to ["*"]. + """ + + host: str + port: int + version: str + enable_auth: bool + auth_token: Optional[str] + max_concurrent_tasks: int + tls_cert: Optional[str] + tls_key: Optional[str] + cors_origins: list[str] + + def __init__( + self, + **server_config: Unpack[A2AConfig] + ) -> None: + """Initialize A2A server instance. + + Args: + **server_config: Configuration options for the A2A server. + """ + self.config = A2AProtocolServer.A2AConfig( + host="0.0.0.0", + port=8000, + version="1.0.0", + enable_auth=False, + max_concurrent_tasks=1, + cors_origins=["*"] + ) + self.update_config(**server_config) + + logger.debug("config=<%s> | initializing", self.config) + + self.app = FastAPI() + self.agent: Optional[Any] = None + self.agent_card: Optional[AgentCard] = None + self._is_running = False + self._server_thread: Optional[threading.Thread] = None + self._uvicorn_server: Optional[uvicorn.Server] = None + + @override + def update_config(self, **server_config: Unpack[A2AConfig]) -> None: # type: ignore[override] + """Update the A2A server configuration. + + Args: + **server_config: Configuration overrides. + """ + self.config.update(server_config) + + @override + def get_config(self) -> A2AConfig: + """Get the A2A server configuration. + + Returns: + The A2A server configuration. + """ + return self.config + + @property + @override + def protocol_name(self) -> str: + """Return the protocol name.""" + return "a2a" + + @property + @override + def is_running(self) -> bool: + """Check if server is running.""" + return self._is_running + + @override + def get_endpoint(self) -> str: + """Get the server's endpoint URL. + + Returns: + The URL where the A2A server is accessible. + """ + host = self.config.get("host", "0.0.0.0") + port = self.config.get("port", 8000) + + # Convert 0.0.0.0 to localhost for easier access + if host == "0.0.0.0": + host = "localhost" + + protocol = "https" if self.config.get("tls_cert") else "http" + return f"{protocol}://{host}:{port}" + + def _create_agent_card(self, agent: Any) -> AgentCard: + """Create an A2A agent card from the agent's tools. + + Args: + agent: The Strands agent instance. + + Returns: + An A2A AgentCard with auto-discovered capabilities. + """ + # Get all tool configurations + tool_configs = agent.tool_registry.get_all_tools_config() + + # Generate skills from tools + skills = [] + for tool_name, tool_spec in tool_configs.items(): + # Generate intelligent tags from description + description = tool_spec.get("description", "") + tags = self._generate_tags(description) + + # Generate examples from input schema + examples = self._generate_examples(tool_spec) + + skill = { + "id": tool_name, # Use tool name as skill ID + "name": tool_spec.get("name", tool_name), + "description": description, + "tags": tags, + "examples": examples, + "inputMode": "text/plain", + "outputMode": "text/plain" + } + skills.append(skill) + + # Create agent card + name = getattr(agent, "name", agent.__class__.__name__) + description = getattr(agent, "description", f"A Strands agent with {len(skills)} tools") + + return AgentCard( + name=name, + description=description, + url=self.get_endpoint(), + version=self.config.get("version", "1.0.0"), + skills=skills, + capabilities={ + "streaming": False, + "pushNotifications": False, + "stateTransitionHistory": True + }, + authentication={ + "schemes": ["Bearer"] if self.config.get("enable_auth") else [] + }, + defaultInputModes=["text", "text/plain"], + defaultOutputModes=["text", "text/plain"] + ) + + def _generate_tags(self, description: str) -> list[str]: + """Generate intelligent tags from a description.""" + tags = [] + desc_lower = description.lower() + + # Category mappings + tag_mappings = { + "calculation": ["calculat", "math", "comput", "sum", "average"], + "search": ["search", "find", "lookup", "query", "fetch"], + "file": ["file", "read", "write", "save", "load"], + "web": ["http", "url", "web", "api", "request"], + "data": ["data", "process", "transform", "analyz", "parse"], + "communication": ["email", "message", "notif", "send", "chat"], + "time": ["time", "date", "schedule", "calendar", "timezone"] + } + + for tag, keywords in tag_mappings.items(): + if any(keyword in desc_lower for keyword in keywords): + tags.append(tag) + + # Default tag if none found + if not tags: + tags.append("general") + + return tags + + def _generate_examples(self, tool_spec: dict) -> list[str]: + """Generate example usage from tool input schema.""" + examples = [] + + # Get custom examples if provided + if "examples" in tool_spec: + return tool_spec["examples"] + + # Otherwise generate from schema + input_schema = tool_spec.get("input_schema", {}) + properties = input_schema.get("properties", {}) + + if properties: + # Create a simple example with required fields + required = input_schema.get("required", []) + example_parts = [] + + for prop, schema in properties.items(): + if prop in required: + prop_type = schema.get("type", "string") + if prop_type == "string": + example_parts.append(f"{prop}: 'example'") + elif prop_type == "number": + example_parts.append(f"{prop}: 123") + elif prop_type == "boolean": + example_parts.append(f"{prop}: true") + + if example_parts: + examples.append(f"Use with {', '.join(example_parts)}") + + # Add a generic example based on description + if not examples and "description" in tool_spec: + examples.append(f"Help me {tool_spec['description'].lower()}") + + return examples + + def _setup_routes(self) -> None: + """Set up the FastAPI routes for A2A.""" + + @self.app.get("/.well-known/agent.json") + async def get_agent_card() -> Dict[str, Any]: + """Serve the agent card.""" + if not self.agent_card: + return {"error": "Agent card not initialized"} + return self.agent_card.dict() + + @self.app.post("/") + async def handle_task(request: Request) -> Dict[str, Any]: + """Handle A2A task requests.""" + # Check authentication if enabled + if self.config.get("enable_auth"): + auth_header = request.headers.get("Authorization") + expected_token = f"Bearer {self.config.get('auth_token')}" + + if auth_header != expected_token: + return { + "jsonrpc": "2.0", + "error": { + "code": -32001, + "message": "Unauthorized" + } + } + + # Parse JSON-RPC request + try: + body = await request.json() + method = body.get("method") + params = body.get("params", {}) + request_id = body.get("id") + + # Extract the user message + user_message = params.get("message", params.get("prompt", "")) + + # Call the agent + result = self.agent(user_message) + + # Format response + response_text = str(result) + + return { + "jsonrpc": "2.0", + "result": { + "content": response_text, + "status": "completed" + }, + "id": request_id + } + + except Exception as e: + logger.error(f"Error processing A2A request: {e}") + return { + "jsonrpc": "2.0", + "error": { + "code": -32603, + "message": str(e) + }, + "id": body.get("id") if "body" in locals() else None + } + + @override + def start(self, agent: Any) -> None: + """Start the A2A server for the given agent. + + Args: + agent: The agent to expose via A2A. + """ + if self._is_running: + logger.warning("A2A server already running") + return + + self.agent = agent + self.agent_card = self._create_agent_card(agent) + self._setup_routes() + + # Configure uvicorn + config = uvicorn.Config( + app=self.app, + host=self.config.get("host", "0.0.0.0"), + port=self.config.get("port", 8000), + log_level="info" + ) + + # Add TLS if configured + if self.config.get("tls_cert") and self.config.get("tls_key"): + config.ssl_certfile = self.config["tls_cert"] + config.ssl_keyfile = self.config["tls_key"] + + self._uvicorn_server = uvicorn.Server(config) + + # Start in background thread + def run_server(): + try: + self._is_running = True + logger.info(f"Starting A2A server on {self.get_endpoint()}") + logger.info(f"Agent Card: {self.get_endpoint()}/.well-known/agent.json") + + # Print discovered skills + logger.info(f"Discovered {len(self.agent_card.skills)} skills:") + for skill in self.agent_card.skills: + skill_info = skill if isinstance(skill, dict) else { + "name": getattr(skill, "name", "Unknown"), + "description": getattr(skill, "description", ""), + "tags": getattr(skill, "tags", []) + } + logger.info(f" • {skill_info['name']}: {skill_info['description']}") + if skill_info.get("tags"): + logger.info(f" Tags: {', '.join(skill_info['tags'])}") + + self._uvicorn_server.run() + finally: + self._is_running = False + + self._server_thread = threading.Thread(target=run_server, daemon=True) + self._server_thread.start() + + # Give server time to start + import time + time.sleep(0.5) + + @override + def stop(self) -> None: + """Stop the A2A server.""" + if not self._is_running: + return + + logger.info("Stopping A2A server") + + if self._uvicorn_server: + self._uvicorn_server.should_exit = True + + if self._server_thread: + self._server_thread.join(timeout=5) + + self._is_running = False + self.agent = None + self.agent_card = None \ No newline at end of file diff --git a/src/strands/protocols/a2a/tools.py b/src/strands/protocols/a2a/tools.py new file mode 100644 index 000000000..ec8a42d79 --- /dev/null +++ b/src/strands/protocols/a2a/tools.py @@ -0,0 +1,352 @@ +"""A2A Remote Tool implementation. + +This module provides tools that wrap remote A2A agents, enabling the +"Agent as Tool" pattern for seamless agent composition. +""" + +import asyncio +import logging +from typing import Any, Dict, Optional, List, Union + +from .client import A2AProtocolClient +from ...types.tools import AgentTool, ToolSpec, ToolUse, ToolResult, ToolResultContent + +logger = logging.getLogger(__name__) + + +class A2ARemoteTool(AgentTool): + """Wraps a remote A2A agent as a Strands tool.""" + + def __init__( + self, + agent_url: str, + skill_id: Optional[str] = None, + name: Optional[str] = None, + auth_token: Optional[str] = None + ): + """Initialize A2A remote tool. + + Args: + agent_url: URL of the remote A2A agent + skill_id: Optional specific skill to expose as tool + name: Optional custom name for the tool + auth_token: Optional authentication token + """ + super().__init__() + self.agent_url = agent_url + self.skill_id = skill_id + self._custom_name = name + self.client = A2AProtocolClient(agent_url, auth_token) + self.agent_card = None + self._discovered = False + self._tool_spec: Optional[ToolSpec] = None + + def _ensure_discovered(self): + """Ensure agent capabilities are discovered.""" + if not self._discovered: + # Use sync discovery for simplicity + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + self.agent_card = loop.run_until_complete(self.client.fetch_agent_card()) + self._discovered = True + # Generate and cache tool spec + self._tool_spec = self._generate_tool_spec() + finally: + loop.close() + + @property + def tool_name(self) -> str: + """The unique name of the tool used for identification and invocation.""" + if self._custom_name: + return self._custom_name + + self._ensure_discovered() + + if self.skill_id and self.agent_card: + # Find the specific skill + for skill in self.agent_card.skills: + skill_dict = skill if isinstance(skill, dict) else skill.dict() + if skill_dict.get("id") == self.skill_id: + return f"remote_{skill_dict.get('name', self.skill_id).lower().replace(' ', '_')}" + + # Use agent name + if self.agent_card: + return f"remote_{self.agent_card.name.lower().replace(' ', '_')}" + + return "remote_agent" + + @property + def tool_spec(self) -> ToolSpec: + """Tool specification that describes its functionality and parameters.""" + self._ensure_discovered() + if self._tool_spec: + return self._tool_spec + # Generate on demand if needed + return self._generate_tool_spec() + + @property + def tool_type(self) -> str: + """The type of the tool implementation.""" + return "a2a_remote" + + @property + def supports_hot_reload(self) -> bool: + """Remote tools don't support hot reload.""" + return False + + def _generate_tool_spec(self) -> ToolSpec: + """Generate tool spec from A2A agent card.""" + if not self.agent_card: + raise RuntimeError("Failed to discover remote agent") + + # If targeting a specific skill + if self.skill_id: + for skill in self.agent_card.skills: + skill_dict = skill if isinstance(skill, dict) else skill.dict() + if skill_dict.get("id") == self.skill_id: + return self._skill_to_tool_spec(skill_dict) + raise ValueError(f"Skill {self.skill_id} not found in agent") + + # Otherwise, create a general tool spec for the entire agent + skill_names = [] + all_examples = [] + all_tags = set() + + for skill in self.agent_card.skills: + skill_dict = skill if isinstance(skill, dict) else skill.dict() + skill_names.append(skill_dict.get("name", skill_dict.get("id", "unknown"))) + all_examples.extend(skill_dict.get("examples", [])) + all_tags.update(skill_dict.get("tags", [])) + + # Build inputSchema in the correct format + json_schema = { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Message or task to send to the remote agent" + }, + "skill": { + "type": "string", + "description": f"Optional specific skill to use", + "enum": [s.get("id") if isinstance(s, dict) else s.id for s in self.agent_card.skills] + } + }, + "required": ["message"] + } + + return ToolSpec( + name=self.tool_name, + description=f"{self.agent_card.name}: {self.agent_card.description}. Available skills: {', '.join(skill_names)}", + inputSchema={"json": json_schema} + ) + + def _skill_to_tool_spec(self, skill: Dict[str, Any]) -> ToolSpec: + """Convert an A2A skill to a Strands tool spec.""" + # Build inputSchema in the correct format + json_schema = { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": f"Input for {skill.get('name', 'the skill')}" + } + }, + "required": ["message"] + } + + return ToolSpec( + name=self.tool_name, + description=skill.get("description", f"Use the {skill.get('name', 'remote')} skill"), + inputSchema={"json": json_schema} + ) + + def get_tool_spec(self) -> Dict[str, Any]: + """Legacy method - returns the tool spec as a dict.""" + return dict(self.tool_spec) + + async def _execute_async(self, **kwargs) -> str: + """Execute the remote agent call asynchronously.""" + message = kwargs.get("message", "") + skill = kwargs.get("skill", self.skill_id) + + if not message: + raise ValueError("Message is required") + + try: + result = await self.client.send_task_and_wait( + message=message, + skill_id=skill + ) + return result + except Exception as e: + logger.error(f"Error calling remote agent {self.agent_url}: {e}") + raise + + def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult: + """Execute the tool's functionality with the given tool use request. + + Args: + tool: The tool use request containing tool ID and parameters. + *args: Positional arguments to pass to the tool. + **kwargs: Keyword arguments to pass to the tool. + + Returns: + The result of the tool execution. + """ + # Extract parameters from tool use + tool_input = tool.get("input", {}) + message = tool_input.get("message", "") + skill = tool_input.get("skill", self.skill_id) + + # Run async method in sync context + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result_text = loop.run_until_complete( + self._execute_async(message=message, skill=skill) + ) + + # Return properly formatted ToolResult + return ToolResult( + toolUseId=tool["toolUseId"], + content=[ToolResultContent(text=result_text)], + status="success" + ) + except Exception as e: + logger.error(f"Tool execution failed: {e}") + return ToolResult( + toolUseId=tool["toolUseId"], + content=[ToolResultContent(text=f"Error: {str(e)}")], + status="error" + ) + finally: + loop.close() + + def __call__(self, **kwargs) -> str: + """Execute remote agent call (legacy interface). + + Args: + message: Message to send to the remote agent + skill: Optional specific skill to use + **kwargs: Additional arguments (ignored) + + Returns: + The agent's response as a string + """ + # Create a synthetic tool use for the legacy interface + tool_use = ToolUse( + toolUseId=f"legacy_{id(self)}", + name=self.tool_name, + input=kwargs + ) + + result = self.invoke(tool_use) + + # Extract text from result + if result["status"] == "success" and result["content"]: + for content in result["content"]: + if "text" in content: + return content["text"] + + return "No response" + + +def create_agent_tool( + agent_url: str, + name: Optional[str] = None, + skill_id: Optional[str] = None, + auth_token: Optional[str] = None +) -> A2ARemoteTool: + """Factory to create a tool from a remote A2A agent. + + Args: + agent_url: URL of the remote A2A agent + name: Optional custom name for the tool + skill_id: Optional specific skill to expose + auth_token: Optional authentication token + + Returns: + A2ARemoteTool instance ready to be registered + + Example: + ```python + # Add a remote agent as a tool + agent.register_tool(create_agent_tool("http://research-agent:8000")) + + # Or with a custom name + agent.register_tool(create_agent_tool( + "http://research-agent:8000", + name="research" + )) + ``` + """ + tool = A2ARemoteTool(agent_url, skill_id, name, auth_token) + + # Pre-discover to validate the agent exists + tool._ensure_discovered() + + return tool + + +def create_agent_tools_from_skills( + agent_url: str, + skills: Optional[List[str]] = None, + auth_token: Optional[str] = None +) -> List[A2ARemoteTool]: + """Create multiple tools from an agent's skills. + + Args: + agent_url: URL of the remote A2A agent + skills: List of skill IDs to create tools for (None = all skills) + auth_token: Optional authentication token + + Returns: + List of A2ARemoteTool instances, one per skill + + Example: + ```python + # Register specific skills as separate tools + tools = create_agent_tools_from_skills( + "http://utility-agent:8000", + skills=["calculate", "translate"] + ) + for tool in tools: + agent.register_tool(tool) + ``` + """ + # First, discover the agent + temp_tool = A2ARemoteTool(agent_url, auth_token=auth_token) + temp_tool._ensure_discovered() + + if not temp_tool.agent_card: + raise RuntimeError(f"Failed to discover agent at {agent_url}") + + # Get all available skill IDs + available_skills = [] + for skill in temp_tool.agent_card.skills: + skill_dict = skill if isinstance(skill, dict) else skill.dict() + available_skills.append(skill_dict.get("id")) + + # Filter to requested skills + if skills: + skill_ids = [s for s in skills if s in available_skills] + if not skill_ids: + raise ValueError(f"No requested skills found. Available: {available_skills}") + else: + skill_ids = available_skills + + # Create a tool for each skill + tools = [] + for skill_id in skill_ids: + tool = A2ARemoteTool( + agent_url, + skill_id=skill_id, + auth_token=auth_token + ) + tool.agent_card = temp_tool.agent_card # Reuse discovered card + tool._discovered = True + tools.append(tool) + + return tools \ No newline at end of file diff --git a/src/strands/types/protocols/__init__.py b/src/strands/types/protocols/__init__.py new file mode 100644 index 000000000..52103c023 --- /dev/null +++ b/src/strands/types/protocols/__init__.py @@ -0,0 +1,5 @@ +"""Protocol server type definitions for the SDK.""" + +from .protocol import ProtocolServer + +__all__ = ["ProtocolServer"] \ No newline at end of file diff --git a/src/strands/types/protocols/protocol.py b/src/strands/types/protocols/protocol.py new file mode 100644 index 000000000..87dae488c --- /dev/null +++ b/src/strands/types/protocols/protocol.py @@ -0,0 +1,88 @@ +"""Protocol server abstraction for agent networking. + +This module provides the abstract base class for protocol server implementations, +following the same pattern as the Model abstraction. +""" + +import abc +import logging +from typing import Any, Dict, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from ...agent.agent import Agent + +logger = logging.getLogger(__name__) + + +class ProtocolServer(abc.ABC): + """Abstract base class for agent protocol server implementations. + + This class defines the interface for all protocol server implementations in the + Strands Agents SDK. It provides a standardized way to configure, start, and manage + different protocol servers (A2A, MCP, GraphQL, gRPC, etc.). + """ + + @abc.abstractmethod + def update_config(self, **server_config: Any) -> None: + """Update the server configuration with the provided arguments. + + Args: + **server_config: Configuration overrides. + """ + pass + + @abc.abstractmethod + def get_config(self) -> Any: + """Return the server configuration. + + Returns: + The server's configuration. + """ + pass + + @abc.abstractmethod + def start(self, agent: "Agent") -> None: + """Start the protocol server for the given agent. + + Args: + agent: The agent to expose via this protocol. + """ + pass + + @abc.abstractmethod + def stop(self) -> None: + """Stop the protocol server.""" + pass + + @abc.abstractmethod + def get_endpoint(self) -> str: + """Get the server's endpoint URL. + + Returns: + The URL where the server is accessible. + """ + pass + + @property + @abc.abstractmethod + def protocol_name(self) -> str: + """The name of the protocol (e.g., 'a2a', 'mcp', 'graphql').""" + pass + + @property + def is_running(self) -> bool: + """Whether the server is currently running. + + Returns: + True if the server is running, False otherwise. + """ + return False + + def __enter__(self) -> "ProtocolServer": + """Context manager entry.""" + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Context manager exit - ensure server is stopped.""" + if self.is_running: + self.stop() \ No newline at end of file