Skip to content

Commit 6853d05

Browse files
fix(config): align environment variables with other DO tools and console (#40) (#41)
Co-authored-by: Ben Batha <[email protected]>
1 parent ed70ab7 commit 6853d05

File tree

3 files changed

+350
-105
lines changed

3 files changed

+350
-105
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ client = Gradient(
4646
)
4747
inference_client = Gradient(
4848
inference_key=os.environ.get(
49-
"GRADIENT_INFERENCE_KEY"
49+
"GRADIENT_MODEL_ACCESS_KEY"
5050
), # This is the default and can be omitted
5151
)
5252
agent_client = Gradient(
53-
agent_key=os.environ.get("GRADIENT_AGENT_KEY"), # This is the default and can be omitted
53+
agent_key=os.environ.get("GRADIENT_AGENT_ACCESS_KEY"), # This is the default and can be omitted
5454
agent_endpoint="https://my-agent.agents.do-ai.run",
5555
)
5656

@@ -92,7 +92,7 @@ print(agent_response.choices[0].message.content)
9292

9393
While you can provide an `api_key`, `inference_key` keyword argument,
9494
we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/)
95-
to add `GRADIENT_API_KEY="My API Key"`, `GRADIENT_INFERENCE_KEY="My INFERENCE Key"` to your `.env` file
95+
to add `DIGITALOCEAN_ACCESS_TOKEN="My API Key"`, `GRADIENT_MODEL_ACCESS_KEY="My INFERENCE Key"` to your `.env` file
9696
so that your keys are not stored in source control.
9797

9898
## Async usage

src/gradient/_client.py

Lines changed: 75 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,16 @@
3232
)
3333

