Skip to content

Commit b74952e

Browse files
feat(api): make kwargs match the env vars
1 parent 6853d05 commit b74952e

File tree

5 files changed

+343
-287
lines changed

5 files changed

+343
-287
lines changed

.stats.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
configured_endpoints: 170
22
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/digitalocean%2Fgradient-9aca3802735e1375125412aa28ac36bf2175144b8218610a73d2e7f775694dff.yml
33
openapi_spec_hash: e29d14e3e4679fcf22b3e760e49931b1
4-
config_hash: 136e1973eb6297e6308a165594bd00a3
4+
config_hash: 99e3cd5dde0beb796f4547410869f726

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ import os
4242
from gradient import Gradient
4343

4444
client = Gradient(
45-
api_key=os.environ.get("DIGITALOCEAN_ACCESS_TOKEN"), # This is the default and can be omitted
45+
access_token=os.environ.get(
46+
"DIGITALOCEAN_ACCESS_TOKEN"
47+
), # This is the default and can be omitted
4648
)
4749
inference_client = Gradient(
4850
inference_key=os.environ.get(
@@ -90,7 +92,7 @@ print("--- Agent Inference")
9092
print(agent_response.choices[0].message.content)
9193
```
9294

93-
While you can provide an `api_key`, `inference_key` keyword argument,
95+
While you can provide an `access_token`, `model_access_key` keyword argument,
9496
we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/)
9597
to add `DIGITALOCEAN_ACCESS_TOKEN="My API Key"`, `GRADIENT_MODEL_ACCESS_KEY="My INFERENCE Key"` to your `.env` file
9698
so that your keys are not stored in source control.
@@ -105,7 +107,9 @@ import asyncio
105107
from gradient import AsyncGradient
106108

107109
client = AsyncGradient(
108-
api_key=os.environ.get("DIGITALOCEAN_ACCESS_TOKEN"), # This is the default and can be omitted
110+
access_token=os.environ.get(
111+
"DIGITALOCEAN_ACCESS_TOKEN"
112+
), # This is the default and can be omitted
109113
)
110114

111115

@@ -148,7 +152,7 @@ from gradient import AsyncGradient
148152

149153
async def main() -> None:
150154
async with AsyncGradient(
151-
api_key="My API Key",
155+
access_token="My Access Token",
152156
http_client=DefaultAioHttpClient(),
153157
) as client:
154158
completion = await client.chat.completions.create(

src/gradient/_client.py

Lines changed: 138 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,23 @@
7171

7272
class Gradient(SyncAPIClient):
7373
# client options
74-
api_key: str | None
75-
inference_key: str | None
76-
agent_key: str | None
74+
access_token: str | None
75+
model_access_key: str | None
76+
agent_access_key: str | None
7777
_agent_endpoint: str | None
78+
inference_endpoint: str | None
7879

7980
def __init__(
8081
self,
8182
*,
82-
api_key: str | None = None,
83-
inference_key: str | None = None,
84-
agent_key: str | None = None,
83+
api_key: str | None = None, # deprecated, use `access_token` instead
84+
inference_key: str | None = None, # deprecated, use `model_access_key` instead
85+
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
86+
access_token: str | None = None,
87+
model_access_key: str | None = None,
88+
agent_access_key: str | None = None,
8589
agent_endpoint: str | None = None,
90+
inference_endpoint: str | None = None,
8691
base_url: str | httpx.URL | None = None,
8792
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
8893
max_retries: int = DEFAULT_MAX_RETRIES,
@@ -105,34 +110,45 @@ def __init__(
105110
"""Construct a new synchronous Gradient client instance.
106111
107112
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
108-
- `api_key` from `DIGITALOCEAN_ACCESS_TOKEN`
109-
- `inference_key` from `GRADIENT_MODEL_ACCESS_KEY`
110-
- `agent_key` from `GRADIENT_AGENT_ACCESS_KEY`
113+
- `access_token` from `DIGITALOCEAN_ACCESS_TOKEN`
114+
- `model_access_key` from `GRADIENT_MODEL_ACCESS_KEY`
115+
- `agent_access_key` from `GRADIENT_AGENT_ACCESS_KEY`
111116
"""
112-
if api_key is None:
113-
api_key = os.environ.get("DIGITALOCEAN_ACCESS_TOKEN")
114-
# support for legacy environment variable
115-
if api_key is None:
116-
api_key = os.environ.get("GRADIENT_API_KEY")
117-
self.api_key = api_key
118-
119-
if inference_key is None:
120-
inference_key = os.environ.get("GRADIENT_MODEL_ACCESS_KEY")
121-
# support for legacy environment variable
122-
if inference_key is None:
123-
inference_key = os.environ.get("GRADIENT_INFERENCE_KEY")
124-
125-
self.inference_key = inference_key
126-
127-
if agent_key is None:
128-
agent_key = os.environ.get("GRADIENT_AGENT_ACCESS_KEY")
129-
# support for legacy environment variable
130-
if agent_key is None:
131-
agent_key = os.environ.get("GRADIENT_AGENT_KEY")
132-
self.agent_key = agent_key
117+
if access_token is None:
118+
if api_key is not None:
119+
access_token = api_key
120+
else:
121+
access_token = os.environ.get("DIGITALOCEAN_ACCESS_TOKEN")
122+
# support for legacy environment variable
123+
if access_token is None:
124+
access_token = os.environ.get("GRADIENT_API_KEY")
125+
self.access_token = access_token
126+
127+
128+
if model_access_key is None:
129+
if inference_key is not None:
130+
model_access_key = inference_key
131+
else:
132+
model_access_key = os.environ.get("GRADIENT_INFERENCE_KEY")
133+
# support for legacy environment variable
134+
if model_access_key is None:
135+
model_access_key = os.environ.get("GRADIENT_MODEL_ACCESS_KEY")
136+
self.model_access_key = model_access_key
137+
138+
if agent_access_key is None:
139+
if agent_key is not None:
140+
agent_access_key = agent_key
141+
else:
142+
agent_access_key = os.environ.get("GRADIENT_AGENT_ACCESS_KEY")
143+
# support for legacy environment variable
144+
if agent_access_key is None:
145+
agent_access_key = os.environ.get("GRADIENT_AGENT_KEY")
146+
self.agent_access_key = agent_access_key
133147

134148
self._agent_endpoint = agent_endpoint
135149

150+
self.inference_endpoint = inference_endpoint
151+
136152
if base_url is None:
137153
base_url = os.environ.get("GRADIENT_BASE_URL")
138154
self._base_url_overridden = base_url is not None
@@ -229,10 +245,10 @@ def qs(self) -> Querystring:
229245
@property
230246
@override
231247
def auth_headers(self) -> dict[str, str]:
232-
api_key = self.api_key
233-
if api_key is None:
248+
access_token = self.access_token
249+
if access_token is None:
234250
return {}
235-
return {"Authorization": f"Bearer {api_key}"}
251+
return {"Authorization": f"Bearer {access_token}"}
236252

237253
@property
238254
@override
@@ -245,24 +261,28 @@ def default_headers(self) -> dict[str, str | Omit]:
245261

246262
@override
247263
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
248-
if (self.api_key or self.agent_key or self.inference_key) and headers.get(
249-
"Authorization"
250-
):
264+
if (
265+
self.access_token or self.agent_access_key or self.model_access_key
266+
) and headers.get("Authorization"):
251267
return
252268
if isinstance(custom_headers.get("Authorization"), Omit):
253269
return
254270

255271
raise TypeError(
256-
'"Could not resolve authentication method. Expected api_key, agent_key, or inference_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
272+
'"Could not resolve authentication method. Expected access_token, agent_access_key, or model_access_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
257273
)
258274

259275
def copy(
260276
self,
261277
*,
262-
api_key: str | None = None,
263-
inference_key: str | None = None,
264-
agent_key: str | None = None,
278+
api_key: str | None = None, # deprecated, use `access_token` instead
279+
inference_key: str | None = None, # deprecated, use `model_access_key` instead
280+
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
281+
access_token: str | None = None,
282+
model_access_key: str | None = None,
283+
agent_access_key: str | None = None,
265284
agent_endpoint: str | None = None,
285+
inference_endpoint: str | None = None,
266286
base_url: str | httpx.URL | None = None,
267287
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
268288
http_client: httpx.Client | None = None,
@@ -300,10 +320,11 @@ def copy(
300320

301321
http_client = http_client or self._client
302322
client = self.__class__(
303-
api_key=api_key or self.api_key,
304-
inference_key=inference_key or self.inference_key,
305-
agent_key=agent_key or self.agent_key,
306-
agent_endpoint=agent_endpoint or self._agent_endpoint,
323+
access_token=access_token or api_key or self.access_token,
324+
model_access_key=model_access_key or inference_key or self.model_access_key,
325+
agent_access_key=agent_access_key or agent_key or self.agent_access_key,
326+
agent_endpoint=agent_endpoint or self.agent_endpoint,
327+
inference_endpoint=inference_endpoint or self.inference_endpoint,
307328
base_url=base_url or self.base_url,
308329
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
309330
http_client=http_client,
@@ -363,18 +384,23 @@ def _make_status_error(
363384

364385
class AsyncGradient(AsyncAPIClient):
365386
# client options
366-
api_key: str | None
367-
inference_key: str | None
368-
agent_key: str | None
387+
access_token: str | None
388+
model_access_key: str | None
389+
agent_access_key: str | None
369390
_agent_endpoint: str | None
391+
inference_endpoint: str | None
370392

371393
def __init__(
372394
self,
373395
*,
374-
api_key: str | None = None,
375-
inference_key: str | None = None,
376-
agent_key: str | None = None,
396+
api_key: str | None = None, # deprecated, use `access_token` instead
397+
inference_key: str | None = None, # deprecated, use `model_access_key` instead
398+
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
399+
access_token: str | None = None,
400+
model_access_key: str | None = None,
401+
agent_access_key: str | None = None,
377402
agent_endpoint: str | None = None,
403+
inference_endpoint: str | None = None,
378404
base_url: str | httpx.URL | None = None,
379405
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
380406
max_retries: int = DEFAULT_MAX_RETRIES,
@@ -397,34 +423,45 @@ def __init__(
397423
"""Construct a new async AsyncGradient client instance.
398424
399425
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
400-
- `api_key` from `DIGITALOCEAN_ACCESS_TOKEN`
401-
- `inference_key` from `GRADIENT_MODEL_ACCESS_KEY`
402-
- `agent_key` from `GRADIENT_AGENT_ACCESS_KEY`
426+
- `access_token` from `DIGITALOCEAN_ACCESS_TOKEN`
427+
- `model_access_key` from `GRADIENT_MODEL_ACCESS_KEY`
428+
- `agent_access_key` from `GRADIENT_AGENT_ACCESS_KEY`
403429
"""
404-
if api_key is None:
405-
api_key = os.environ.get("DIGITALOCEAN_ACCESS_TOKEN")
406-
# support for legacy environment variable
407-
if api_key is None:
408-
api_key = os.environ.get("GRADIENT_API_KEY")
409-
self.api_key = api_key
410-
411-
if inference_key is None:
412-
inference_key = os.environ.get("GRADIENT_MODEL_ACCESS_KEY")
413-
# support for legacy environment variable
414-
if inference_key is None:
415-
inference_key = os.environ.get("GRADIENT_INFERENCE_KEY")
416-
self.api_key = api_key
417-
self.inference_key = inference_key
418-
419-
if agent_key is None:
420-
agent_key = os.environ.get("GRADIENT_AGENT_ACCESS_KEY")
421-
# support for legacy environment variable
422-
if agent_key is None:
423-
agent_key = os.environ.get("GRADIENT_AGENT_KEY")
424-
self.agent_key = agent_key
430+
if access_token is None:
431+
if api_key is not None:
432+
access_token = api_key
433+
else:
434+
access_token = os.environ.get("DIGITALOCEAN_ACCESS_TOKEN")
435+
# support for legacy environment variable
436+
if access_token is None:
437+
access_token = os.environ.get("GRADIENT_API_KEY")
438+
self.access_token = access_token
439+
440+
441+
if model_access_key is None:
442+
if inference_key is not None:
443+
model_access_key = inference_key
444+
else:
445+
model_access_key = os.environ.get("GRADIENT_INFERENCE_KEY")
446+
# support for legacy environment variable
447+
if model_access_key is None:
448+
model_access_key = os.environ.get("GRADIENT_MODEL_ACCESS_KEY")
449+
self.model_access_key = model_access_key
450+
451+
if agent_access_key is None:
452+
if agent_key is not None:
453+
agent_access_key = agent_key
454+
else:
455+
agent_access_key = os.environ.get("GRADIENT_AGENT_ACCESS_KEY")
456+
# support for legacy environment variable
457+
if agent_access_key is None:
458+
agent_access_key = os.environ.get("GRADIENT_AGENT_KEY")
459+
self.agent_access_key = agent_access_key
425460

426461
self._agent_endpoint = agent_endpoint
427462

463+
self.inference_endpoint = inference_endpoint
464+
428465
if base_url is None:
429466
base_url = os.environ.get("GRADIENT_BASE_URL")
430467
self._base_url_overridden = base_url is not None
@@ -521,10 +558,10 @@ def qs(self) -> Querystring:
521558
@property
522559
@override
523560
def auth_headers(self) -> dict[str, str]:
524-
api_key = self.api_key
525-
if api_key is None:
561+
access_token = self.access_token
562+
if access_token is None:
526563
return {}
527-
return {"Authorization": f"Bearer {api_key}"}
564+
return {"Authorization": f"Bearer {access_token}"}
528565

529566
@property
530567
@override
@@ -537,24 +574,28 @@ def default_headers(self) -> dict[str, str | Omit]:
537574

538575
@override
539576
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
540-
if (self.api_key or self.agent_key or self.inference_key) and headers.get(
541-
"Authorization"
542-
):
577+
if (
578+
self.access_token or self.agent_access_key or self.model_access_key
579+
) and headers.get("Authorization"):
543580
return
544581
if isinstance(custom_headers.get("Authorization"), Omit):
545582
return
546583

547584
raise TypeError(
548-
'"Could not resolve authentication method. Expected api_key, agent_key, or inference_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
585+
'"Could not resolve authentication method. Expected access_token, agent_access_key, or model_access_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
549586
)
550587

551588
def copy(
552589
self,
553590
*,
554-
api_key: str | None = None,
555-
inference_key: str | None = None,
556-
agent_key: str | None = None,
591+
api_key: str | None = None, # deprecated, use `access_token` instead
592+
inference_key: str | None = None, # deprecated, use `model_access_key` instead
593+
agent_key: str | None = None, # deprecated, use `agent_access_key` instead
557594
agent_endpoint: str | None = None,
595+
access_token: str | None = None,
596+
model_access_key: str | None = None,
597+
agent_access_key: str | None = None,
598+
inference_endpoint: str | None = None,
558599
base_url: str | httpx.URL | None = None,
559600
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
560601
http_client: httpx.AsyncClient | None = None,
@@ -592,10 +633,23 @@ def copy(
592633

593634
http_client = http_client or self._client
594635
client = self.__class__(
636+
<<<<<<< HEAD
595637
api_key=api_key or self.api_key,
596638
inference_key=inference_key or self.inference_key,
597639
agent_key=agent_key or self.agent_key,
598640
agent_endpoint=agent_endpoint or self._agent_endpoint,
641+
||||||| eb1dcf7
642+
api_key=api_key or self.api_key,
643+
inference_key=inference_key or self.inference_key,
644+
agent_key=agent_key or self.agent_key,
645+
agent_domain=agent_domain or self.agent_domain,
646+
=======
647+
access_token=access_token or self.access_token,
648+
model_access_key=model_access_key or self.model_access_key,
649+
agent_access_key=agent_access_key or self.agent_access_key,
650+
agent_endpoint=agent_endpoint or self.agent_endpoint,
651+
inference_endpoint=inference_endpoint or self.inference_endpoint,
652+
>>>>>>> origin/generated--merge-conflict
599653
base_url=base_url or self.base_url,
600654
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
601655
http_client=http_client,

0 commit comments

Comments
 (0)