diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index ba1c77193..8c9716a4f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -18,13 +18,13 @@ from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec -from ..types.content import ContentBlock, Message, Messages +from ..types.content import ContentBlock, Messages from ..types.exceptions import ( ContextWindowOverflowException, ModelThrottledException, ) from ..types.streaming import CitationsDelta, StreamEvent -from ..types.tools import ToolChoice, ToolResult, ToolSpec +from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys from .model import Model @@ -185,17 +185,6 @@ def get_config(self) -> BedrockConfig: """ return self.config - def _should_include_tool_result_status(self) -> bool: - """Determine whether to include tool result status based on current config.""" - include_status = self.config.get("include_tool_result_status", "auto") - - if include_status is True: - return True - elif include_status is False: - return False - else: # "auto" - return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS) - def format_request( self, messages: Messages, @@ -281,14 +270,12 @@ def format_request( ), } - def _format_bedrock_messages(self, messages: Messages) -> Messages: + def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format messages for Bedrock API compatibility. This function ensures messages conform to Bedrock's expected format by: - Filtering out SDK_UNKNOWN_MEMBER content blocks - - Cleaning tool result content blocks by removing additional fields that may be - useful for retaining information in hooks but would cause Bedrock validation - exceptions when presented with unexpected fields + - Eagerly filtering content blocks to only include Bedrock-supported fields - Ensuring all message content blocks are properly formatted for the Bedrock API Args: @@ -298,17 +285,19 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: Messages formatted for Bedrock API compatibility Note: - Bedrock will throw validation exceptions when presented with additional - unexpected fields in tool result blocks. - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html + Unlike other APIs that ignore unknown fields, Bedrock only accepts a strict + subset of fields for each content block type and throws validation exceptions + when presented with unexpected fields. Therefore, we must eagerly filter all + content blocks to remove any additional fields before sending to Bedrock. + https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlock.html """ - cleaned_messages = [] + cleaned_messages: list[dict[str, Any]] = [] filtered_unknown_members = False dropped_deepseek_reasoning_content = False for message in messages: - cleaned_content: list[ContentBlock] = [] + cleaned_content: list[dict[str, Any]] = [] for content_block in message["content"]: # Filter out SDK_UNKNOWN_MEMBER content blocks @@ -322,33 +311,13 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: dropped_deepseek_reasoning_content = True continue - if "toolResult" in content_block: - # Create a new content block with only the cleaned toolResult - tool_result: ToolResult = content_block["toolResult"] + # Format content blocks for Bedrock API compatibility + formatted_content = self._format_request_message_content(content_block) + cleaned_content.append(formatted_content) - if self._should_include_tool_result_status(): - # Include status field - cleaned_tool_result = ToolResult( - content=tool_result["content"], - toolUseId=tool_result["toolUseId"], - status=tool_result["status"], - ) - else: - # Remove status field - cleaned_tool_result = ToolResult( # type: ignore[typeddict-item] - toolUseId=tool_result["toolUseId"], content=tool_result["content"] - ) - - cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result} - cleaned_content.append(cleaned_block) - else: - # Keep other content blocks as-is - cleaned_content.append(content_block) - - # Create new message with cleaned content (skip if empty for DeepSeek) + # Create new message with cleaned content (skip if empty) if cleaned_content: - cleaned_message: Message = Message(content=cleaned_content, role=message["role"]) - cleaned_messages.append(cleaned_message) + cleaned_messages.append({"content": cleaned_content, "role": message["role"]}) if filtered_unknown_members: logger.warning( @@ -361,6 +330,184 @@ def _format_bedrock_messages(self, messages: Messages) -> Messages: return cleaned_messages + def _should_include_tool_result_status(self) -> bool: + """Determine whether to include tool result status based on current config.""" + include_status = self.config.get("include_tool_result_status", "auto") + + if include_status is True: + return True + elif include_status is False: + return False + else: # "auto" + return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS) + + def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + """Format a Bedrock content block. + + Bedrock strictly validates content blocks and throws exceptions for unknown fields. + This function extracts only the fields that Bedrock supports for each content type. + + Args: + content: Content block to format. + + Returns: + Bedrock formatted content block. + + Raises: + TypeError: If the content block type is not supported by Bedrock. + """ + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html + if "cachePoint" in content: + return {"cachePoint": {"type": content["cachePoint"]["type"]}} + + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html + if "document" in content: + document = content["document"] + result: dict[str, Any] = {} + + # Handle required fields (all optional due to total=False) + if "name" in document: + result["name"] = document["name"] + if "format" in document: + result["format"] = document["format"] + + # Handle source + if "source" in document: + result["source"] = {"bytes": document["source"]["bytes"]} + + # Handle optional fields + if "citations" in document and document["citations"] is not None: + result["citations"] = {"enabled": document["citations"]["enabled"]} + if "context" in document: + result["context"] = document["context"] + + return {"document": result} + + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_GuardrailConverseContentBlock.html + if "guardContent" in content: + guard = content["guardContent"] + guard_text = guard["text"] + result = {"text": {"text": guard_text["text"], "qualifiers": guard_text["qualifiers"]}} + return {"guardContent": result} + + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageBlock.html + if "image" in content: + image = content["image"] + source = image["source"] + formatted_source = {} + if "bytes" in source: + formatted_source = {"bytes": source["bytes"]} + result = {"format": image["format"], "source": formatted_source} + return {"image": result} + + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html + if "reasoningContent" in content: + reasoning = content["reasoningContent"] + result = {} + + if "reasoningText" in reasoning: + reasoning_text = reasoning["reasoningText"] + result["reasoningText"] = {} + if "text" in reasoning_text: + result["reasoningText"]["text"] = reasoning_text["text"] + # Only include signature if truthy (avoid empty strings) + if reasoning_text.get("signature"): + result["reasoningText"]["signature"] = reasoning_text["signature"] + + if "redactedContent" in reasoning: + result["redactedContent"] = reasoning["redactedContent"] + + return {"reasoningContent": result} + + # Pass through text and other simple content types + if "text" in content: + return {"text": content["text"]} + + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html + if "toolResult" in content: + tool_result = content["toolResult"] + formatted_content: list[dict[str, Any]] = [] + for tool_result_content in tool_result["content"]: + if "json" in tool_result_content: + # Handle json field since not in ContentBlock but valid in ToolResultContent + formatted_content.append({"json": tool_result_content["json"]}) + else: + formatted_content.append( + self._format_request_message_content(cast(ContentBlock, tool_result_content)) + ) + + result = { + "content": formatted_content, + "toolUseId": tool_result["toolUseId"], + } + if "status" in tool_result and self._should_include_tool_result_status(): + result["status"] = tool_result["status"] + return {"toolResult": result} + + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolUseBlock.html + if "toolUse" in content: + tool_use = content["toolUse"] + return { + "toolUse": { + "input": tool_use["input"], + "name": tool_use["name"], + "toolUseId": tool_use["toolUseId"], + } + } + + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_VideoBlock.html + if "video" in content: + video = content["video"] + source = video["source"] + formatted_source = {} + if "bytes" in source: + formatted_source = {"bytes": source["bytes"]} + result = {"format": video["format"], "source": formatted_source} + return {"video": result} + + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html + if "citationsContent" in content: + citations = content["citationsContent"] + result = {} + + if "citations" in citations: + result["citations"] = [] + for citation in citations["citations"]: + filtered_citation: dict[str, Any] = {} + if "location" in citation: + location = citation["location"] + filtered_location = {} + # Filter location fields to only include Bedrock-supported ones + if "documentIndex" in location: + filtered_location["documentIndex"] = location["documentIndex"] + if "start" in location: + filtered_location["start"] = location["start"] + if "end" in location: + filtered_location["end"] = location["end"] + filtered_citation["location"] = filtered_location + if "sourceContent" in citation: + filtered_source_content: list[dict[str, Any]] = [] + for source_content in citation["sourceContent"]: + if "text" in source_content: + filtered_source_content.append({"text": source_content["text"]}) + if filtered_source_content: + filtered_citation["sourceContent"] = filtered_source_content + if "title" in citation: + filtered_citation["title"] = citation["title"] + result["citations"].append(filtered_citation) + + if "content" in citations: + filtered_content: list[dict[str, Any]] = [] + for generated_content in citations["content"]: + if "text" in generated_content: + filtered_content.append({"text": generated_content["text"]}) + if filtered_content: + result["content"] = filtered_content + + return {"citationsContent": result} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: """Check if guardrail data contains any blocked policies. diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index e9bea2686..a443c9b66 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1519,6 +1519,159 @@ async def test_stream_deepseek_skips_empty_messages(bedrock_client, alist): assert sent_messages[1]["content"] == [{"text": "Follow up"}] +def test_format_request_filters_image_content_blocks(model, model_id): + """Test that format_request filters extra fields from image content blocks.""" + messages = [ + { + "role": "user", + "content": [ + { + "image": { + "format": "png", + "source": {"bytes": b"image_data"}, + "filename": "test.png", # Extra field that should be filtered + "metadata": {"size": 1024}, # Extra field that should be filtered + } + }, + ], + } + ] + + formatted_request = model.format_request(messages) + + image_block = formatted_request["messages"][0]["content"][0]["image"] + expected = {"format": "png", "source": {"bytes": b"image_data"}} + assert image_block == expected + assert "filename" not in image_block + assert "metadata" not in image_block + + +def test_format_request_filters_nested_image_s3_fields(model, model_id): + """Test that s3Location is filtered out and only bytes source is preserved.""" + messages = [ + { + "role": "user", + "content": [ + { + "image": { + "format": "png", + "source": { + "bytes": b"image_data", + "s3Location": {"bucket": "my-bucket", "key": "image.png", "extraField": "filtered"}, + }, + } + } + ], + } + ] + + formatted_request = model.format_request(messages) + image_source = formatted_request["messages"][0]["content"][0]["image"]["source"] + + assert image_source == {"bytes": b"image_data"} + assert "s3Location" not in image_source + + +def test_format_request_filters_document_content_blocks(model, model_id): + """Test that format_request filters extra fields from document content blocks.""" + messages = [ + { + "role": "user", + "content": [ + { + "document": { + "name": "test.pdf", + "source": {"bytes": b"pdf_data"}, + "format": "pdf", + "extraField": "should be removed", + "metadata": {"pages": 10}, + } + }, + ], + } + ] + + formatted_request = model.format_request(messages) + + document_block = formatted_request["messages"][0]["content"][0]["document"] + expected = {"name": "test.pdf", "source": {"bytes": b"pdf_data"}, "format": "pdf"} + assert document_block == expected + assert "extraField" not in document_block + assert "metadata" not in document_block + + +def test_format_request_filters_nested_reasoning_content(model, model_id): + """Test deep filtering of nested reasoningText fields.""" + messages = [ + { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": {"text": "thinking...", "signature": "abc123", "extraField": "filtered"} + } + } + ], + } + ] + + formatted_request = model.format_request(messages) + reasoning_text = formatted_request["messages"][0]["content"][0]["reasoningContent"]["reasoningText"] + + assert reasoning_text == {"text": "thinking...", "signature": "abc123"} + + +def test_format_request_filters_video_content_blocks(model, model_id): + """Test that format_request filters extra fields from video content blocks.""" + messages = [ + { + "role": "user", + "content": [ + { + "video": { + "format": "mp4", + "source": {"bytes": b"video_data"}, + "duration": 120, # Extra field that should be filtered + "resolution": "1080p", # Extra field that should be filtered + } + }, + ], + } + ] + + formatted_request = model.format_request(messages) + + video_block = formatted_request["messages"][0]["content"][0]["video"] + expected = {"format": "mp4", "source": {"bytes": b"video_data"}} + assert video_block == expected + assert "duration" not in video_block + assert "resolution" not in video_block + + +def test_format_request_filters_cache_point_content_blocks(model, model_id): + """Test that format_request filters extra fields from cachePoint content blocks.""" + messages = [ + { + "role": "user", + "content": [ + { + "cachePoint": { + "type": "default", + "extraField": "should be removed", + } + }, + ], + } + ] + + formatted_request = model.format_request(messages) + + cache_point_block = formatted_request["messages"][0]["content"][0]["cachePoint"] + expected = {"type": "default"} + assert cache_point_block == expected + assert "extraField" not in cache_point_block + + def test_config_validation_warns_on_unknown_keys(bedrock_client, captured_warnings): """Test that unknown config keys emit a warning.""" BedrockModel(model_id="test-model", invalid_param="test")