Skip to content

Update to include the generative service. #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/test_pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ jobs:
- name: Run tests
run: |
python --version
pip install -q -e .[dev]
python -m unittest discover --pattern '*test*.py'
pip install .[dev]
python -m unittest
test3_10:
name: Test Py3.10
runs-on: ubuntu-latest
Expand All @@ -36,8 +36,8 @@ jobs:
- name: Run tests
run: |
python --version
pip install -q -e .[dev]
python -m unittest discover --pattern '*test*.py'
pip install -q .[dev]
python -m unittest
test3_9:
name: Test Py3.9
runs-on: ubuntu-latest
Expand All @@ -49,8 +49,8 @@ jobs:
- name: Run tests
run: |
python --version
pip install -q -e .[dev]
python -m unittest discover --pattern '*test*.py'
pip install .[dev]
python -m unittest
pytype3_10:
name: pytype 3.10
runs-on: ubuntu-latest
Expand All @@ -62,7 +62,7 @@ jobs:
- name: Run pytype
run: |
python --version
pip install -q -e .[dev]
pip install .[dev]
pip install -q gspread ipython
pytype
format:
Expand All @@ -76,7 +76,7 @@ jobs:
- name: Check format
run: |
python --version
pip install -q -e .
pip install -q .
pip install -q black
black . --check

16 changes: 12 additions & 4 deletions google/generativeai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@
Use the `palm.chat` function to have a discussion with a model:

```
response = palm.chat(messages=["Hello."])
print(response.last) # 'Hello! What can I help you with?'
response.reply("Can you tell me a joke?")
chat = palm.chat(messages=["Hello."])
print(chat.last) # 'Hello! What can I help you with?'
chat = chat.reply("Can you tell me a joke?")
print(chat.last) # 'Why did the chicken cross the road?'
```

