Skip to content

Commit ab5f8ee

Browse files
authored
multi agent input (#1196)
1 parent 95ac650 commit ab5f8ee

File tree

7 files changed

+72
-21
lines changed

7 files changed

+72
-21
lines changed

src/strands/multiagent/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
from .._async import run_async
1414
from ..agent import AgentResult
15-
from ..types.content import ContentBlock
1615
from ..types.event_loop import Metrics, Usage
16+
from ..types.multiagent import MultiAgentInput
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -173,7 +173,7 @@ class MultiAgentBase(ABC):
173173

174174
@abstractmethod
175175
async def invoke_async(
176-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
176+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
177177
) -> MultiAgentResult:
178178
"""Invoke asynchronously.
179179
@@ -186,7 +186,7 @@ async def invoke_async(
186186
raise NotImplementedError("invoke_async not implemented")
187187

188188
async def stream_async(
189-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
189+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
190190
) -> AsyncIterator[dict[str, Any]]:
191191
"""Stream events during multi-agent execution.
192192
@@ -211,7 +211,7 @@ async def stream_async(
211211
yield {"result": result}
212212

213213
def __call__(
214-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
214+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
215215
) -> MultiAgentResult:
216216
"""Invoke synchronously.
217217

src/strands/multiagent/graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646
from ..types.content import ContentBlock, Messages
4747
from ..types.event_loop import Metrics, Usage
48+
from ..types.multiagent import MultiAgentInput
4849
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
4950

5051
logger = logging.getLogger(__name__)
@@ -67,7 +68,7 @@ class GraphState:
6768
"""
6869

6970
# Task (with default empty string)
70-
task: str | list[ContentBlock] = ""
71+
task: MultiAgentInput = ""
7172

7273
# Execution state
7374
status: Status = Status.PENDING
@@ -456,7 +457,7 @@ def __init__(
456457
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
457458

458459
def __call__(
459-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
460+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
460461
) -> GraphResult:
461462
"""Invoke the graph synchronously.
462463
@@ -472,7 +473,7 @@ def __call__(
472473
return run_async(lambda: self.invoke_async(task, invocation_state))
473474

474475
async def invoke_async(
475-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
476+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
476477
) -> GraphResult:
477478
"""Invoke the graph asynchronously.
478479
@@ -496,7 +497,7 @@ async def invoke_async(
496497
return cast(GraphResult, final_event["result"])
497498

498499
async def stream_async(
499-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
500+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
500501
) -> AsyncIterator[dict[str, Any]]:
501502
"""Stream events during graph execution.
502503

src/strands/multiagent/swarm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646
from ..types.content import ContentBlock, Messages
4747
from ..types.event_loop import Metrics, Usage
48+
from ..types.multiagent import MultiAgentInput
4849
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
4950

5051
logger = logging.getLogger(__name__)
@@ -145,7 +146,7 @@ class SwarmState:
145146
"""Current state of swarm execution."""
146147

147148
current_node: SwarmNode | None # The agent currently executing
148-
task: str | list[ContentBlock] # The original task from the user that is being executed
149+
task: MultiAgentInput # The original task from the user that is being executed
149150
completion_status: Status = Status.PENDING # Current swarm execution status
150151
shared_context: SharedContext = field(default_factory=SharedContext) # Context shared between agents
151152
node_history: list[SwarmNode] = field(default_factory=list) # Complete history of agents that have executed
@@ -277,7 +278,7 @@ def __init__(
277278
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
278279

279280
def __call__(
280-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
281+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
281282
) -> SwarmResult:
282283
"""Invoke the swarm synchronously.
283284
@@ -292,7 +293,7 @@ def __call__(
292293
return run_async(lambda: self.invoke_async(task, invocation_state))
293294

294295
async def invoke_async(
295-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
296+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
296297
) -> SwarmResult:
297298
"""Invoke the swarm asynchronously.
298299
@@ -316,7 +317,7 @@ async def invoke_async(
316317
return cast(SwarmResult, final_event["result"])
317318

318319
async def stream_async(
319-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
320+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
320321
) -> AsyncIterator[dict[str, Any]]:
321322
"""Stream events during swarm execution.
322323
@@ -741,7 +742,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
741742
)
742743

743744
async def _execute_node(
744-
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any]
745+
self, node: SwarmNode, task: MultiAgentInput, invocation_state: dict[str, Any]
745746
) -> AsyncIterator[Any]:
746747
"""Execute swarm node and yield TypedEvent objects."""
747748
start_time = time.time()

src/strands/telemetry/tracer.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
import logging
99
import os
1010
from datetime import date, datetime, timezone
11-
from typing import Any, Dict, Mapping, Optional
11+
from typing import Any, Dict, Mapping, Optional, cast
1212

1313
import opentelemetry.trace as trace_api
1414
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
1515
from opentelemetry.trace import Span, StatusCode
1616

1717
from ..agent.agent_result import AgentResult
1818
from ..types.content import ContentBlock, Message, Messages
19+
from ..types.interrupt import InterruptResponseContent
20+
from ..types.multiagent import MultiAgentInput
1921
from ..types.streaming import Metrics, StopReason, Usage
2022
from ..types.tools import ToolResult, ToolUse
2123
from ..types.traces import Attributes, AttributeValue
@@ -675,7 +677,7 @@ def _construct_tool_definitions(self, tools_config: dict) -> list[dict[str, Any]
675677

676678
def start_multiagent_span(
677679
self,
678-
task: str | list[ContentBlock],
680+
task: MultiAgentInput,
679681
instance: str,
680682
) -> Span:
681683
"""Start a new span for swarm invocation."""
@@ -789,12 +791,23 @@ def _add_event_messages(self, span: Span, messages: Messages) -> None:
789791
{"content": serialize(message["content"])},
790792
)
791793

792-
def _map_content_blocks_to_otel_parts(self, content_blocks: list[ContentBlock]) -> list[dict[str, Any]]:
793-
"""Map ContentBlock objects to OpenTelemetry parts format."""
794+
def _map_content_blocks_to_otel_parts(
795+
self, content_blocks: list[ContentBlock] | list[InterruptResponseContent]
796+
) -> list[dict[str, Any]]:
797+
"""Map content blocks to OpenTelemetry parts format."""
794798
parts: list[dict[str, Any]] = []
795799

796-
for block in content_blocks:
797-
if "text" in block:
800+
for block in cast(list[dict[str, Any]], content_blocks):
801+
if "interruptResponse" in block:
802+
interrupt_response = block["interruptResponse"]
803+
parts.append(
804+
{
805+
"type": "interrupt_response",
806+
"id": interrupt_response["interruptId"],
807+
"response": interrupt_response["response"],
808+
},
809+
)
810+
elif "text" in block:
798811
# Standard TextPart
799812
parts.append({"type": "text", "content": block["text"]})
800813
elif "toolUse" in block:

src/strands/types/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
from typing import TypeAlias
77

88
from .content import ContentBlock, Messages
9-
from .interrupt import InterruptResponse
9+
from .interrupt import InterruptResponseContent
1010

11-
AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponse] | Messages | None
11+
AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] | Messages | None

src/strands/types/multiagent.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Multi-agent related type definitions for the SDK."""
2+
3+
from typing import TypeAlias
4+
5+
from .content import ContentBlock
6+
7+
MultiAgentInput: TypeAlias = str | list[ContentBlock]

tests/strands/telemetry/test_tracer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize
1313
from strands.types.content import ContentBlock
14+
from strands.types.interrupt import InterruptResponseContent
1415
from strands.types.streaming import Metrics, StopReason, Usage
1516

1617

@@ -396,6 +397,34 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer):
396397
assert span is not None
397398

398399

400+
@pytest.mark.parametrize(
401+
"task, expected_parts",
402+
[
403+
([ContentBlock(text="Test message")], [{"type": "text", "content": "Test message"}]),
404+
(
405+
[InterruptResponseContent(interruptResponse={"interruptId": "test-id", "response": "approved"})],
406+
[{"type": "interrupt_response", "id": "test-id", "response": "approved"}],
407+
),
408+
],
409+
)
410+
def test_start_multiagent_span_task_part_conversion(mock_tracer, task, expected_parts, monkeypatch):
411+
monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental")
412+
413+
with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer):
414+
tracer = Tracer()
415+
tracer.tracer = mock_tracer
416+
417+
mock_span = mock.MagicMock()
418+
mock_tracer.start_span.return_value = mock_span
419+
420+
tracer.start_multiagent_span(task, "swarm")
421+
422+
expected_content = json.dumps([{"role": "user", "parts": expected_parts}])
423+
mock_span.add_event.assert_any_call(
424+
"gen_ai.client.inference.operation.details", attributes={"gen_ai.input.messages": expected_content}
425+
)
426+
427+
399428
def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, monkeypatch):
400429
"""Test starting a swarm call span with task as list of contentBlock with latest semantic conventions."""
401430
with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer):

0 commit comments

Comments
 (0)