Skip to content

Add dataclass prettyprinting. #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions google/generativeai/discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from google.generativeai.client import get_default_discuss_client
from google.generativeai.client import get_default_discuss_async_client
from google.generativeai import string_utils
from google.generativeai.types import discuss_types
from google.generativeai.types import model_types
from google.generativeai.types import safety_types
Expand Down Expand Up @@ -445,6 +446,7 @@ async def chat_async(
DATACLASS_KWARGS = {}


@string_utils.prettyprint
@set_doc(discuss_types.ChatResponse.__doc__)
@dataclasses.dataclass(**DATACLASS_KWARGS, init=False)
class ChatResponse(discuss_types.ChatResponse):
Expand Down
22 changes: 0 additions & 22 deletions google/generativeai/docstring_utils.py

This file was deleted.

1 change: 1 addition & 0 deletions google/generativeai/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Iterator

from google.ai import generativelanguage as glm

from google.generativeai import client as client_lib
from google.generativeai.types import model_types
from google.api_core import operation as operation_lib
Expand Down
74 changes: 74 additions & 0 deletions google/generativeai/string_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import dataclasses
import pprint
import re
import reprlib
import textwrap


def strip_oneof(docstring):
lines = docstring.splitlines()
lines = [line for line in lines if ".. _oneof:" not in line]
lines = [line for line in lines if "This field is a member of `oneof`_" not in line]
return "\n".join(lines)


def prettyprint(cls):
cls.__str__ = _prettyprint
cls.__repr__ = _prettyprint
return cls


repr = reprlib.Repr()


@reprlib.recursive_repr()
def _prettyprint(self):
"""A dataclass prettyprint function you can use in __str__or __repr__.

Note: You can't set `__str__ = pprint.pformat` because it causes a recursion error.

Mostly identical to pprint but:

* This will contract long lists and dicts (> 10lines) to [...] and {...}.
* This will contract long object reprs to ClassName(...).
"""
fields = []
for f in dataclasses.fields(self):
s = pprint.pformat(getattr(self, f.name))
class_re = r"^(\w+)\(.*\)$"
if s.count("\n") >= 10:
if s.startswith("["):
s = "[...]"
elif s.startswith("{"):
s = "{...}"
elif re.match(class_re, s, flags=re.DOTALL):
s = re.sub(class_re, r"\1(...)", s, flags=re.DOTALL)
else:
s = "..."
else:
width = len(f.name) + 1
s = textwrap.indent(s, " " * width).lstrip(" ")
fields.append(f"{f.name}={s}")
attrs = ",\n".join(fields)

name = self.__class__.__name__
width = len(name) + 1

attrs = textwrap.indent(attrs, " " * width).lstrip(" ")
return f"{name}({attrs})"
2 changes: 2 additions & 0 deletions google/generativeai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import google.ai.generativelanguage as glm

from google.generativeai.client import get_default_text_client
from google.generativeai import string_utils
from google.generativeai.types import text_types
from google.generativeai.types import model_types
from google.generativeai.types import safety_types
Expand Down Expand Up @@ -175,6 +176,7 @@ def generate_text(
return _generate_response(client=client, request=request)


@string_utils.prettyprint
@dataclasses.dataclass(init=False)
class Completion(text_types.Completion):
def __init__(self, **kwargs):
Expand Down
7 changes: 4 additions & 3 deletions google/generativeai/types/citation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from typing import Optional, List

from google.ai import generativelanguage as glm
from google.generativeai import docstring_utils
from google.generativeai import string_utils

from typing import TypedDict

__all__ = [
Expand All @@ -32,10 +33,10 @@ class CitationSourceDict(TypedDict):
uri: str | None
license: str | None

__doc__ = docstring_utils.strip_oneof(glm.CitationSource.__doc__)
__doc__ = string_utils.strip_oneof(glm.CitationSource.__doc__)


class CitationMetadataDict(TypedDict):
citation_sources: List[CitationSourceDict | None]

__doc__ = docstring_utils.strip_oneof(glm.CitationMetadata.__doc__)
__doc__ = string_utils.strip_oneof(glm.CitationMetadata.__doc__)
3 changes: 3 additions & 0 deletions google/generativeai/types/discuss_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from typing import Any, Dict, TypedDict, Union, Iterable, Optional, Tuple, List

import google.ai.generativelanguage as glm
from google.generativeai import string_utils

from google.generativeai.types import safety_types
from google.generativeai.types import citation_types

Expand Down Expand Up @@ -97,6 +99,7 @@ class ResponseDict(TypedDict):
candidates: List[MessageDict]


@string_utils.prettyprint
@dataclasses.dataclass(init=False)
class ChatResponse(abc.ABC):
"""A chat response from the model.
Expand Down
7 changes: 7 additions & 0 deletions google/generativeai/types/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any, Iterable, TypedDict, Union

import google.ai.generativelanguage as glm
from google.generativeai import string_utils

__all__ = [
"Model",
Expand Down Expand Up @@ -65,6 +66,7 @@ def to_tuned_model_state(x: TunedModelStateOptions) -> TunedModelState:
return _TUNED_MODEL_STATES[x]


@string_utils.prettyprint
@dataclasses.dataclass
class Model:
"""A dataclass representation of a `glm.Model`.
Expand Down Expand Up @@ -152,6 +154,7 @@ def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedM
return TunedModel(**tuned_model)


@string_utils.prettyprint
@dataclasses.dataclass
class TunedModel:
"""A dataclass representation of a `glm.TunedModel`."""
Expand All @@ -170,6 +173,7 @@ class TunedModel:
tuning_task: TuningTask | None = None


@string_utils.prettyprint
@dataclasses.dataclass
class TuningTask:
start_time: datetime.datetime | None = None
Expand Down Expand Up @@ -208,6 +212,7 @@ def encode_tuning_example(example: TuningExampleOptions):
return example


@string_utils.prettyprint
@dataclasses.dataclass
class TuningSnapshot:
step: int
Expand All @@ -216,6 +221,7 @@ class TuningSnapshot:
compute_time: datetime.datetime


@string_utils.prettyprint
@dataclasses.dataclass
class Hyperparameters:
epoch_count: int = 0
Expand Down Expand Up @@ -246,6 +252,7 @@ def make_model_name(name: AnyModelNameOptions):
TunedModelsIterable = Iterable[TunedModel]


@string_utils.prettyprint
@dataclasses.dataclass
class TokenCount:
"""A dataclass representation of a `glm.TokenCountResponse`.
Expand Down
11 changes: 6 additions & 5 deletions google/generativeai/types/safety_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from collections.abc import Mapping

from google.ai import generativelanguage as glm
from google.generativeai import docstring_utils
from google.generativeai import string_utils

import typing
from typing import Iterable, Dict, Iterable, List, TypedDict, Union

Expand Down Expand Up @@ -134,7 +135,7 @@ class ContentFilterDict(TypedDict):
reason: BlockedReason
message: str

__doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__)
__doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__)


def convert_filters_to_enums(
Expand All @@ -153,7 +154,7 @@ class SafetyRatingDict(TypedDict):
category: HarmCategory
probability: HarmProbability

__doc__ = docstring_utils.strip_oneof(glm.SafetyRating.__doc__)
__doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__)


def convert_rating_to_enum(rating: dict) -> SafetyRatingDict:
Expand All @@ -174,7 +175,7 @@ class SafetySettingDict(TypedDict):
category: HarmCategory
threshold: HarmBlockThreshold

__doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__)
__doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__)


class LooseSafetySettingDict(TypedDict):
Expand Down Expand Up @@ -220,7 +221,7 @@ class SafetyFeedbackDict(TypedDict):
rating: SafetyRatingDict
setting: SafetySettingDict

__doc__ = docstring_utils.strip_oneof(glm.SafetyFeedback.__doc__)
__doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__)


def convert_safety_feedback_to_enums(
Expand Down
2 changes: 2 additions & 0 deletions google/generativeai/types/text_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
from typing import Any, Dict, List, TypedDict

from google.generativeai import string_utils
from google.generativeai.types import safety_types
from google.generativeai.types import citation_types

Expand All @@ -39,6 +40,7 @@ class TextCompletion(TypedDict, total=False):
citation_metadata: citation_types.CitationMetadataDict | None


@string_utils.prettyprint
@dataclasses.dataclass(init=False)
class Completion(abc.ABC):
"""The result returned by `generativeai.generate_text`.
Expand Down
Loading