Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions src/google/adk/tools/_function_parameter_parse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@

logger = logging.getLogger('google_adk.' + __name__)

_INTEGER_STRING_PATTERN = r'^-?\d+$'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The regex pattern for integer strings is inconsistent with the one used for parsing in function_tool.py. Here, it is r'^-?\d+$', while in function_tool.py it is r'^[-+]?\d+$', which also allows a leading +. While the Gemini model might not generate a leading +, it's best practice to keep the schema pattern and the parser logic consistent. This makes the system more robust. Please consider updating the pattern to allow for an optional positive sign.

Suggested change
_INTEGER_STRING_PATTERN = r'^-?\d+$'
_INTEGER_STRING_PATTERN = r'^[-+]?\d+$'



def _handle_params_as_deferred_annotations(
param: inspect.Parameter, annotation_under_future: dict[str, Any], name: str
Expand Down Expand Up @@ -231,7 +233,7 @@ def _parse_schema_from_parameter(
schema.default = param.default
schema.type = _py_builtin_type_to_schema_type[param.annotation]
_raise_if_schema_unsupported(variant, schema)
return schema
return _sanitize_integer_schema_for_variant(schema, variant)
if isinstance(param.annotation, type) and issubclass(param.annotation, Enum):
schema.type = types.Type.STRING
schema.enum = [e.value for e in param.annotation]
Expand All @@ -245,7 +247,7 @@ def _parse_schema_from_parameter(
raise ValueError(default_value_error_msg)
schema.default = default_value
_raise_if_schema_unsupported(variant, schema)
return schema
return _sanitize_integer_schema_for_variant(schema, variant)
if (
get_origin(param.annotation) is Union
# only parse simple UnionType, example int | str | float | bool
Expand Down Expand Up @@ -286,7 +288,7 @@ def _parse_schema_from_parameter(
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
return _sanitize_integer_schema_for_variant(schema, variant)
if isinstance(param.annotation, _GenericAlias) or isinstance(
param.annotation, typing_types.GenericAlias
):
Expand All @@ -299,7 +301,7 @@ def _parse_schema_from_parameter(
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
return _sanitize_integer_schema_for_variant(schema, variant)
if origin is Literal:
if not all(isinstance(arg, str) for arg in args):
raise ValueError(
Expand All @@ -312,7 +314,7 @@ def _parse_schema_from_parameter(
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
return _sanitize_integer_schema_for_variant(schema, variant)
if origin is list:
schema.type = types.Type.ARRAY
schema.items = _parse_schema_from_parameter(
Expand All @@ -329,7 +331,7 @@ def _parse_schema_from_parameter(
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
return _sanitize_integer_schema_for_variant(schema, variant)
if origin is Union:
schema.any_of = []
schema.type = types.Type.OBJECT
Expand Down Expand Up @@ -375,7 +377,7 @@ def _parse_schema_from_parameter(
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
return _sanitize_integer_schema_for_variant(schema, variant)
# all other generic alias will be invoked in raise branch
if (
inspect.isclass(param.annotation)
Expand All @@ -400,7 +402,7 @@ def _parse_schema_from_parameter(
func_name,
)
_raise_if_schema_unsupported(variant, schema)
return schema
return _sanitize_integer_schema_for_variant(schema, variant)
if inspect.isclass(param.annotation) and issubclass(
param.annotation, ToolContext
):
Expand All @@ -414,7 +416,7 @@ def _parse_schema_from_parameter(
schema.type = types.Type.OBJECT
schema.nullable = True
_raise_if_schema_unsupported(variant, schema)
return schema
return _sanitize_integer_schema_for_variant(schema, variant)
raise ValueError(
f'Failed to parse the parameter {param} of function {func_name} for'
' automatic function calling. Automatic function calling works best with'
Expand All @@ -431,3 +433,42 @@ def _get_required_fields(schema: types.Schema) -> list[str]:
for field_name, field_schema in schema.properties.items()
if not field_schema.nullable and field_schema.default is None
]


def _sanitize_integer_schema_for_variant(
schema: types.Schema, variant: GoogleLLMVariant
) -> types.Schema:
if variant != GoogleLLMVariant.GEMINI_API:
return schema

_convert_integer_schema_to_string(schema)
return schema


def _convert_integer_schema_to_string(schema: types.Schema | None) -> None:
if schema is None:
return

if schema.type == types.Type.INTEGER:
schema.type = types.Type.STRING
if schema.pattern is None:
schema.pattern = _INTEGER_STRING_PATTERN
if schema.default is not None and not isinstance(schema.default, str):
schema.default = str(schema.default)
if schema.enum:
schema.enum = [str(enum_value) for enum_value in schema.enum]

if schema.properties:
for property_schema in schema.properties.values():
_convert_integer_schema_to_string(property_schema)

if schema.items:
_convert_integer_schema_to_string(schema.items)

if schema.any_of:
for nested_schema in schema.any_of:
_convert_integer_schema_to_string(nested_schema)

additional_properties = schema.additional_properties
if isinstance(additional_properties, types.Schema):
_convert_integer_schema_to_string(additional_properties)
55 changes: 55 additions & 0 deletions src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@

from __future__ import annotations

import decimal
import inspect
import logging
import math
import re
from typing import Any
from typing import Callable
from typing import get_args
Expand All @@ -34,6 +37,8 @@

logger = logging.getLogger('google_adk.' + __name__)

_INTEGER_STRING_REGEX = re.compile(r'^[-+]?\d+$')


class FunctionTool(BaseTool):
"""A tool that wraps a user-defined Python function.
Expand Down Expand Up @@ -151,8 +156,58 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
# Keep the original value if conversion fails
pass

if param_name in converted_args:
converted_args[param_name] = self._maybe_convert_builtin_value(
converted_args[param_name], target_type
)

return converted_args

def _maybe_convert_builtin_value(self, value: Any, target_type: Any) -> Any:
if target_type is inspect.Parameter.empty:
return value

if target_type is int:
return self._convert_to_int_value(value)

return value

@staticmethod
def _convert_to_int_value(value: Any) -> Any:
if value is None or isinstance(value, bool):
return value

if isinstance(value, int):
return value

if isinstance(value, str):
trimmed_value = value.strip()
if _INTEGER_STRING_REGEX.match(trimmed_value):
try:
return int(trimmed_value)
except ValueError:
return value
return value

if isinstance(value, decimal.Decimal):
try:
integral_value = value.to_integral_value()
except (decimal.InvalidOperation, ValueError):
return value
if integral_value == value:
try:
return int(integral_value)
except (ValueError, OverflowError):
return value
return value

if isinstance(value, float):
if math.isfinite(value) and value.is_integer():
return int(value)
return value

return value

@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/tools/test_from_function_with_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,23 @@ def test_function(param: str) -> int:
assert declaration.response.type == types.Type.INTEGER


def test_from_function_with_options_int_parameter_gemini():
"""Test that GEMINI int parameters are represented as strings."""

def test_function(count: int) -> None:
"""A test function with an integer parameter."""
_ = count

declaration = _automatic_function_calling_util.from_function_with_options(
test_function, GoogleLLMVariant.GEMINI_API
)

assert declaration.parameters is not None
param_schema = declaration.parameters.properties['count']
assert param_schema.type == types.Type.STRING
assert param_schema.pattern == '^-?\\d+$'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To align with the suggested change to the integer string pattern in _function_parameter_parse_util.py, this test assertion should be updated to reflect the new, more robust pattern.

Suggested change
assert param_schema.pattern == '^-?\\d+$'
assert param_schema.pattern == '^[-+]?\d+$'



def test_from_function_with_options_any_annotation_vertex():
"""Test from_function_with_options with Any type annotation for VERTEX_AI."""

Expand Down
19 changes: 19 additions & 0 deletions tests/unittests/tools/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,22 @@ def explicit_params_func(arg1: str, arg2: int):
assert result == {"arg1": "test", "arg2": 42}
# Explicitly verify that unexpected_param was filtered out and not passed to the function
assert "unexpected_param" not in result


@pytest.mark.asyncio
async def test_run_async_with_large_integer_strings(mock_tool_context):
"""Test that large integers provided as strings are converted losslessly."""

def add_numbers(a: int, b: int) -> int:
return a + b

tool = FunctionTool(add_numbers)
result = await tool.run_async(
args={
"a": "10000000000000001",
"b": "123456789",
},
tool_context=mock_tool_context,
)

assert result == 10000000123456790