Skip to content

Add metadata handling. #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 182 additions & 100 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from __future__ import annotations

import os
from typing import cast, Optional, Union
import dataclasses
import types
from typing import Any, cast
from collections.abc import Sequence

import google.ai.generativelanguage as glm

Expand All @@ -26,15 +29,163 @@

from google.generativeai import version


USER_AGENT = "genai-py"

default_client_config = {}
default_discuss_client = None
default_discuss_async_client = None
default_model_client = None
default_text_client = None
default_operations_client = None

@dataclasses.dataclass
class _ClientManager:
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
default_metadata: Sequence[tuple[str, str]] = ()
discuss_client: glm.DiscussServiceClient | None = None
discuss_async_client: glm.DiscussServiceAsyncClient | None = None
model_client: glm.ModelServiceClient | None = None
text_client: glm.TextServiceClient | None = None
operations_client = None

def configure(
self,
*,
api_key: str | None = None,
credentials: ga_credentials.Credentials | dict | None = None,
# The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'.
# See `_transport_registry` in `DiscussServiceClientMeta`.
# Since the transport classes align with the client classes it wouldn't make
# sense to accept a `Transport` object here even though the client classes can.
# We could accept a dict since all the `Transport` classes take the same args,
# but that seems rare. Users that need it can just switch to the low level API.
transport: str | None = None,
client_options: client_options_lib.ClientOptions | dict | None = None,
client_info: gapic_v1.client_info.ClientInfo | None = None,
default_metadata: Sequence[tuple[str, str]] = (),
) -> None:
"""Captures default client configuration.

If no API key has been provided (either directly, or on `client_options`) and the
`GOOGLE_API_KEY` environment variable is set, it will be used as the API key.

Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in
`google.ai.generativelanguage` for details on the other arguments.

Args:
transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
api_key: The API-Key to use when creating the default clients (each service uses
a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
used.
default_metadata: Default (key, value) metadata pairs to send with every request.
when using `transport="rest"` these are sent as HTTP headers.
"""
if isinstance(client_options, dict):
client_options = client_options_lib.from_dict(client_options)
if client_options is None:
client_options = client_options_lib.ClientOptions()
client_options = cast(client_options_lib.ClientOptions, client_options)
had_api_key_value = getattr(client_options, "api_key", None)

if had_api_key_value:
if api_key is not None:
raise ValueError("You can't set both `api_key` and `client_options['api_key']`.")
else:
if api_key is None:
# If no key is provided explicitly, attempt to load one from the
# environment.
api_key = os.getenv("GOOGLE_API_KEY")

client_options.api_key = api_key

user_agent = f"{USER_AGENT}/{version.__version__}"
if client_info:
# Be respectful of any existing agent setting.
if client_info.user_agent:
client_info.user_agent += f" {user_agent}"
else:
client_info.user_agent = user_agent
else:
client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)

client_config = {
"credentials": credentials,
"transport": transport,
"client_options": client_options,
"client_info": client_info,
}

client_config = {key: value for key, value in client_config.items() if value is not None}

self.client_config = client_config
self.default_metadata = default_metadata
self.discuss_client = None
self.text_client = None
self.model_client = None
self.operations_client = None

def make_client(self, cls):
# Attempt to configure using defaults.
if not self.client_config:
configure()

client = cls(**self.client_config)

if not self.default_metadata:
return client

def keep(name, f):
if name.startswith("_"):
return False
elif not isinstance(f, types.FunctionType):
return False
elif isinstance(f, classmethod):
return False
elif isinstance(f, staticmethod):
return False
else:
return True

def add_default_metadata_wrapper(f):
def call(*args, metadata=(), **kwargs):
metadata = list(metadata) + list(self.default_metadata)
return f(*args, **kwargs, metadata=metadata)

return call

for name, value in cls.__dict__.items():
if not keep(name, value):
continue
f = getattr(client, name)
f = add_default_metadata_wrapper(f)
setattr(client, name, f)

return client

def get_default_discuss_client(self) -> glm.DiscussServiceClient:
if self.discuss_client is None:
self.discuss_client = self.make_client(glm.DiscussServiceClient)
return self.discuss_client

def get_default_text_client(self) -> glm.TextServiceClient:
if self.text_client is None:
self.text_client = self.make_client(glm.TextServiceClient)
return self.text_client

def get_default_discuss_async_client(self) -> glm.DiscussServiceAsyncClient:
if self.discuss_async_client is None:
self.discuss_async_client = self.make_client(glm.DiscussServiceAsyncClient)
return self.discuss_async_client

