Skip to content

Commit 68daceb

Browse files
committed
fix: actually read env vars
1 parent b74952e commit 68daceb

File tree

3 files changed

+99
-67
lines changed

3 files changed

+99
-67
lines changed

src/gradient/_client.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ class Gradient(SyncAPIClient):
8080
def __init__(
8181
self,
8282
*,
83-
api_key: str | None = None, # deprecated, use `access_token` instead
84-
inference_key: str | None = None, # deprecated, use `model_access_key` instead
85-
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
83+
api_key: str | None = None, # deprecated, use `access_token` instead
84+
inference_key: str | None = None, # deprecated, use `model_access_key` instead
85+
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
8686
access_token: str | None = None,
8787
model_access_key: str | None = None,
8888
agent_access_key: str | None = None,
@@ -124,7 +124,6 @@ def __init__(
124124
access_token = os.environ.get("GRADIENT_API_KEY")
125125
self.access_token = access_token
126126

127-
128127
if model_access_key is None:
129128
if inference_key is not None:
130129
model_access_key = inference_key
@@ -145,8 +144,15 @@ def __init__(
145144
agent_access_key = os.environ.get("GRADIENT_AGENT_KEY")
146145
self.agent_access_key = agent_access_key
147146

147+
if agent_endpoint is None:
148+
agent_endpoint = os.environ.get("GRADIENT_AGENT_ENDPOINT")
148149
self._agent_endpoint = agent_endpoint
149150

151+
if inference_endpoint is None:
152+
inference_endpoint = os.environ.get("GRADIENT_INFERENCE_ENDPOINT")
153+
if inference_endpoint is None:
154+
inference_endpoint = "https://inference.do-ai.run"
155+
150156
self.inference_endpoint = inference_endpoint
151157

152158
if base_url is None:
@@ -275,9 +281,9 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
275281
def copy(
276282
self,
277283
*,
278-
api_key: str | None = None, # deprecated, use `access_token` instead
279-
inference_key: str | None = None, # deprecated, use `model_access_key` instead
280-
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
284+
api_key: str | None = None, # deprecated, use `access_token` instead
285+
inference_key: str | None = None, # deprecated, use `model_access_key` instead
286+
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
281287
access_token: str | None = None,
282288
model_access_key: str | None = None,
283289
agent_access_key: str | None = None,
@@ -393,9 +399,9 @@ class AsyncGradient(AsyncAPIClient):
393399
def __init__(
394400
self,
395401
*,
396-
api_key: str | None = None, # deprecated, use `access_token` instead
397-
inference_key: str | None = None, # deprecated, use `model_access_key` instead
398-
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
402+
api_key: str | None = None, # deprecated, use `access_token` instead
403+
inference_key: str | None = None, # deprecated, use `model_access_key` instead
404+
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
399405
access_token: str | None = None,
400406
model_access_key: str | None = None,
401407
agent_access_key: str | None = None,
@@ -437,7 +443,6 @@ def __init__(
437443
access_token = os.environ.get("GRADIENT_API_KEY")
438444
self.access_token = access_token
439445

440-
441446
if model_access_key is None:
442447
if inference_key is not None:
443448
model_access_key = inference_key
@@ -588,9 +593,9 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
588593
def copy(
589594
self,
590595
*,
591-
api_key: str | None = None, # deprecated, use `access_token` instead
592-
inference_key: str | None = None, # deprecated, use `model_access_key` instead
593-
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
596+
api_key: str | None = None, # deprecated, use `access_token` instead
597+
inference_key: str | None = None, # deprecated, use `model_access_key` instead
598+
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
594599
agent_endpoint: str | None = None,
595600
access_token: str | None = None,
596601
model_access_key: str | None = None,
@@ -633,23 +638,11 @@ def copy(
633638

634639
http_client = http_client or self._client
635640
client = self.__class__(
636-
<<<<<<< HEAD
637-
api_key=api_key or self.api_key,
638-
inference_key=inference_key or self.inference_key,
639-
agent_key=agent_key or self.agent_key,
640-
agent_endpoint=agent_endpoint or self._agent_endpoint,
641-
||||||| eb1dcf7
642-
api_key=api_key or self.api_key,
643-
inference_key=inference_key or self.inference_key,
644-
agent_key=agent_key or self.agent_key,
645-
agent_domain=agent_domain or self.agent_domain,
646-
=======
647-
access_token=access_token or self.access_token,
648-
model_access_key=model_access_key or self.model_access_key,
649-
agent_access_key=agent_access_key or self.agent_access_key,
641+
access_token=access_token or api_key or self.access_token,
642+
model_access_key=model_access_key or inference_key or self.model_access_key,
643+
agent_access_key=agent_access_key or agent_key or self.agent_access_key,
650644
agent_endpoint=agent_endpoint or self.agent_endpoint,
651645
inference_endpoint=inference_endpoint or self.inference_endpoint,
652-
>>>>>>> origin/generated--merge-conflict
653646
base_url=base_url or self.base_url,
654647
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
655648
http_client=http_client,

src/gradient/resources/agents/chat/completions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -463,13 +463,13 @@ def create(
463463
extra_body: Body | None = None,
464464
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
465465
) -> CompletionCreateResponse | Stream[ChatCompletionChunk]:
466-
# This method requires an agent_key to be set via client argument or environment variable
467-
if not self._client.agent_key:
466+
# This method requires an agent_access_key to be set via client argument or environment variable
467+
if not self._client.agent_access_key:
468468
raise TypeError(
469-
"Could not resolve authentication method. Expected agent_key to be set for chat completions."
469+
"Could not resolve authentication method. Expected agent_access_key to be set for chat completions."
470470
)
471471
headers = extra_headers or {}
472-
headers = {"Authorization": f"Bearer {self._client.agent_key}", **headers}
472+
headers = {"Authorization": f"Bearer {self._client.agent_access_key}", **headers}
473473

474474
return self._post(
475475
(
@@ -951,13 +951,13 @@ async def create(
951951
extra_body: Body | None = None,
952952
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
953953
) -> CompletionCreateResponse | AsyncStream[ChatCompletionChunk]:
954-
# This method requires an agent_key to be set via client argument or environment variable
955-
if not self._client.agent_key:
954+
# This method requires an agent_access_key to be set via client argument or environment variable
955+
if not self._client.agent_access_key:
956956
raise TypeError(
957-
"Could not resolve authentication method. Expected agent_key to be set for chat completions."
957+
"Could not resolve authentication method. Expected agent_access_key to be set for chat completions."
958958
)
959959
headers = extra_headers or {}
960-
headers = {"Authorization": f"Bearer {self._client.agent_key}", **headers}
960+
headers = {"Authorization": f"Bearer {self._client.agent_access_key}", **headers}
961961

962962
return await self._post(
963963
(

src/gradient/resources/chat/completions.py

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def create(
6262
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
6363
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
6464
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
65-
stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN,
65+
stream_options: (
66+
Optional[completion_create_params.StreamOptions] | NotGiven
67+
) = NOT_GIVEN,
6668
temperature: Optional[float] | NotGiven = NOT_GIVEN,
6769
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
6870
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
@@ -191,7 +193,9 @@ def create(
191193
n: Optional[int] | NotGiven = NOT_GIVEN,
192194
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
193195
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
194-
stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN,
196+
stream_options: (
197+
Optional[completion_create_params.StreamOptions] | NotGiven
198+
) = NOT_GIVEN,
195199
temperature: Optional[float] | NotGiven = NOT_GIVEN,
196200
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
197201
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
@@ -319,7 +323,9 @@ def create(
319323
n: Optional[int] | NotGiven = NOT_GIVEN,
320324
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
321325
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
322-
stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN,
326+
stream_options: (
327+
Optional[completion_create_params.StreamOptions] | NotGiven
328+
) = NOT_GIVEN,
323329
temperature: Optional[float] | NotGiven = NOT_GIVEN,
324330
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
325331
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
@@ -447,7 +453,9 @@ def create(
447453
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
448454
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
449455
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
450-
stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN,
456+
stream_options: (
457+
Optional[completion_create_params.StreamOptions] | NotGiven
458+
) = NOT_GIVEN,
451459
temperature: Optional[float] | NotGiven = NOT_GIVEN,
452460
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
453461
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
@@ -461,18 +469,23 @@ def create(
461469
extra_body: Body | None = None,
462470
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
463471
) -> CompletionCreateResponse | Stream[ChatCompletionChunk]:
464-
# This method requires an inference_key to be set via client argument or environment variable
465-
if not self._client.inference_key:
472+
# This method requires an model_access_key to be set via client argument or environment variable
473+
if not self._client.model_access_key:
466474
raise TypeError(
467-
"Could not resolve authentication method. Expected inference_key to be set for chat completions."
475+
"Could not resolve authentication method. Expected model_access_key to be set for chat completions."
468476
)
469477
headers = extra_headers or {}
470-
headers = {"Authorization": f"Bearer {self._client.inference_key}", **headers}
478+
headers = {
479+
"Authorization": f"Bearer {self._client.model_access_key}",
480+
**headers,
481+
}
471482

472483
return self._post(
473-
"/chat/completions"
474-
if self._client._base_url_overridden
475-
else "https://inference.do-ai.run/v1/chat/completions",
484+
(
485+
"/chat/completions"
486+
if self._client._base_url_overridden
487+
else f"{self._client.inference_endpoint}/v1/chat/completions"
488+
),
476489
body=maybe_transform(
477490
{
478491
"messages": messages,
@@ -495,12 +508,17 @@ def create(
495508
"top_p": top_p,
496509
"user": user,
497510
},
498-
completion_create_params.CompletionCreateParamsStreaming
499-
if stream
500-
else completion_create_params.CompletionCreateParamsNonStreaming,
511+
(
512+
completion_create_params.CompletionCreateParamsStreaming
513+
if stream
514+
else completion_create_params.CompletionCreateParamsNonStreaming
515+
),
501516
),
502517
options=make_request_options(
503-
extra_headers=headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
518+
extra_headers=headers,
519+
extra_query=extra_query,
520+
extra_body=extra_body,
521+
timeout=timeout,
504522
),
505523
cast_to=CompletionCreateResponse,
506524
stream=stream or False,
@@ -544,7 +562,9 @@ async def create(
544562
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
545563
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
546564
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
547-
stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN,
565+
stream_options: (
566+
Optional[completion_create_params.StreamOptions] | NotGiven
567+
) = NOT_GIVEN,
548568
temperature: Optional[float] | NotGiven = NOT_GIVEN,
549569
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
550570
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
@@ -673,7 +693,9 @@ async def create(
673693
n: Optional[int] | NotGiven = NOT_GIVEN,
674694
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
675695
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
676-
stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN,
696+
stream_options: (
697+
Optional[completion_create_params.StreamOptions] | NotGiven
698+
) = NOT_GIVEN,
677699
temperature: Optional[float] | NotGiven = NOT_GIVEN,
678700
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
679701
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
@@ -801,7 +823,9 @@ async def create(
801823
n: Optional[int] | NotGiven = NOT_GIVEN,
802824
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
803825
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
804-
stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN,
826+
stream_options: (
827+
Optional[completion_create_params.StreamOptions] | NotGiven
828+
) = NOT_GIVEN,
805829
temperature: Optional[float] | NotGiven = NOT_GIVEN,
806830
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
807831
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
@@ -929,7 +953,9 @@ async def create(
929953
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
930954
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
931955
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
932-
stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN,
956+
stream_options: (
957+
Optional[completion_create_params.StreamOptions] | NotGiven
958+
) = NOT_GIVEN,
933959
temperature: Optional[float] | NotGiven = NOT_GIVEN,
934960
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
935961
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
@@ -943,18 +969,26 @@ async def create(
943969
extra_body: Body | None = None,
944970
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
945971
) -> CompletionCreateResponse | AsyncStream[ChatCompletionChunk]:
946-
# This method requires an inference_key to be set via client argument or environment variable
947-
if not hasattr(self._client, "inference_key") or not self._client.inference_key:
972+
# This method requires an model_access_key to be set via client argument or environment variable
973+
if (
974+
not hasattr(self._client, "model_access_key")
975+
or not self._client.model_access_key
976+
):
948977
raise TypeError(
949-
"Could not resolve authentication method. Expected inference_key to be set for chat completions."
978+
"Could not resolve authentication method. Expected model_access_key to be set for chat completions."
950979
)
951980
headers = extra_headers or {}
952-
headers = {"Authorization": f"Bearer {self._client.inference_key}", **headers}
981+
headers = {
982+
"Authorization": f"Bearer {self._client.model_access_key}",
983+
**headers,
984+
}
953985

954986
return await self._post(
955-
"/chat/completions"
956-
if self._client._base_url_overridden
957-
else "https://inference.do-ai.run/v1/chat/completions",
987+
(
988+
"/chat/completions"
989+
if self._client._base_url_overridden
990+
else f"{self._client.inference_endpoint}/chat/completions"
991+
),
958992
body=await async_maybe_transform(
959993
{
960994
"messages": messages,
@@ -977,12 +1011,17 @@ async def create(
9771011
"top_p": top_p,
9781012
"user": user,
9791013
},
980-
completion_create_params.CompletionCreateParamsStreaming
981-
if stream
982-
else completion_create_params.CompletionCreateParamsNonStreaming,
1014+
(
1015+
completion_create_params.CompletionCreateParamsStreaming
1016+
if stream
1017+
else completion_create_params.CompletionCreateParamsNonStreaming
1018+
),
9831019
),
9841020
options=make_request_options(
985-
extra_headers=headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
1021+
extra_headers=headers,
1022+
extra_query=extra_query,
1023+
extra_body=extra_body,
1024+
timeout=timeout,
9861025
),
9871026
cast_to=CompletionCreateResponse,
9881027
stream=stream or False,

0 commit comments

Comments
 (0)