diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index fd75ea175..b80cdddab 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -64,12 +64,10 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: """ validate_config_keys(model_config, self.OpenAIConfig) self.config = dict(model_config) + self.client_args = client_args or {} logger.debug("config=<%s> | initializing", self.config) - client_args = client_args or {} - self.client = openai.AsyncOpenAI(**client_args) - @override def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override] """Update the OpenAI model configuration with the provided arguments. @@ -379,58 +377,60 @@ async def stream( logger.debug("formatted request=<%s>", request) logger.debug("invoking model") - response = await self.client.chat.completions.create(**request) - - logger.debug("got response from model") - yield self.format_chunk({"chunk_type": "message_start"}) - yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) - - tool_calls: dict[int, list[Any]] = {} - - async for event in response: - # Defensive: skip events with empty or missing choices - if not getattr(event, "choices", None): - continue - choice = event.choices[0] - - if choice.delta.content: - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} - ) - - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: - yield self.format_chunk( - { - "chunk_type": "content_delta", - "data_type": "reasoning_content", - "data": choice.delta.reasoning_content, - } - ) - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) + async with openai.AsyncOpenAI(**self.client_args) as client: + response = await client.chat.completions.create(**request) - if choice.finish_reason: - break + logger.debug("got response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + tool_calls: dict[int, list[Any]] = {} - for tool_deltas in tool_calls.values(): - yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) + async for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + if choice.delta.content: + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + ) + + if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": choice.delta.reasoning_content, + } + ) - for tool_delta in tool_deltas: - yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + if choice.finish_reason: + break - yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) - # Skip remaining events as we don't have use for anything except the final usage payload - async for event in response: - _ = event + for tool_deltas in tool_calls.values(): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) - if event.usage: - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + for tool_delta in tool_deltas: + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) + + # Skip remaining events as we don't have use for anything except the final usage payload + async for event in response: + _ = event + + if event.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) logger.debug("finished streaming response from model") @@ -449,11 +449,12 @@ async def structured_output( Yields: Model events with the last being the structured output. """ - response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore - model=self.get_config()["model_id"], - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], - response_format=output_model, - ) + async with openai.AsyncOpenAI(**self.client_args) as client: + response: ParsedChatCompletion = await client.beta.chat.completions.parse( + model=self.get_config()["model_id"], + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], + response_format=output_model, + ) parsed: T | None = None # Find the first choice with tool_calls diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 64da3cac2..5979ec628 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -8,14 +8,11 @@ @pytest.fixture -def openai_client_cls(): +def openai_client(): with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls: - yield mock_client_cls - - -@pytest.fixture -def openai_client(openai_client_cls): - return openai_client_cls.return_value + mock_client = unittest.mock.AsyncMock() + mock_client_cls.return_value.__aenter__.return_value = mock_client + yield mock_client @pytest.fixture @@ -68,16 +65,14 @@ class TestOutputModel(pydantic.BaseModel): return TestOutputModel -def test__init__(openai_client_cls, model_id): - model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) +def test__init__(model_id): + model = OpenAIModel(model_id=model_id, params={"max_tokens": 1}) tru_config = model.get_config() exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} assert tru_config == exp_config - openai_client_cls.assert_called_once_with(api_key="k1") - def test_update_config(model, model_id): model.update_config(model_id=model_id)