Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
from dataclasses import dataclass, field
from typing import Any, Callable, Tuple, cast

from opentelemetry import trace as trace_api

from ..agent import Agent, AgentResult
from ..agent.state import AgentState
from ..telemetry import get_tracer
from ..tools.decorator import tool
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
Expand Down Expand Up @@ -229,6 +232,7 @@ def __init__(
task="",
completion_status=Status.PENDING,
)
self.tracer = get_tracer()

self._setup_swarm(nodes)
self._inject_swarm_tools()
Expand Down Expand Up @@ -257,24 +261,26 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> S
)

start_time = time.time()
try:
logger.debug("current_node=<%s> | starting swarm execution with node", self.state.current_node.node_id)
logger.debug(
"max_handoffs=<%d>, max_iterations=<%d>, timeout=<%s>s | swarm execution config",
self.max_handoffs,
self.max_iterations,
self.execution_timeout,
)
span = self.tracer.start_multiagent_span(task, "swarm")
with trace_api.use_span(span, end_on_exit=True):
try:
logger.debug("current_node=<%s> | starting swarm execution with node", self.state.current_node.node_id)
logger.debug(
"max_handoffs=<%d>, max_iterations=<%d>, timeout=<%s>s | swarm execution config",
self.max_handoffs,
self.max_iterations,
self.execution_timeout,
)

await self._execute_swarm()
except Exception:
logger.exception("swarm execution failed")
self.state.completion_status = Status.FAILED
raise
finally:
self.state.execution_time = round((time.time() - start_time) * 1000)
await self._execute_swarm()
except Exception:
logger.exception("swarm execution failed")
self.state.completion_status = Status.FAILED
raise
finally:
self.state.execution_time = round((time.time() - start_time) * 1000)

return self._build_result()
return self._build_result()

def _setup_swarm(self, nodes: list[Agent]) -> None:
"""Initialize swarm configuration."""
Expand Down
32 changes: 28 additions & 4 deletions tests/strands/multiagent/test_swarm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import time
from unittest.mock import MagicMock, Mock
from unittest.mock import MagicMock, Mock, patch

import pytest

Expand Down Expand Up @@ -94,6 +94,22 @@ def mock_swarm(mock_agents):
return swarm


@pytest.fixture
def mock_strands_tracer():
with patch("strands.multiagent.swarm.get_tracer") as mock_get_tracer:
mock_tracer_instance = MagicMock()
mock_span = MagicMock()
mock_tracer_instance.start_multiagent_span.return_value = mock_span
mock_get_tracer.return_value = mock_tracer_instance
yield mock_tracer_instance


@pytest.fixture
def mock_use_span():
with patch("strands.multiagent.swarm.trace_api.use_span") as mock_use_span:
yield mock_use_span


def test_swarm_structure_and_nodes(mock_swarm, mock_agents):
"""Test swarm structure and SwarmNode properties."""
# Test swarm structure
Expand Down Expand Up @@ -215,7 +231,7 @@ def test_swarm_state_should_continue(mock_swarm):


@pytest.mark.asyncio
async def test_swarm_execution_async(mock_swarm, mock_agents):
async def test_swarm_execution_async(mock_strands_tracer, mock_use_span, mock_swarm, mock_agents):
"""Test asynchronous swarm execution."""
# Execute swarm
task = [ContentBlock(text="Analyze this task"), ContentBlock(text="Additional context")]
Expand All @@ -238,8 +254,11 @@ async def test_swarm_execution_async(mock_swarm, mock_agents):
assert hasattr(result, "node_history")
assert len(result.node_history) == 1

mock_strands_tracer.start_multiagent_span.assert_called()
mock_use_span.assert_called_once()

def test_swarm_synchronous_execution(mock_agents):

def test_swarm_synchronous_execution(mock_strands_tracer, mock_use_span, mock_agents):
"""Test synchronous swarm execution using __call__ method."""
agents = list(mock_agents.values())
swarm = Swarm(
Expand Down Expand Up @@ -280,6 +299,9 @@ def test_swarm_synchronous_execution(mock_agents):
for node in swarm.nodes.values():
node.executor.tool_registry.process_tools.assert_called()

mock_strands_tracer.start_multiagent_span.assert_called()
mock_use_span.assert_called_once()


def test_swarm_builder_validation(mock_agents):
"""Test swarm builder validation and error handling."""
Expand Down Expand Up @@ -407,7 +429,7 @@ def test_swarm_tool_creation_and_execution():
assert completion_result["status"] == "success"


def test_swarm_failure_handling():
def test_swarm_failure_handling(mock_strands_tracer, mock_use_span):
"""Test swarm execution with agent failures."""
# Test execution with agent failures
failing_agent = create_mock_agent("failing_agent")
Expand All @@ -418,6 +440,8 @@ def test_swarm_failure_handling():
# The swarm catches exceptions internally and sets status to FAILED
result = failing_swarm("Test failure handling")
assert result.status == Status.FAILED
mock_strands_tracer.start_multiagent_span.assert_called()
mock_use_span.assert_called_once()


def test_swarm_metrics_handling():
Expand Down
Loading