Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 52 additions & 51 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,10 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
"""
validate_config_keys(model_config, self.OpenAIConfig)
self.config = dict(model_config)
self.client_args = client_args or {}

logger.debug("config=<%s> | initializing", self.config)

client_args = client_args or {}
self.client = openai.AsyncOpenAI(**client_args)

@override
def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
"""Update the OpenAI model configuration with the provided arguments.
Expand Down Expand Up @@ -379,58 +377,60 @@ async def stream(
logger.debug("formatted request=<%s>", request)

logger.debug("invoking model")
response = await self.client.chat.completions.create(**request)

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})

tool_calls: dict[int, list[Any]] = {}

async for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]

if choice.delta.content:
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
)

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data": choice.delta.reasoning_content,
}
)

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)
async with openai.AsyncOpenAI(**self.client_args) as client:
response = await client.chat.completions.create(**request)

if choice.finish_reason:
break
logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
tool_calls: dict[int, list[Any]] = {}

for tool_deltas in tool_calls.values():
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
async for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]

if choice.delta.content:
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
)

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data": choice.delta.reasoning_content,
}
)

for tool_delta in tool_deltas:
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
if choice.finish_reason:
break

yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})

# Skip remaining events as we don't have use for anything except the final usage payload
async for event in response:
_ = event
for tool_deltas in tool_calls.values():
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})

if event.usage:
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
for tool_delta in tool_deltas:
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})

yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})

# Skip remaining events as we don't have use for anything except the final usage payload
async for event in response:
_ = event

if event.usage:
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})

logger.debug("finished streaming response from model")

Expand All @@ -449,11 +449,12 @@ async def structured_output(
Yields:
Model events with the last being the structured output.
"""
response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore
model=self.get_config()["model_id"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
response_format=output_model,
)
async with openai.AsyncOpenAI(**self.client_args) as client:
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
model=self.get_config()["model_id"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
response_format=output_model,
)

parsed: T | None = None
# Find the first choice with tool_calls
Expand Down
17 changes: 6 additions & 11 deletions tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@


@pytest.fixture
def openai_client_cls():
def openai_client():
with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls:
yield mock_client_cls


@pytest.fixture
def openai_client(openai_client_cls):
return openai_client_cls.return_value
mock_client = unittest.mock.AsyncMock()
mock_client_cls.return_value.__aenter__.return_value = mock_client
yield mock_client


@pytest.fixture
Expand Down Expand Up @@ -68,16 +65,14 @@ class TestOutputModel(pydantic.BaseModel):
return TestOutputModel


def test__init__(openai_client_cls, model_id):
model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1})
def test__init__(model_id):
model = OpenAIModel(model_id=model_id, params={"max_tokens": 1})

tru_config = model.get_config()
exp_config = {"model_id": "m1", "params": {"max_tokens": 1}}

assert tru_config == exp_config

openai_client_cls.assert_called_once_with(api_key="k1")


def test_update_config(model, model_id):
model.update_config(model_id=model_id)
Expand Down
Loading