def get_default_model_client(self) -> glm.ModelServiceClient:
if self.model_client is None:
self.model_client = self.make_client(glm.ModelServiceClient)
return self.model_client

def get_default_operations_client(self) -> operations_v1.OperationsClient:
if self.operations_client is None:
self.model_client = get_default_model_client()
self.operations_client = self.model_client._transport.operations_client

return self.operations_client


_client_manager = _ClientManager()


def configure(
Expand All @@ -50,119 +201,50 @@ def configure(
transport: str | None = None,
client_options: client_options_lib.ClientOptions | dict | None = None,
client_info: gapic_v1.client_info.ClientInfo | None = None,
default_metadata: Sequence[tuple[str, str]] = (),
):
"""Captures default client configuration.

If no API key has been provided (either directly, or on `client_options`) and the
`GOOGLE_API_KEY` environment variable is set, it will be used as the API key.

Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in
`google.ai.generativelanguage` for details on the other arguments.

Args:
Refer to `glm.DiscussServiceClient`, and `glm.ModelsServiceClient` for details on additional arguments.
transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`].
api_key: The API-Key to use when creating the default clients (each service uses
a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
used.
default_metadata: Default (key, value) metadata pairs to send with every request.
when using `transport="rest"` these are sent as HTTP headers.
"""
global default_client_config
global default_discuss_client
global default_model_client
global default_text_client
global default_operations_client

if isinstance(client_options, dict):
client_options = client_options_lib.from_dict(client_options)
if client_options is None:
client_options = client_options_lib.ClientOptions()
client_options = cast(client_options_lib.ClientOptions, client_options)
had_api_key_value = getattr(client_options, "api_key", None)

if had_api_key_value:
if api_key is not None:
raise ValueError("You can't set both `api_key` and `client_options['api_key']`.")
else:
if api_key is None:
# If no key is provided explicitly, attempt to load one from the
# environment.
api_key = os.getenv("GOOGLE_API_KEY")

client_options.api_key = api_key

user_agent = f"{USER_AGENT}/{version.__version__}"
if client_info:
# Be respectful of any existing agent setting.
if client_info.user_agent:
client_info.user_agent += f" {user_agent}"
else:
client_info.user_agent = user_agent
else:
client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)

new_default_client_config = {
"credentials": credentials,
"transport": transport,
"client_options": client_options,
"client_info": client_info,
}

new_default_client_config = {
key: value for key, value in new_default_client_config.items() if value is not None
}

default_client_config = new_default_client_config
default_discuss_client = None
default_text_client = None
default_model_client = None
default_operations_client = None
return _client_manager.configure(
api_key=api_key,
credentials=credentials,
transport=transport,
client_options=client_options,
client_info=client_info,
default_metadata=default_metadata,
)


def get_default_discuss_client() -> glm.DiscussServiceClient:
global default_discuss_client
if default_discuss_client is None:
# Attempt to configure using defaults.
if not default_client_config:
configure()
default_discuss_client = glm.DiscussServiceClient(**default_client_config)

return default_discuss_client
return _client_manager.get_default_discuss_client()


def get_default_text_client() -> glm.TextServiceClient:
global default_text_client
if default_text_client is None:
# Attempt to configure using defaults.
if not default_client_config:
configure()
default_text_client = glm.TextServiceClient(**default_client_config)

return default_text_client
return _client_manager.get_default_text_client()


def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
global default_discuss_async_client
if default_discuss_async_client is None:
# Attempt to configure using defaults.
if not default_client_config:
configure()
default_discuss_async_client = glm.DiscussServiceAsyncClient(**default_client_config)

return default_discuss_async_client


def get_default_model_client() -> glm.ModelServiceClient:
global default_model_client
if default_model_client is None:
# Attempt to configure using defaults.
if not default_client_config:
configure()
default_model_client = glm.ModelServiceClient(**default_client_config)
def get_default_operations_client() -> operations_v1.OperationsClient:
return _client_manager.get_default_operations_client()

return default_model_client

def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
return _client_manager.get_default_discuss_async_client()

def get_default_operations_client() -> operations_v1.OperationsClient:
global default_operations_client
if default_operations_client is None:
model_client = get_default_model_client()
default_operations_client = model_client._transport.operations_client

return default_operations_client
def get_default_model_client() -> glm.ModelServiceAsyncClient:
return _client_manager.get_default_model_client()
Loading