diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index a6bd93cca..8ff37d359 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -34,6 +34,32 @@ class OpenAIModel(Model, abc.ABC): config: dict[str, Any] + @staticmethod + def b64encode(data: bytes) -> bytes: + """Base64 encode the provided data. + + If the data is already base64 encoded, we do nothing. + Note, this is a temporary method used to provide a warning to users who pass in base64 encoded data. In future + versions, images and documents will be base64 encoded on behalf of customers for consistency with the other + providers and general convenience. + + Args: + data: Data to encode. + + Returns: + Base64 encoded data. + """ + try: + base64.b64decode(data, validate=True) + logger.warning( + "issue=<%s> | base64 encoded images and documents will not be accepted in future versions", + "https://github.com/strands-agents/sdk-python/issues/252", + ) + except ValueError: + data = base64.b64encode(data) + + return data + @classmethod def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: """Format an OpenAI compatible content block. @@ -60,17 +86,8 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] if "image" in content: mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_bytes = content["image"]["source"]["bytes"] - try: - base64.b64decode(image_bytes, validate=True) - logger.warning( - "issue=<%s> | base64 encoded images will not be accepted in a future version", - "https://github.com/strands-agents/sdk-python/issues/252", - ) - except ValueError: - image_bytes = base64.b64encode(image_bytes) - - image_data = image_bytes.decode("utf-8") + image_data = OpenAIModel.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + return { "image_url": { "detail": "auto", diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index 2827969db..3a1a940b1 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -362,3 +362,15 @@ def test_format_chunk_unknown_type(model): with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): model.format_chunk(event) + + +@pytest.mark.parametrize( + ("data", "exp_result"), + [ + (b"image", b"aW1hZ2U="), + (b"aW1hZ2U=", b"aW1hZ2U="), + ], +) +def test_b64encode(data, exp_result): + tru_result = SAOpenAIModel.b64encode(data) + assert tru_result == exp_result