Skip to content

Commit f31b551

Browse files
dbschmigelskiUnshure
authored andcommitted
feat: decouple Strands ContentBlock and BedrockModel (strands-agents#836)
1 parent 3259310 commit f31b551

File tree

2 files changed

+347
-47
lines changed

2 files changed

+347
-47
lines changed

src/strands/models/bedrock.py

Lines changed: 194 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818

1919
from ..event_loop import streaming
2020
from ..tools import convert_pydantic_to_tool_spec
21-
from ..types.content import ContentBlock, Message, Messages
21+
from ..types.content import ContentBlock, Messages
2222
from ..types.exceptions import (
2323
ContextWindowOverflowException,
2424
ModelThrottledException,
2525
)
2626
from ..types.streaming import CitationsDelta, StreamEvent
27-
from ..types.tools import ToolChoice, ToolResult, ToolSpec
27+
from ..types.tools import ToolChoice, ToolSpec
2828
from ._validation import validate_config_keys
2929
from .model import Model
3030

@@ -185,17 +185,6 @@ def get_config(self) -> BedrockConfig:
185185
"""
186186
return self.config
187187

188-
def _should_include_tool_result_status(self) -> bool:
189-
"""Determine whether to include tool result status based on current config."""
190-
include_status = self.config.get("include_tool_result_status", "auto")
191-
192-
if include_status is True:
193-
return True
194-
elif include_status is False:
195-
return False
196-
else: # "auto"
197-
return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS)
198-
199188
def format_request(
200189
self,
201190
messages: Messages,
@@ -281,14 +270,12 @@ def format_request(
281270
),
282271
}
283272

284-
def _format_bedrock_messages(self, messages: Messages) -> Messages:
273+
def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
285274
"""Format messages for Bedrock API compatibility.
286275
287276
This function ensures messages conform to Bedrock's expected format by:
288277
- Filtering out SDK_UNKNOWN_MEMBER content blocks
289-
- Cleaning tool result content blocks by removing additional fields that may be
290-
useful for retaining information in hooks but would cause Bedrock validation
291-
exceptions when presented with unexpected fields
278+
- Eagerly filtering content blocks to only include Bedrock-supported fields
292279
- Ensuring all message content blocks are properly formatted for the Bedrock API
293280
294281
Args:
@@ -298,17 +285,19 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
298285
Messages formatted for Bedrock API compatibility
299286
300287
Note:
301-
Bedrock will throw validation exceptions when presented with additional
302-
unexpected fields in tool result blocks.
303-
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
288+
Unlike other APIs that ignore unknown fields, Bedrock only accepts a strict
289+
subset of fields for each content block type and throws validation exceptions
290+
when presented with unexpected fields. Therefore, we must eagerly filter all
291+
content blocks to remove any additional fields before sending to Bedrock.
292+
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlock.html
304293
"""
305-
cleaned_messages = []
294+
cleaned_messages: list[dict[str, Any]] = []
306295

307296
filtered_unknown_members = False
308297
dropped_deepseek_reasoning_content = False
309298

310299
for message in messages:
311-
cleaned_content: list[ContentBlock] = []
300+
cleaned_content: list[dict[str, Any]] = []
312301

313302
for content_block in message["content"]:
314303
# Filter out SDK_UNKNOWN_MEMBER content blocks
@@ -322,33 +311,13 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
322311
dropped_deepseek_reasoning_content = True
323312
continue
324313

325-
if "toolResult" in content_block:
326-
# Create a new content block with only the cleaned toolResult
327-
tool_result: ToolResult = content_block["toolResult"]
314+
# Format content blocks for Bedrock API compatibility
315+
formatted_content = self._format_request_message_content(content_block)
316+
cleaned_content.append(formatted_content)
328317

329-
if self._should_include_tool_result_status():
330-
# Include status field
331-
cleaned_tool_result = ToolResult(
332-
content=tool_result["content"],
333-
toolUseId=tool_result["toolUseId"],
334-
status=tool_result["status"],
335-
)
336-
else:
337-
# Remove status field
338-
cleaned_tool_result = ToolResult( # type: ignore[typeddict-item]
339-
toolUseId=tool_result["toolUseId"], content=tool_result["content"]
340-
)
341-
342-
cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result}
343-
cleaned_content.append(cleaned_block)
344-
else:
345-
# Keep other content blocks as-is
346-
cleaned_content.append(content_block)
347-
348-
# Create new message with cleaned content (skip if empty for DeepSeek)
318+
# Create new message with cleaned content (skip if empty)
349319
if cleaned_content:
350-
cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
351-
cleaned_messages.append(cleaned_message)
320+
cleaned_messages.append({"content": cleaned_content, "role": message["role"]})
352321

353322
if filtered_unknown_members:
354323
logger.warning(
@@ -361,6 +330,184 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages:
361330

362331
return cleaned_messages
363332

333+
def _should_include_tool_result_status(self) -> bool:
334+
"""Determine whether to include tool result status based on current config."""
335+
include_status = self.config.get("include_tool_result_status", "auto")
336+
337+
if include_status is True:
338+
return True
339+
elif include_status is False:
340+
return False
341+
else: # "auto"
342+
return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS)
343+
344+
def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
345+
"""Format a Bedrock content block.
346+
347+
Bedrock strictly validates content blocks and throws exceptions for unknown fields.
348+
This function extracts only the fields that Bedrock supports for each content type.
349+
350+
Args:
351+
content: Content block to format.
352+
353+
Returns:
354+
Bedrock formatted content block.
355+
356+
Raises:
357+
TypeError: If the content block type is not supported by Bedrock.
358+
"""
359+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html
360+
if "cachePoint" in content:
361+
return {"cachePoint": {"type": content["cachePoint"]["type"]}}
362+
363+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html
364+
if "document" in content:
365+
document = content["document"]
366+
result: dict[str, Any] = {}
367+
368+
# Handle required fields (all optional due to total=False)
369+
if "name" in document:
370+
result["name"] = document["name"]
371+
if "format" in document:
372+
result["format"] = document["format"]
373+
374+
# Handle source
375+
if "source" in document:
376+
result["source"] = {"bytes": document["source"]["bytes"]}
377+
378+
# Handle optional fields
379+
if "citations" in document and document["citations"] is not None:
380+
result["citations"] = {"enabled": document["citations"]["enabled"]}
381+
if "context" in document:
382+
result["context"] = document["context"]
383+
384+
return {"document": result}
385+
386+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_GuardrailConverseContentBlock.html
387+
if "guardContent" in content:
388+
guard = content["guardContent"]
389+
guard_text = guard["text"]
390+
result = {"text": {"text": guard_text["text"], "qualifiers": guard_text["qualifiers"]}}
391+
return {"guardContent": result}
392+
393+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageBlock.html
394+
if "image" in content:
395+
image = content["image"]
396+
source = image["source"]
397+
formatted_source = {}
398+
if "bytes" in source:
399+
formatted_source = {"bytes": source["bytes"]}
400+
result = {"format": image["format"], "source": formatted_source}
401+
return {"image": result}
402+
403+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html
404+
if "reasoningContent" in content:
405+
reasoning = content["reasoningContent"]
406+
result = {}
407+
408+
if "reasoningText" in reasoning:
409+
reasoning_text = reasoning["reasoningText"]
410+
result["reasoningText"] = {}
411+
if "text" in reasoning_text:
412+
result["reasoningText"]["text"] = reasoning_text["text"]
413+
# Only include signature if truthy (avoid empty strings)
414+
if reasoning_text.get("signature"):
415+
result["reasoningText"]["signature"] = reasoning_text["signature"]
416+
417+
if "redactedContent" in reasoning:
418+
result["redactedContent"] = reasoning["redactedContent"]
419+
420+
return {"reasoningContent": result}
421+
422+
# Pass through text and other simple content types
423+
if "text" in content:
424+
return {"text": content["text"]}
425+
426+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
427+
if "toolResult" in content:
428+
tool_result = content["toolResult"]
429+
formatted_content: list[dict[str, Any]] = []
430+
for tool_result_content in tool_result["content"]:
431+
if "json" in tool_result_content:
432+
# Handle json field since not in ContentBlock but valid in ToolResultContent
433+
formatted_content.append({"json": tool_result_content["json"]})
434+
else:
435+
formatted_content.append(
436+
self._format_request_message_content(cast(ContentBlock, tool_result_content))
437+
)
438+
439+
result = {
440+
"content": formatted_content,
441+
"toolUseId": tool_result["toolUseId"],
442+
}
443+
if "status" in tool_result and self._should_include_tool_result_status():
444+
result["status"] = tool_result["status"]
445+
return {"toolResult": result}
446+
447+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolUseBlock.html
448+
if "toolUse" in content:
449+
tool_use = content["toolUse"]
450+
return {
451+
"toolUse": {
452+
"input": tool_use["input"],
453+
"name": tool_use["name"],
454+
"toolUseId": tool_use["toolUseId"],
455+
}
456+
}
457+
458+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_VideoBlock.html
459+
if "video" in content:
460+
video = content["video"]
461+
source = video["source"]
462+
formatted_source = {}
463+
if "bytes" in source:
464+
formatted_source = {"bytes": source["bytes"]}
465+
result = {"format": video["format"], "source": formatted_source}
466+
return {"video": result}
467+
468+
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html
469+
if "citationsContent" in content:
470+
citations = content["citationsContent"]
471+
result = {}
472+
473+
if "citations" in citations:
474+
result["citations"] = []
475+
for citation in citations["citations"]:
476+
filtered_citation: dict[str, Any] = {}
477+
if "location" in citation:
478+
location = citation["location"]
479+
filtered_location = {}
480+
# Filter location fields to only include Bedrock-supported ones
481+
if "documentIndex" in location:
482+
filtered_location["documentIndex"] = location["documentIndex"]
483+
if "start" in location:
484+
filtered_location["start"] = location["start"]
485+
if "end" in location:
486+
filtered_location["end"] = location["end"]
487+
filtered_citation["location"] = filtered_location
488+
if "sourceContent" in citation:
489+
filtered_source_content: list[dict[str, Any]] = []
490+
for source_content in citation["sourceContent"]:
491+
if "text" in source_content:
492+
filtered_source_content.append({"text": source_content["text"]})
493+
if filtered_source_content:
494+
filtered_citation["sourceContent"] = filtered_source_content
495+
if "title" in citation:
496+
filtered_citation["title"] = citation["title"]
497+
result["citations"].append(filtered_citation)
498+
499+
if "content" in citations:
500+
filtered_content: list[dict[str, Any]] = []
501+
for generated_content in citations["content"]:
502+
if "text" in generated_content:
503+
filtered_content.append({"text": generated_content["text"]})
504+
if filtered_content:
505+
result["content"] = filtered_content
506+
507+
return {"citationsContent": result}
508+
509+
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
510+
364511
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
365512
"""Check if guardrail data contains any blocked policies.
366513

0 commit comments

Comments
 (0)