diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 824e08819..80f238f81 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -21,8 +21,11 @@ from dataclasses import dataclass, field from typing import Any, Callable, Tuple, cast +from opentelemetry import trace as trace_api + from ..agent import Agent, AgentResult from ..agent.state import AgentState +from ..telemetry import get_tracer from ..tools.decorator import tool from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage @@ -229,6 +232,7 @@ def __init__( task="", completion_status=Status.PENDING, ) + self.tracer = get_tracer() self._setup_swarm(nodes) self._inject_swarm_tools() @@ -257,24 +261,26 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> S ) start_time = time.time() - try: - logger.debug("current_node=<%s> | starting swarm execution with node", self.state.current_node.node_id) - logger.debug( - "max_handoffs=<%d>, max_iterations=<%d>, timeout=<%s>s | swarm execution config", - self.max_handoffs, - self.max_iterations, - self.execution_timeout, - ) + span = self.tracer.start_multiagent_span(task, "swarm") + with trace_api.use_span(span, end_on_exit=True): + try: + logger.debug("current_node=<%s> | starting swarm execution with node", self.state.current_node.node_id) + logger.debug( + "max_handoffs=<%d>, max_iterations=<%d>, timeout=<%s>s | swarm execution config", + self.max_handoffs, + self.max_iterations, + self.execution_timeout, + ) - await self._execute_swarm() - except Exception: - logger.exception("swarm execution failed") - self.state.completion_status = Status.FAILED - raise - finally: - self.state.execution_time = round((time.time() - start_time) * 1000) + await self._execute_swarm() + except Exception: + logger.exception("swarm execution failed") + self.state.completion_status = Status.FAILED + raise + finally: + self.state.execution_time = round((time.time() - start_time) * 1000) - return self._build_result() + return self._build_result() def _setup_swarm(self, nodes: list[Agent]) -> None: """Initialize swarm configuration.""" diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index ffb0343b2..fb4c10467 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1,6 +1,6 @@ import math import time -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock, Mock, patch import pytest @@ -94,6 +94,22 @@ def mock_swarm(mock_agents): return swarm +@pytest.fixture +def mock_strands_tracer(): + with patch("strands.multiagent.swarm.get_tracer") as mock_get_tracer: + mock_tracer_instance = MagicMock() + mock_span = MagicMock() + mock_tracer_instance.start_multiagent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer_instance + yield mock_tracer_instance + + +@pytest.fixture +def mock_use_span(): + with patch("strands.multiagent.swarm.trace_api.use_span") as mock_use_span: + yield mock_use_span + + def test_swarm_structure_and_nodes(mock_swarm, mock_agents): """Test swarm structure and SwarmNode properties.""" # Test swarm structure @@ -215,7 +231,7 @@ def test_swarm_state_should_continue(mock_swarm): @pytest.mark.asyncio -async def test_swarm_execution_async(mock_swarm, mock_agents): +async def test_swarm_execution_async(mock_strands_tracer, mock_use_span, mock_swarm, mock_agents): """Test asynchronous swarm execution.""" # Execute swarm task = [ContentBlock(text="Analyze this task"), ContentBlock(text="Additional context")] @@ -238,8 +254,11 @@ async def test_swarm_execution_async(mock_swarm, mock_agents): assert hasattr(result, "node_history") assert len(result.node_history) == 1 + mock_strands_tracer.start_multiagent_span.assert_called() + mock_use_span.assert_called_once() -def test_swarm_synchronous_execution(mock_agents): + +def test_swarm_synchronous_execution(mock_strands_tracer, mock_use_span, mock_agents): """Test synchronous swarm execution using __call__ method.""" agents = list(mock_agents.values()) swarm = Swarm( @@ -280,6 +299,9 @@ def test_swarm_synchronous_execution(mock_agents): for node in swarm.nodes.values(): node.executor.tool_registry.process_tools.assert_called() + mock_strands_tracer.start_multiagent_span.assert_called() + mock_use_span.assert_called_once() + def test_swarm_builder_validation(mock_agents): """Test swarm builder validation and error handling.""" @@ -407,7 +429,7 @@ def test_swarm_tool_creation_and_execution(): assert completion_result["status"] == "success" -def test_swarm_failure_handling(): +def test_swarm_failure_handling(mock_strands_tracer, mock_use_span): """Test swarm execution with agent failures.""" # Test execution with agent failures failing_agent = create_mock_agent("failing_agent") @@ -418,6 +440,8 @@ def test_swarm_failure_handling(): # The swarm catches exceptions internally and sets status to FAILED result = failing_swarm("Test failure handling") assert result.status == Status.FAILED + mock_strands_tracer.start_multiagent_span.assert_called() + mock_use_span.assert_called_once() def test_swarm_metrics_handling():