Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 123 additions & 1 deletion src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,128 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]

return super().format_request_message_content(content)

def _format_request_message_contents(self, role: str, content: ContentBlock) -> list[dict[str, Any]]:
"""Format LiteLLM compatible message contents.

LiteLLM expects content to be a string for simple text messages, not a list of content blocks.
This method flattens the content structure to be compatible with LiteLLM providers like Cerebras and Groq.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi thanks for this but can you help explain where the abstraction is broken.

Are you indicating that LiteLLM has a broken abstraction for Cerebras and Groq? Or do you believe that Strands has always improperly implemented the spec but this was just exposed when Ceregras and Groq were added?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump @aditya270520, thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LiteLLM has a broken abstraction for Cerebras and Groq providers. Strands has always properly implemented the OpenAI specification, but this inconsistency was exposed when Cerebras and Groq were added to LiteLLM's supported providers.
The abstraction should be fixed in LiteLLM itself to ensure all providers receive content in the same format, but until that happens, Strands needs this workaround to maintain compatibility.
This is a common issue when trying to create unified interfaces across multiple providers with different underlying APIs - the abstraction layer (LiteLLM) needs to handle the normalization, but it's not doing so consistently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Args:
role: The role of the message (e.g., "user", "assistant").
content: Content block to format.

Returns:
LiteLLM formatted message contents.

Raises:
TypeError: If the content block type cannot be converted to a LiteLLM-compatible format.
"""
if "text" in content:
return [{"role": role, "content": content["text"]}]

if "image" in content:
return [
{
"role": role,
"content": [{"type": "image_url", "image_url": {"url": content["image"]["source"]["bytes"]}}],
}
]

if "toolUse" in content:
return [
{
"role": role,
"tool_calls": [
{
"id": content["toolUse"]["toolUseId"],
"type": "function",
"function": {
"name": content["toolUse"]["name"],
"arguments": json.dumps(content["toolUse"]["input"]),
},
}
],
}
]

if "toolResult" in content:
return [
formatted_tool_result_content
for tool_result_content in content["toolResult"]["content"]
for formatted_tool_result_content in self._format_request_message_contents(
"tool",
(
{"text": json.dumps(tool_result_content["json"])}
if "json" in tool_result_content
else cast(ContentBlock, tool_result_content)
),
)
]

# For other content types, use the parent class method
formatted_content = self.format_request_message_content(content)
return [{"role": role, "content": [formatted_content]}]

def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
"""Format a LiteLLM compatible messages array.

This method overrides the parent OpenAIModel's format_request_messages to ensure
compatibility with LiteLLM providers like Cerebras and Groq that expect content
to be a string for simple text messages.

Args:
messages: List of message objects to be processed by the model.
system_prompt: System prompt to provide context to the model.

Returns:
A LiteLLM compatible messages array.
"""
system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []

return system_message + [
formatted_message
for message in messages
for content in message["content"]
for formatted_message in self._format_request_message_contents(message["role"], content)
]

def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, tool_choice: ToolChoice | None = None
) -> dict[str, Any]:
"""Format a LiteLLM compatible chat streaming request.

Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.

Returns:
A LiteLLM compatible chat streaming request.

Raises:
TypeError: If a message contains a content block type that cannot be converted to a LiteLLM-compatible
format.
"""
return {
"messages": self._format_request_messages(messages, system_prompt),
"model": self.config["model_id"],
"stream": True,
"stream_options": {"include_usage": True},
"tools": [
{
"type": "function",
"function": {
"name": tool_spec["name"],
"description": tool_spec["description"],
"parameters": tool_spec["inputSchema"]["json"],
},
}
for tool_spec in tool_specs or []
],
**cast(dict[str, Any], self.config.get("params", {})),
}

