Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions chatlas/_provider_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def ChatOpenAI(
model: "Optional[ResponsesModel | str]" = None,
api_key: Optional[str] = None,
base_url: str = "https://api.openai.com/v1",
service_tier: Optional[
Literal["auto", "default", "flex", "scale", "priority"]
] = None,
kwargs: Optional["ChatClientArgs"] = None,
) -> Chat["SubmitInputArgs", Response]:
"""
Expand Down Expand Up @@ -93,6 +96,13 @@ def ChatOpenAI(
variable.
base_url
The base URL to the endpoint; the default uses OpenAI.
service_tier
Request a specific service tier. Options:
- `"auto"` (default): uses the service tier configured in Project settings.
- `"default"`: standard pricing and performance.
- `"flex"`: slower and cheaper.
- `"scale"`: batch-like pricing for high-volume use.
- `"priority"`: faster and more expensive.
kwargs
Additional arguments to pass to the `openai.OpenAI()` client
constructor.
Expand Down Expand Up @@ -146,6 +156,10 @@ def ChatOpenAI(
if model is None:
model = log_model_default("gpt-4.1")

kwargs_chat: "SubmitInputArgs" = {}
if service_tier is not None:
kwargs_chat["service_tier"] = service_tier

return Chat(
provider=OpenAIProvider(
api_key=api_key,
Expand All @@ -154,6 +168,7 @@ def ChatOpenAI(
kwargs=kwargs,
),
system_prompt=system_prompt,
kwargs_chat=kwargs_chat,
)


Expand Down Expand Up @@ -260,6 +275,16 @@ def stream_text(self, chunk):
def stream_merge_chunks(self, completion, chunk):
if chunk.type == "response.completed":
return chunk.response
elif chunk.type == "response.failed":
error = chunk.response.error
if error is None:
msg = "Request failed with an unknown error."
else:
msg = f"Request failed ({error.code}): {error.message}"
raise RuntimeError(msg)
elif chunk.type == "error":
raise RuntimeError(f"Request errored: {chunk.message}")

# Since this value won't actually be used, we can lie about the type
return cast(Response, None)

Expand Down
17 changes: 16 additions & 1 deletion chatlas/_provider_openai_azure.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from openai import AsyncAzureOpenAI, AzureOpenAI
from openai.types.chat import ChatCompletion
Expand All @@ -21,6 +21,9 @@ def ChatAzureOpenAI(
api_version: str,
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
service_tier: Optional[
Literal["auto", "default", "flex", "scale", "priority"]
] = None,
kwargs: Optional["ChatAzureClientArgs"] = None,
) -> Chat["SubmitInputArgs", ChatCompletion]:
"""
Expand Down Expand Up @@ -62,6 +65,13 @@ def ChatAzureOpenAI(
variable.
system_prompt
A system prompt to set the behavior of the assistant.
service_tier
Request a specific service tier. Options:
- `"auto"` (default): uses the service tier configured in Project settings.
- `"default"`: standard pricing and performance.
- `"flex"`: slower and cheaper.
- `"scale"`: batch-like pricing for high-volume use.
- `"priority"`: faster and more expensive.
kwargs
Additional arguments to pass to the `openai.AzureOpenAI()` client constructor.

Expand All @@ -71,6 +81,10 @@ def ChatAzureOpenAI(
A Chat object.
"""

kwargs_chat: "SubmitInputArgs" = {}
if service_tier is not None:
kwargs_chat["service_tier"] = service_tier

return Chat(
provider=OpenAIAzureProvider(
endpoint=endpoint,
Expand All @@ -80,6 +94,7 @@ def ChatAzureOpenAI(
kwargs=kwargs,
),
system_prompt=system_prompt,
kwargs_chat=kwargs_chat,
)


Expand Down
5 changes: 5 additions & 0 deletions tests/test_provider_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,8 @@ def test_openai_custom_http_client():

def test_openai_list_models():
assert_list_models(ChatOpenAI)


def test_openai_service_tier():
chat = ChatOpenAI(service_tier="flex")
assert chat.kwargs_chat.get("service_tier") == "flex"