Skip to content

Commit cbdab32

Browse files
authored
feat(swarm): Make entry point configurable (#851)
1 parent f12fee8 commit cbdab32

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

src/strands/multiagent/swarm.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def __init__(
196196
self,
197197
nodes: list[Agent],
198198
*,
199+
entry_point: Agent | None = None,
199200
max_handoffs: int = 20,
200201
max_iterations: int = 20,
201202
execution_timeout: float = 900.0,
@@ -207,6 +208,7 @@ def __init__(
207208
208209
Args:
209210
nodes: List of nodes (e.g. Agent) to include in the swarm
211+
entry_point: Agent to start with. If None, uses the first agent (default: None)
210212
max_handoffs: Maximum handoffs to agents and users (default: 20)
211213
max_iterations: Maximum node executions within the swarm (default: 20)
212214
execution_timeout: Total execution timeout in seconds (default: 900.0)
@@ -218,6 +220,7 @@ def __init__(
218220
"""
219221
super().__init__()
220222

223+
self.entry_point = entry_point
221224
self.max_handoffs = max_handoffs
222225
self.max_iterations = max_iterations
223226
self.execution_timeout = execution_timeout
@@ -276,7 +279,11 @@ async def invoke_async(
276279
logger.debug("starting swarm execution")
277280

278281
# Initialize swarm state with configuration
279-
initial_node = next(iter(self.nodes.values())) # First SwarmNode
282+
if self.entry_point:
283+
initial_node = self.nodes[str(self.entry_point.name)]
284+
else:
285+
initial_node = next(iter(self.nodes.values())) # First SwarmNode
286+
280287
self.state = SwarmState(
281288
current_node=initial_node,
282289
task=task,
@@ -326,9 +333,28 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:
326333

327334
self.nodes[node_id] = SwarmNode(node_id=node_id, executor=node)
328335

336+
# Validate entry point if specified
337+
if self.entry_point is not None:
338+
entry_point_node_id = str(self.entry_point.name)
339+
if (
340+
entry_point_node_id not in self.nodes
341+
or self.nodes[entry_point_node_id].executor is not self.entry_point
342+
):
343+
available_agents = [
344+
f"{node_id} ({type(node.executor).__name__})" for node_id, node in self.nodes.items()
345+
]
346+
raise ValueError(f"Entry point agent not found in swarm nodes. Available agents: {available_agents}")
347+
329348
swarm_nodes = list(self.nodes.values())
330349
logger.debug("nodes=<%s> | initialized swarm with nodes", [node.node_id for node in swarm_nodes])
331350

351+
if self.entry_point:
352+
entry_point_name = getattr(self.entry_point, "name", "unnamed_agent")
353+
logger.debug("entry_point=<%s> | configured entry point", entry_point_name)
354+
else:
355+
first_node = next(iter(self.nodes.keys()))
356+
logger.debug("entry_point=<%s> | using first node as entry point", first_node)
357+
332358
def _validate_swarm(self, nodes: list[Agent]) -> None:
333359
"""Validate swarm structure and nodes."""
334360
# Check for duplicate object instances

tests/strands/multiagent/test_swarm.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,82 @@ def test_swarm_auto_completion_without_handoff():
451451
no_handoff_agent.invoke_async.assert_called()
452452

453453

454+
def test_swarm_configurable_entry_point():
455+
"""Test swarm with configurable entry point."""
456+
# Create multiple agents
457+
agent1 = create_mock_agent("agent1", "Agent 1 response")
458+
agent2 = create_mock_agent("agent2", "Agent 2 response")
459+
agent3 = create_mock_agent("agent3", "Agent 3 response")
460+
461+
# Create swarm with agent2 as entry point
462+
swarm = Swarm([agent1, agent2, agent3], entry_point=agent2)
463+
464+
# Verify entry point is set correctly
465+
assert swarm.entry_point is agent2
466+
467+
# Execute swarm
468+
result = swarm("Test task")
469+
470+
# Verify agent2 was the first to execute
471+
assert result.status == Status.COMPLETED
472+
assert len(result.node_history) == 1
473+
assert result.node_history[0].node_id == "agent2"
474+
475+
476+
def test_swarm_invalid_entry_point():
477+
"""Test swarm with invalid entry point raises error."""
478+
agent1 = create_mock_agent("agent1", "Agent 1 response")
479+
agent2 = create_mock_agent("agent2", "Agent 2 response")
480+
agent3 = create_mock_agent("agent3", "Agent 3 response") # Not in swarm
481+
482+
# Try to create swarm with agent not in the swarm
483+
with pytest.raises(ValueError, match="Entry point agent not found in swarm nodes"):
484+
Swarm([agent1, agent2], entry_point=agent3)
485+
486+
487+
def test_swarm_default_entry_point():
488+
"""Test swarm uses first agent as default entry point."""
489+
agent1 = create_mock_agent("agent1", "Agent 1 response")
490+
agent2 = create_mock_agent("agent2", "Agent 2 response")
491+
492+
# Create swarm without specifying entry point
493+
swarm = Swarm([agent1, agent2])
494+
495+
# Verify no explicit entry point is set
496+
assert swarm.entry_point is None
497+
498+
# Execute swarm
499+
result = swarm("Test task")
500+
501+
# Verify first agent was used as entry point
502+
assert result.status == Status.COMPLETED
503+
assert len(result.node_history) == 1
504+
assert result.node_history[0].node_id == "agent1"
505+
506+
507+
def test_swarm_duplicate_agent_names():
508+
"""Test swarm rejects agents with duplicate names."""
509+
agent1 = create_mock_agent("duplicate_name", "Agent 1 response")
510+
agent2 = create_mock_agent("duplicate_name", "Agent 2 response")
511+
512+
# Try to create swarm with duplicate names
513+
with pytest.raises(ValueError, match="Node ID 'duplicate_name' is not unique"):
514+
Swarm([agent1, agent2])
515+
516+
517+
def test_swarm_entry_point_same_name_different_object():
518+
"""Test entry point validation with same name but different object."""
519+
agent1 = create_mock_agent("agent1", "Agent 1 response")
520+
agent2 = create_mock_agent("agent2", "Agent 2 response")
521+
522+
# Create a different agent with same name as agent1
523+
different_agent_same_name = create_mock_agent("agent1", "Different agent response")
524+
525+
# Try to use the different agent as entry point
526+
with pytest.raises(ValueError, match="Entry point agent not found in swarm nodes"):
527+
Swarm([agent1, agent2], entry_point=different_agent_same_name)
528+
529+
454530
def test_swarm_validate_unsupported_features():
455531
"""Test Swarm validation for session persistence and callbacks."""
456532
# Test with normal agent (should work)

0 commit comments

Comments
 (0)