Skip to content

Add rev14 parameters and fixes. #561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions google/generativeai/types/generation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,27 @@ class GenerationConfig:
Note: The default value varies by model, see the
`Model.top_k` attribute of the `Model` returned the
`genai.get_model` function.

seed:
Optional. Seed used in decoding. If not set, the request uses a randomly generated seed.
response_mime_type:
Optional. Output response mimetype of the generated candidate text.

Supported mimetype:
`text/plain`: (default) Text output.
`text/x-enum`: for use with a string-enum in `response_schema`
`application/json`: JSON response in the candidates.

response_schema:
Optional. Specifies the format of the JSON requested if response_mime_type is
`application/json`.
presence_penalty:
Optional.
frequency_penalty:
Optional.
response_logprobs:
Optional. If true, export the `logprobs` results in response.
logprobs:
Optional. Number of candidates of log probabilities to return at each step of decoding.
"""

candidate_count: int | None = None
Expand All @@ -163,8 +173,13 @@ class GenerationConfig:
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
seed: int | None = None
response_mime_type: str | None = None
response_schema: protos.Schema | Mapping[str, Any] | type | None = None
presence_penalty: float | None = None
frequency_penalty: float | None = None
response_logprobs: bool | None = None
logprobs: int | None = None


GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig]
Expand Down Expand Up @@ -306,6 +321,7 @@ def _join_code_execution_result(result_1, result_2):


def _join_candidates(candidates: Iterable[protos.Candidate]):
"""Joins stream chunks of a single candidate."""
candidates = tuple(candidates)

index = candidates[0].index # These should all be the same.
Expand All @@ -321,6 +337,7 @@ def _join_candidates(candidates: Iterable[protos.Candidate]):


def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]):
"""Joins stream chunks where each chunk is a list of candidate chunks."""
# Assuming that is a candidate ends, it is no longer returned in the list of
# candidates and that's why candidates have an index
candidates = collections.defaultdict(list)
Expand All @@ -344,10 +361,15 @@ def _join_prompt_feedbacks(

def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]):
chunks = tuple(chunks)
if "usage_metadata" in chunks[-1]:
usage_metadata = chunks[-1].usage_metadata
else:
usage_metadata = None

return protos.GenerateContentResponse(
candidates=_join_candidate_lists(c.candidates for c in chunks),
prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
usage_metadata=chunks[-1].usage_metadata,
usage_metadata=usage_metadata,
)


Expand Down Expand Up @@ -541,7 +563,8 @@ def __str__(self) -> str:
_result = _result.replace("\n", "\n ")

if self._error:
_error = f",\nerror=<{self._error.__class__.__name__}> {self._error}"

_error = f",\nerror={repr(self._error)}"
else:
_error = ""

Expand Down
5 changes: 4 additions & 1 deletion google/generativeai/types/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def idecode_time(parent: dict["str", Any], name: str):

def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel:
if isinstance(tuned_model, protos.TunedModel):
tuned_model = type(tuned_model).to_dict(tuned_model) # pytype: disable=attribute-error
tuned_model = type(tuned_model).to_dict(
tuned_model, including_default_value_fields=False
) # pytype: disable=attribute-error
tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None))

base_model = tuned_model.pop("base_model", None)
Expand Down Expand Up @@ -195,6 +197,7 @@ class TunedModel:
create_time: datetime.datetime | None = None
update_time: datetime.datetime | None = None
tuning_task: TuningTask | None = None
reader_project_numbers: list[int] | None = None

@property
def permissions(self) -> permission_types.Permissions:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_version():
release_status = "Development Status :: 5 - Production/Stable"

dependencies = [
"google-ai-generativelanguage==0.6.9",
"google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz",
"google-api-core",
"google-api-python-client",
"google-auth>=2.15.0", # 2.15 adds API key auth support
Expand Down
12 changes: 7 additions & 5 deletions tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from google.generativeai.types import file_types

import collections
import datetime
import os
from typing import Iterable, Union
from typing import Iterable, Sequence
import pathlib

import google
Expand All @@ -37,12 +38,13 @@ def __init__(self, test):

def create_file(
self,
path: Union[str, pathlib.Path, os.PathLike],
path: str | pathlib.Path | os.PathLike,
*,
mime_type: Union[str, None] = None,
name: Union[str, None] = None,
display_name: Union[str, None] = None,
mime_type: str | None = None,
name: str | None = None,
display_name: str | None = None,
resumable: bool = True,
metadata: Sequence[tuple[str, str]] = (),
) -> protos.File:
self.observed_requests.append(
dict(
Expand Down
53 changes: 37 additions & 16 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
import string
import textwrap
from typing_extensions import TypedDict
Expand All @@ -22,6 +38,8 @@ class Person(TypedDict):


class UnitTests(parameterized.TestCase):
maxDiff = None

@parameterized.named_parameters(
[
"protos.GenerationConfig",
Expand Down Expand Up @@ -416,24 +434,16 @@ def test_join_prompt_feedbacks(self):
],
"role": "assistant",
},
"citation_metadata": {"citation_sources": []},
"index": 0,
"finish_reason": 0,
"safety_ratings": [],
"token_count": 0,
"grounding_attributions": [],
"citation_metadata": {},
},
{
"content": {
"parts": [{"text": "Tell me a story about a magic backpack"}],
"role": "assistant",
},
"index": 1,
"citation_metadata": {"citation_sources": []},
"finish_reason": 0,
"safety_ratings": [],
"token_count": 0,
"grounding_attributions": [],
"citation_metadata": {},
},
{
"content": {
Expand All @@ -458,17 +468,16 @@ def test_join_prompt_feedbacks(self):
},
]
},
"finish_reason": 0,
"safety_ratings": [],
"token_count": 0,
"grounding_attributions": [],
},
]

def test_join_candidates(self):
candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS]
result = generation_types._join_candidate_lists(candidate_lists)
self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r) for r in result])
self.assertEqual(
self.MERGED_CANDIDATES,
[type(r).to_dict(r, including_default_value_fields=False) for r in result],
)

def test_join_chunks(self):
chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS]
Expand All @@ -480,6 +489,10 @@ def test_join_chunks(self):
],
)

chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata(
prompt_token_count=5
)

result = generation_types._join_chunks(chunks)

expected = protos.GenerateContentResponse(
Expand All @@ -495,10 +508,18 @@ def test_join_chunks(self):
}
],
},
"usage_metadata": {"prompt_token_count": 5},
},
)

self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(expected))
expected = json.dumps(
type(expected).to_dict(expected, including_default_value_fields=False), indent=4
)
result = json.dumps(
type(result).to_dict(result, including_default_value_fields=False), indent=4
)

self.assertEqual(expected, result)

def test_generate_content_response_iterator_end_to_end(self):
chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS]
Expand Down
53 changes: 7 additions & 46 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,8 +935,7 @@ def test_repr_for_streaming_start_to_finish(self):
"citation_metadata": {}
}
],
"prompt_feedback": {},
"usage_metadata": {}
"prompt_feedback": {}
}),
)"""
)
Expand Down Expand Up @@ -964,8 +963,7 @@ def test_repr_for_streaming_start_to_finish(self):
"citation_metadata": {}
}
],
"prompt_feedback": {},
"usage_metadata": {}
"prompt_feedback": {}
}),
)"""
)
Expand Down Expand Up @@ -998,10 +996,10 @@ def test_repr_error_info_for_stream_prompt_feedback_blocked(self):
}
}),
),
error=<BlockedPromptException> prompt_feedback {
error=BlockedPromptException(prompt_feedback {
block_reason: SAFETY
}
"""
)"""
)
self.assertEqual(expected, result)

Expand Down Expand Up @@ -1056,11 +1054,10 @@ def no_throw():
"citation_metadata": {}
}
],
"prompt_feedback": {},
"usage_metadata": {}
"prompt_feedback": {}
}),
),
error=<ValueError> """
error=ValueError()"""
)
self.assertEqual(expected, result)

Expand Down Expand Up @@ -1095,43 +1092,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self):
response = chat.send_message("hello2", stream=True)

result = repr(response)
expected = textwrap.dedent(
"""\
response:
GenerateContentResponse(
done=True,
iterator=None,
result=protos.GenerateContentResponse({
"candidates": [
{
"content": {
"parts": [
{
"text": "abc"
}
]
},
"finish_reason": "SAFETY",
"index": 0,
"citation_metadata": {}
}
],
"prompt_feedback": {},
"usage_metadata": {}
}),
),
error=<StopCandidateException> content {
parts {
text: "abc"
}
}
finish_reason: SAFETY
index: 0
citation_metadata {
}
"""
)
self.assertEqual(expected, result)
self.assertIn("StopCandidateException", result)

def test_repr_for_multi_turn_chat(self):
# Multi turn chat
Expand Down
Loading