From 9e8fb01b8dca8889e52d97c28a427ef1a7e2e3ac Mon Sep 17 00:00:00 2001 From: sarojrout Date: Tue, 18 Nov 2025 00:55:30 -0800 Subject: [PATCH] fix(tools): preserve integer precision in tool calls (#3592) --- .../tools/_function_parameter_parse_util.py | 59 ++++++++++++++++--- src/google/adk/tools/function_tool.py | 55 +++++++++++++++++ .../tools/test_from_function_with_options.py | 17 ++++++ tests/unittests/tools/test_function_tool.py | 19 ++++++ 4 files changed, 141 insertions(+), 9 deletions(-) diff --git a/src/google/adk/tools/_function_parameter_parse_util.py b/src/google/adk/tools/_function_parameter_parse_util.py index 1b9559b29c..a1be6843e8 100644 --- a/src/google/adk/tools/_function_parameter_parse_util.py +++ b/src/google/adk/tools/_function_parameter_parse_util.py @@ -48,6 +48,8 @@ logger = logging.getLogger('google_adk.' + __name__) +_INTEGER_STRING_PATTERN = r'^-?\d+$' + def _handle_params_as_deferred_annotations( param: inspect.Parameter, annotation_under_future: dict[str, Any], name: str @@ -231,7 +233,7 @@ def _parse_schema_from_parameter( schema.default = param.default schema.type = _py_builtin_type_to_schema_type[param.annotation] _raise_if_schema_unsupported(variant, schema) - return schema + return _sanitize_integer_schema_for_variant(schema, variant) if isinstance(param.annotation, type) and issubclass(param.annotation, Enum): schema.type = types.Type.STRING schema.enum = [e.value for e in param.annotation] @@ -245,7 +247,7 @@ def _parse_schema_from_parameter( raise ValueError(default_value_error_msg) schema.default = default_value _raise_if_schema_unsupported(variant, schema) - return schema + return _sanitize_integer_schema_for_variant(schema, variant) if ( get_origin(param.annotation) is Union # only parse simple UnionType, example int | str | float | bool @@ -286,7 +288,7 @@ def _parse_schema_from_parameter( raise ValueError(default_value_error_msg) schema.default = param.default _raise_if_schema_unsupported(variant, schema) - return schema + return _sanitize_integer_schema_for_variant(schema, variant) if isinstance(param.annotation, _GenericAlias) or isinstance( param.annotation, typing_types.GenericAlias ): @@ -299,7 +301,7 @@ def _parse_schema_from_parameter( raise ValueError(default_value_error_msg) schema.default = param.default _raise_if_schema_unsupported(variant, schema) - return schema + return _sanitize_integer_schema_for_variant(schema, variant) if origin is Literal: if not all(isinstance(arg, str) for arg in args): raise ValueError( @@ -312,7 +314,7 @@ def _parse_schema_from_parameter( raise ValueError(default_value_error_msg) schema.default = param.default _raise_if_schema_unsupported(variant, schema) - return schema + return _sanitize_integer_schema_for_variant(schema, variant) if origin is list: schema.type = types.Type.ARRAY schema.items = _parse_schema_from_parameter( @@ -329,7 +331,7 @@ def _parse_schema_from_parameter( raise ValueError(default_value_error_msg) schema.default = param.default _raise_if_schema_unsupported(variant, schema) - return schema + return _sanitize_integer_schema_for_variant(schema, variant) if origin is Union: schema.any_of = [] schema.type = types.Type.OBJECT @@ -375,7 +377,7 @@ def _parse_schema_from_parameter( raise ValueError(default_value_error_msg) schema.default = param.default _raise_if_schema_unsupported(variant, schema) - return schema + return _sanitize_integer_schema_for_variant(schema, variant) # all other generic alias will be invoked in raise branch if ( inspect.isclass(param.annotation) @@ -400,7 +402,7 @@ def _parse_schema_from_parameter( func_name, ) _raise_if_schema_unsupported(variant, schema) - return schema + return _sanitize_integer_schema_for_variant(schema, variant) if inspect.isclass(param.annotation) and issubclass( param.annotation, ToolContext ): @@ -414,7 +416,7 @@ def _parse_schema_from_parameter( schema.type = types.Type.OBJECT schema.nullable = True _raise_if_schema_unsupported(variant, schema) - return schema + return _sanitize_integer_schema_for_variant(schema, variant) raise ValueError( f'Failed to parse the parameter {param} of function {func_name} for' ' automatic function calling. Automatic function calling works best with' @@ -431,3 +433,42 @@ def _get_required_fields(schema: types.Schema) -> list[str]: for field_name, field_schema in schema.properties.items() if not field_schema.nullable and field_schema.default is None ] + + +def _sanitize_integer_schema_for_variant( + schema: types.Schema, variant: GoogleLLMVariant +) -> types.Schema: + if variant != GoogleLLMVariant.GEMINI_API: + return schema + + _convert_integer_schema_to_string(schema) + return schema + + +def _convert_integer_schema_to_string(schema: types.Schema | None) -> None: + if schema is None: + return + + if schema.type == types.Type.INTEGER: + schema.type = types.Type.STRING + if schema.pattern is None: + schema.pattern = _INTEGER_STRING_PATTERN + if schema.default is not None and not isinstance(schema.default, str): + schema.default = str(schema.default) + if schema.enum: + schema.enum = [str(enum_value) for enum_value in schema.enum] + + if schema.properties: + for property_schema in schema.properties.values(): + _convert_integer_schema_to_string(property_schema) + + if schema.items: + _convert_integer_schema_to_string(schema.items) + + if schema.any_of: + for nested_schema in schema.any_of: + _convert_integer_schema_to_string(nested_schema) + + additional_properties = schema.additional_properties + if isinstance(additional_properties, types.Schema): + _convert_integer_schema_to_string(additional_properties) diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index d957d1c16b..56b520a951 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -14,8 +14,11 @@ from __future__ import annotations +import decimal import inspect import logging +import math +import re from typing import Any from typing import Callable from typing import get_args @@ -34,6 +37,8 @@ logger = logging.getLogger('google_adk.' + __name__) +_INTEGER_STRING_REGEX = re.compile(r'^[-+]?\d+$') + class FunctionTool(BaseTool): """A tool that wraps a user-defined Python function. @@ -151,8 +156,58 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]: # Keep the original value if conversion fails pass + if param_name in converted_args: + converted_args[param_name] = self._maybe_convert_builtin_value( + converted_args[param_name], target_type + ) + return converted_args + def _maybe_convert_builtin_value(self, value: Any, target_type: Any) -> Any: + if target_type is inspect.Parameter.empty: + return value + + if target_type is int: + return self._convert_to_int_value(value) + + return value + + @staticmethod + def _convert_to_int_value(value: Any) -> Any: + if value is None or isinstance(value, bool): + return value + + if isinstance(value, int): + return value + + if isinstance(value, str): + trimmed_value = value.strip() + if _INTEGER_STRING_REGEX.match(trimmed_value): + try: + return int(trimmed_value) + except ValueError: + return value + return value + + if isinstance(value, decimal.Decimal): + try: + integral_value = value.to_integral_value() + except (decimal.InvalidOperation, ValueError): + return value + if integral_value == value: + try: + return int(integral_value) + except (ValueError, OverflowError): + return value + return value + + if isinstance(value, float): + if math.isfinite(value) and value.is_integer(): + return int(value) + return value + + return value + @override async def run_async( self, *, args: dict[str, Any], tool_context: ToolContext diff --git a/tests/unittests/tools/test_from_function_with_options.py b/tests/unittests/tools/test_from_function_with_options.py index 61670a2678..a0fc91c26f 100644 --- a/tests/unittests/tools/test_from_function_with_options.py +++ b/tests/unittests/tools/test_from_function_with_options.py @@ -153,6 +153,23 @@ def test_function(param: str) -> int: assert declaration.response.type == types.Type.INTEGER +def test_from_function_with_options_int_parameter_gemini(): + """Test that GEMINI int parameters are represented as strings.""" + + def test_function(count: int) -> None: + """A test function with an integer parameter.""" + _ = count + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.GEMINI_API + ) + + assert declaration.parameters is not None + param_schema = declaration.parameters.properties['count'] + assert param_schema.type == types.Type.STRING + assert param_schema.pattern == '^-?\\d+$' + + def test_from_function_with_options_any_annotation_vertex(): """Test from_function_with_options with Any type annotation for VERTEX_AI.""" diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py index 78610d330d..c0c362965a 100644 --- a/tests/unittests/tools/test_function_tool.py +++ b/tests/unittests/tools/test_function_tool.py @@ -428,3 +428,22 @@ def explicit_params_func(arg1: str, arg2: int): assert result == {"arg1": "test", "arg2": 42} # Explicitly verify that unexpected_param was filtered out and not passed to the function assert "unexpected_param" not in result + + +@pytest.mark.asyncio +async def test_run_async_with_large_integer_strings(mock_tool_context): + """Test that large integers provided as strings are converted losslessly.""" + + def add_numbers(a: int, b: int) -> int: + return a + b + + tool = FunctionTool(add_numbers) + result = await tool.run_async( + args={ + "a": "10000000000000001", + "b": "123456789", + }, + tool_context=mock_tool_context, + ) + + assert result == 10000000123456790