@override
async def stream(
self,
Expand Down Expand Up @@ -211,7 +333,7 @@ async def structured_output(
response = await litellm.acompletion(
**self.client_args,
model=self.get_config()["model_id"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
messages=self._format_request_messages(prompt, system_prompt),
response_format=output_model,
)

Expand Down
101 changes: 101 additions & 0 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""

import asyncio
import copy
import json
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
Expand All @@ -24,6 +26,105 @@ class Status(Enum):
FAILED = "failed"


@dataclass
class MultiAgentNode:
"""Base class for nodes in multi-agent systems."""

node_id: str

def __hash__(self) -> int:
"""Return hash for MultiAgentNode based on node_id."""
return hash(self.node_id)

def __eq__(self, other: Any) -> bool:
"""Return equality for MultiAgentNode based on node_id."""
if not isinstance(other, MultiAgentNode):
return False
return self.node_id == other.node_id


@dataclass
class SharedContext:
"""Shared context between multi-agent nodes.

This class provides a key-value store for sharing information across nodes
in multi-agent systems like Graph and Swarm. It validates that all values
are JSON serializable to ensure compatibility.
"""

context: dict[str, dict[str, Any]] = field(default_factory=dict)

def add_context(self, node: MultiAgentNode, key: str, value: Any) -> None:
"""Add context for a specific node.

Args:
node: The node object to add context for
key: The key to store the value under
value: The value to store (must be JSON serializable)

Raises:
ValueError: If key is invalid or value is not JSON serializable
"""
self._validate_key(key)
self._validate_json_serializable(value)

if node.node_id not in self.context:
self.context[node.node_id] = {}
self.context[node.node_id][key] = value

def get_context(self, node: MultiAgentNode, key: str | None = None) -> Any:
"""Get context for a specific node.

Args:
node: The node object to get context for
key: The specific key to retrieve (if None, returns all context for the node)

Returns:
The stored value, entire context dict for the node, or None if not found
"""
if node.node_id not in self.context:
return None if key else {}

if key is None:
return copy.deepcopy(self.context[node.node_id])
else:
value = self.context[node.node_id].get(key)
return copy.deepcopy(value) if value is not None else None

def _validate_key(self, key: str) -> None:
"""Validate that a key is valid.

Args:
key: The key to validate

Raises:
ValueError: If key is invalid
"""
if key is None:
raise ValueError("Key cannot be None")
if not isinstance(key, str):
raise ValueError("Key must be a string")
if not key.strip():
raise ValueError("Key cannot be empty")

def _validate_json_serializable(self, value: Any) -> None:
"""Validate that a value is JSON serializable.

Args:
value: The value to validate

Raises:
ValueError: If value is not JSON serializable
"""
try:
json.dumps(value)
except (TypeError, ValueError) as e:
raise ValueError(
f"Value is not JSON serializable: {type(value).__name__}. "
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed."
) from e


@dataclass
class NodeResult:
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.
Expand Down
28 changes: 25 additions & 3 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..telemetry import get_tracer
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
from .base import MultiAgentBase, MultiAgentNode, MultiAgentResult, NodeResult, SharedContext, Status

logger = logging.getLogger(__name__)

Expand All @@ -46,6 +46,7 @@ class GraphState:
task: The original input prompt/query provided to the graph execution.
This represents the actual work to be performed by the graph as a whole.
Entry point nodes receive this task as their input if they have no dependencies.
shared_context: Context shared between graph nodes for storing user-defined state.
"""

# Task (with default empty string)
Expand All @@ -61,6 +62,9 @@ class GraphState:
# Results
results: dict[str, NodeResult] = field(default_factory=dict)

# User-defined state shared across nodes
shared_context: "SharedContext" = field(default_factory=lambda: SharedContext())

# Accumulated metrics
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
Expand Down Expand Up @@ -126,7 +130,7 @@ def should_traverse(self, state: GraphState) -> bool:


@dataclass
class GraphNode:
class GraphNode(MultiAgentNode):
"""Represents a node in the graph.

The execution_status tracks the node's lifecycle within graph orchestration:
Expand All @@ -135,7 +139,6 @@ class GraphNode:
- COMPLETED/FAILED: Node finished executing (regardless of result quality)
"""

node_id: str
executor: Agent | MultiAgentBase
dependencies: set["GraphNode"] = field(default_factory=set)
execution_status: Status = Status.PENDING
Expand Down Expand Up @@ -385,6 +388,25 @@ def __init__(
self.state = GraphState()
self.tracer = get_tracer()

@property
def shared_context(self) -> SharedContext:
"""Access to the shared context for storing user-defined state across graph nodes.

Returns:
The SharedContext instance that can be used to store and retrieve
information that should be accessible to all nodes in the graph.

Example:
```python
graph = Graph(...)
node1 = graph.nodes["node1"]
node2 = graph.nodes["node2"]
graph.shared_context.add_context(node1, "file_reference", "/path/to/file")
graph.shared_context.get_context(node2, "file_reference")
```
"""
return self.state.shared_context

def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> GraphResult:
Expand Down
Loading
Loading