## Models
Expand All @@ -68,13 +69,20 @@
"""
from __future__ import annotations

from google.generativeai import types
from google.generativeai import version

from google.generativeai import types
from google.generativeai.types import GenerationConfig


from google.generativeai.discuss import chat
from google.generativeai.discuss import chat_async
from google.generativeai.discuss import count_message_tokens

from google.generativeai.embedding import embed_content

from google.generativeai.generative_models import GenerativeModel

from google.generativeai.text import generate_text
from google.generativeai.text import generate_embeddings
from google.generativeai.text import count_text_tokens
Expand Down
112 changes: 54 additions & 58 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import os
Expand All @@ -27,7 +13,12 @@
from google.api_core import gapic_v1
from google.api_core import operations_v1

from google.generativeai import version
try:
from google.generativeai import version

__version__ = version.__version__
except ImportError:
__version__ = "0.0.0"

USER_AGENT = "genai-py"

Expand All @@ -36,11 +27,10 @@
class _ClientManager:
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
default_metadata: Sequence[tuple[str, str]] = ()

discuss_client: glm.DiscussServiceClient | None = None
discuss_async_client: glm.DiscussServiceAsyncClient | None = None
model_client: glm.ModelServiceClient | None = None
text_client: glm.TextServiceClient | None = None
operations_client = None
clients: dict[str, Any] = dataclasses.field(default_factory=dict)

def configure(
self,
Expand All @@ -54,7 +44,7 @@ def configure(
# We could accept a dict since all the `Transport` classes take the same args,
# but that seems rare. Users that need it can just switch to the low level API.
transport: str | None = None,
client_options: client_options_lib.ClientOptions | dict | None = None,
client_options: client_options_lib.ClientOptions | dict[str, Any] | None = None,
client_info: gapic_v1.client_info.ClientInfo | None = None,
default_metadata: Sequence[tuple[str, str]] = (),
) -> None:
Expand Down Expand Up @@ -93,7 +83,7 @@ def configure(

client_options.api_key = api_key

user_agent = f"{USER_AGENT}/{version.__version__}"
user_agent = f"{USER_AGENT}/{__version__}"
if client_info:
# Be respectful of any existing agent setting.
if client_info.user_agent:
Expand All @@ -114,12 +104,16 @@ def configure(

self.client_config = client_config
self.default_metadata = default_metadata
self.discuss_client = None
self.text_client = None
self.model_client = None
self.operations_client = None

def make_client(self, cls):
self.clients = {}

def make_client(self, name):
if name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")

# Attempt to configure using defaults.
if not self.client_config:
configure()
Expand Down Expand Up @@ -157,35 +151,25 @@ def call(*args, metadata=(), **kwargs):

return client

def get_default_discuss_client(self) -> glm.DiscussServiceClient:
if self.discuss_client is None:
self.discuss_client = self.make_client(glm.DiscussServiceClient)
return self.discuss_client

def get_default_text_client(self) -> glm.TextServiceClient:
if self.text_client is None:
self.text_client = self.make_client(glm.TextServiceClient)
return self.text_client

def get_default_discuss_async_client(self) -> glm.DiscussServiceAsyncClient:
if self.discuss_async_client is None:
self.discuss_async_client = self.make_client(glm.DiscussServiceAsyncClient)
return self.discuss_async_client
def get_default_client(self, name):
name = name.lower()
if name == "operations":
return self.get_default_operations_client()

def get_default_model_client(self) -> glm.ModelServiceClient:
if self.model_client is None:
self.model_client = self.make_client(glm.ModelServiceClient)
return self.model_client
client = self.clients.get(name)
if client is None:
client = self.make_client(name)
self.clients[name] = client
return client

def get_default_operations_client(self) -> operations_v1.OperationsClient:
if self.operations_client is None:
self.model_client = get_default_model_client()
self.operations_client = self.model_client._transport.operations_client

return self.operations_client

client = self.clients.get("operations", None)
if client is None:
model_client = self.get_default_client("Model")
client = model_client._transport.operations_client
self.clients["operations"] = client

_client_manager = _ClientManager()
return client


def configure(
Expand Down Expand Up @@ -230,21 +214,33 @@ def configure(
)


_client_manager = _ClientManager()
_client_manager.configure()


def get_default_discuss_client() -> glm.DiscussServiceClient:
return _client_manager.get_default_discuss_client()
return _client_manager.get_default_client("discuss")


def get_default_text_client() -> glm.TextServiceClient:
return _client_manager.get_default_text_client()
def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
return _client_manager.get_default_client("discuss_async")


def get_default_operations_client() -> operations_v1.OperationsClient:
return _client_manager.get_default_operations_client()
def get_default_generative_client() -> glm.GenerativeServiceClient:
return _client_manager.get_default_client("generative")


def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
return _client_manager.get_default_discuss_async_client()
def get_default_generative_async_client() -> glm.GenerativeServiceAsyncClient:
return _client_manager.get_default_client("generative_async")


def get_default_text_client() -> glm.TextServiceClient:
return _client_manager.get_default_client("text")


def get_default_operations_client() -> operations_v1.OperationsClient:
return _client_manager.get_default_client("operations")


def get_default_model_client() -> glm.ModelServiceAsyncClient:
return _client_manager.get_default_model_client()
return _client_manager.get_default_client("model")
20 changes: 5 additions & 15 deletions google/generativeai/discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,6 @@ def _make_generate_message_request(
)


def set_doc(doc):
"""A decorator to set the docstring of a function."""

def inner(f):
f.__doc__ = doc
return f

return inner


DEFAULT_DISCUSS_MODEL = "models/chat-bison-001"


Expand Down Expand Up @@ -411,7 +401,7 @@ def chat(
return _generate_response(client=client, request=request)


@set_doc(chat.__doc__)
@string_utils.set_doc(chat.__doc__)
async def chat_async(
*,
model: model_types.AnyModelNameOptions | None = "models/chat-bison-001",
Expand Down Expand Up @@ -447,7 +437,7 @@ async def chat_async(


@string_utils.prettyprint
@set_doc(discuss_types.ChatResponse.__doc__)
@string_utils.set_doc(discuss_types.ChatResponse.__doc__)
@dataclasses.dataclass(**DATACLASS_KWARGS, init=False)
class ChatResponse(discuss_types.ChatResponse):
_client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False)
Expand All @@ -457,7 +447,7 @@ def __init__(self, **kwargs):
setattr(self, key, value)

@property
@set_doc(discuss_types.ChatResponse.last.__doc__)
@string_utils.set_doc(discuss_types.ChatResponse.last.__doc__)
def last(self) -> str | None:
if self.messages[-1]:
return self.messages[-1]["content"]
Expand All @@ -470,7 +460,7 @@ def last(self, message: discuss_types.MessageOptions):
message = type(message).to_dict(message)
self.messages[-1] = message

@set_doc(discuss_types.ChatResponse.reply.__doc__)
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResponse:
if isinstance(self._client, glm.DiscussServiceAsyncClient):
raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
Expand All @@ -489,7 +479,7 @@ def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResp
request = _make_generate_message_request(**request)
return _generate_response(request=request, client=self._client)

@set_doc(discuss_types.ChatResponse.reply.__doc__)
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
async def reply_async(
self, message: discuss_types.MessageOptions
) -> discuss_types.ChatResponse:
Expand Down
Loading