3434
if TYPE_CHECKING:
35-
from .resources import chat, agents, models, regions, databases, inference, gpu_droplets, knowledge_bases
35+
from .resources import (
36+
chat,
37+
agents,
38+
models,
39+
regions,
40+
databases,
41+
inference,
42+
gpu_droplets,
43+
knowledge_bases,
44+
)
3645
from .resources.regions import RegionsResource, AsyncRegionsResource
3746
from .resources.chat.chat import ChatResource, AsyncChatResource
3847
from .resources.gpu_droplets import (
@@ -102,14 +111,24 @@ def __init__(
102111
"""
103112
if api_key is None:
104113
api_key = os.environ.get("DIGITALOCEAN_ACCESS_TOKEN")
114+
# support for legacy environment variable
115+
if api_key is None:
116+
api_key = os.environ.get("GRADIENT_API_KEY")
105117
self.api_key = api_key
106118

107119
if inference_key is None:
108120
inference_key = os.environ.get("GRADIENT_MODEL_ACCESS_KEY")
121+
# support for legacy environment variable
122+
if inference_key is None:
123+
inference_key = os.environ.get("GRADIENT_INFERENCE_KEY")
124+
109125
self.inference_key = inference_key
110126

111127
if agent_key is None:
112128
agent_key = os.environ.get("GRADIENT_AGENT_ACCESS_KEY")
129+
# support for legacy environment variable
130+
if agent_key is None:
131+
agent_key = os.environ.get("GRADIENT_AGENT_KEY")
113132
self.agent_key = agent_key
114133

115134
self._agent_endpoint = agent_endpoint
@@ -226,7 +245,9 @@ def default_headers(self) -> dict[str, str | Omit]:
226245

227246
@override
228247
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
229-
if (self.api_key or self.agent_key or self.inference_key) and headers.get("Authorization"):
248+
if (self.api_key or self.agent_key or self.inference_key) and headers.get(
249+
"Authorization"
250+
):
230251
return
231252
if isinstance(custom_headers.get("Authorization"), Omit):
232253
return
@@ -256,10 +277,14 @@ def copy(
256277
Create a new client instance re-using the same options given to the current client with optional overriding.
257278
"""
258279
if default_headers is not None and set_default_headers is not None:
259-
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
280+
raise ValueError(
281+
"The `default_headers` and `set_default_headers` arguments are mutually exclusive"
282+
)
260283

261284
if default_query is not None and set_default_query is not None:
262-
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
285+
raise ValueError(
286+
"The `default_query` and `set_default_query` arguments are mutually exclusive"
287+
)
263288

264289
headers = self._custom_headers
265290
if default_headers is not None:
@@ -306,10 +331,14 @@ def _make_status_error(
306331
return _exceptions.BadRequestError(err_msg, response=response, body=body)
307332

308333
if response.status_code == 401:
309-
return _exceptions.AuthenticationError(err_msg, response=response, body=body)
334+
return _exceptions.AuthenticationError(
335+
err_msg, response=response, body=body
336+
)
310337

311338
if response.status_code == 403:
312-
return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
339+
return _exceptions.PermissionDeniedError(
340+
err_msg, response=response, body=body
341+
)
313342

314343
if response.status_code == 404:
315344
return _exceptions.NotFoundError(err_msg, response=response, body=body)
@@ -318,13 +347,17 @@ def _make_status_error(
318347
return _exceptions.ConflictError(err_msg, response=response, body=body)
319348

320349
if response.status_code == 422:
321-
return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
350+
return _exceptions.UnprocessableEntityError(
351+
err_msg, response=response, body=body
352+
)
322353

323354
if response.status_code == 429:
324355
return _exceptions.RateLimitError(err_msg, response=response, body=body)
325356

326357
if response.status_code >= 500:
327-
return _exceptions.InternalServerError(err_msg, response=response, body=body)
358+
return _exceptions.InternalServerError(
359+
err_msg, response=response, body=body
360+
)
328361
return APIStatusError(err_msg, response=response, body=body)
329362

330363

@@ -370,14 +403,24 @@ def __init__(
370403
"""
371404
if api_key is None:
372405
api_key = os.environ.get("DIGITALOCEAN_ACCESS_TOKEN")
406+
# support for legacy environment variable
407+
if api_key is None:
408+
api_key = os.environ.get("GRADIENT_API_KEY")
373409
self.api_key = api_key
374410

375411
if inference_key is None:
376412
inference_key = os.environ.get("GRADIENT_MODEL_ACCESS_KEY")
413+
# support for legacy environment variable
414+
if inference_key is None:
415+
inference_key = os.environ.get("GRADIENT_INFERENCE_KEY")
416+
self.api_key = api_key
377417
self.inference_key = inference_key
378418

379419
if agent_key is None:
380420
agent_key = os.environ.get("GRADIENT_AGENT_ACCESS_KEY")
421+
# support for legacy environment variable
422+
if agent_key is None:
423+
agent_key = os.environ.get("GRADIENT_AGENT_KEY")
381424
self.agent_key = agent_key
382425

383426
self._agent_endpoint = agent_endpoint
@@ -494,7 +537,9 @@ def default_headers(self) -> dict[str, str | Omit]:
494537

495538
@override
496539
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
497-
if (self.api_key or self.agent_key or self.inference_key) and headers.get("Authorization"):
540+
if (self.api_key or self.agent_key or self.inference_key) and headers.get(
541+
"Authorization"
542+
):
498543
return
499544
if isinstance(custom_headers.get("Authorization"), Omit):
500545
return
@@ -524,10 +569,14 @@ def copy(
524569
Create a new client instance re-using the same options given to the current client with optional overriding.
525570
"""
526571
if default_headers is not None and set_default_headers is not None:
527-
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
572+
raise ValueError(
573+
"The `default_headers` and `set_default_headers` arguments are mutually exclusive"
574+
)
528575

529576
if default_query is not None and set_default_query is not None:
530-
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
577+
raise ValueError(
578+
"The `default_query` and `set_default_query` arguments are mutually exclusive"
579+
)
531580

532581
headers = self._custom_headers
533582
if default_headers is not None:
@@ -574,10 +623,14 @@ def _make_status_error(
574623
return _exceptions.BadRequestError(err_msg, response=response, body=body)
575624

576625
if response.status_code == 401:
577-
return _exceptions.AuthenticationError(err_msg, response=response, body=body)
626+
return _exceptions.AuthenticationError(
627+
err_msg, response=response, body=body
628+
)
578629

579630
if response.status_code == 403:
580-
return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
631+
return _exceptions.PermissionDeniedError(
632+
err_msg, response=response, body=body
633+
)
581634

582635
if response.status_code == 404:
583636
return _exceptions.NotFoundError(err_msg, response=response, body=body)
@@ -586,13 +639,17 @@ def _make_status_error(
586639
return _exceptions.ConflictError(err_msg, response=response, body=body)
587640

588641
if response.status_code == 422:
589-
return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
642+
return _exceptions.UnprocessableEntityError(
643+
err_msg, response=response, body=body
644+
)
590645

591646
if response.status_code == 429:
592647
return _exceptions.RateLimitError(err_msg, response=response, body=body)
593648

594649
if response.status_code >= 500:
595-
return _exceptions.InternalServerError(err_msg, response=response, body=body)
650+
return _exceptions.InternalServerError(
651+
err_msg, response=response, body=body
652+
)
596653
return APIStatusError(err_msg, response=response, body=body)
597654

598655

@@ -811,7 +868,9 @@ def knowledge_bases(
811868
AsyncKnowledgeBasesResourceWithStreamingResponse,
812869
)
813870

814-
return AsyncKnowledgeBasesResourceWithStreamingResponse(self._client.knowledge_bases)
871+
return AsyncKnowledgeBasesResourceWithStreamingResponse(
872+
self._client.knowledge_bases
873+
)
815874

816875
@cached_property
817876
def models(self) -> models.AsyncModelsResourceWithStreamingResponse:

0 commit comments

Comments
 (0)