Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 41 additions & 26 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,37 +629,52 @@ async def _postprocess_handle_function_calls_async(
function_call_event: Event,
llm_request: LlmRequest,
) -> AsyncGenerator[Event, None]:
if function_response_event := await functions.handle_function_calls_async(
invocation_context, function_call_event, llm_request.tools_dict
):
auth_event = functions.generate_auth_event(
invocation_context, function_response_event
# First, stream progressive tools if present (partial events + final event)
final_event_from_progressive = None
async with Aclosing(
functions.iter_progressive_function_calls_async(
invocation_context, function_call_event, llm_request.tools_dict
)
) as agen:
async for event in agen:
final_event_from_progressive = event
yield event

# If progressive produced a final event, continue with it; otherwise fallback
# to the default async handler (non-progressive tools and parallel merge)
function_response_event = final_event_from_progressive
if not function_response_event:
function_response_event = await functions.handle_function_calls_async(
invocation_context, function_call_event, llm_request.tools_dict
)
if auth_event:
yield auth_event
if not function_response_event:
return

# Always yield the function response event first
yield function_response_event

# Check if this is a set_model_response function response
if json_response := _output_schema_processor.get_structured_model_response(
function_response_event
):
# Create and yield a final model response event
final_event = (
_output_schema_processor.create_final_model_response_event(
invocation_context, json_response
)
)
yield final_event
transfer_to_agent = function_response_event.actions.transfer_to_agent
if transfer_to_agent:
agent_to_run = self._get_agent_to_run(
invocation_context, transfer_to_agent
)
async with Aclosing(agent_to_run.run_async(invocation_context)) as agen:
async for event in agen:
yield event
# Common path: auth event, structured response, agent transfer
auth_event = functions.generate_auth_event(
invocation_context, function_response_event
)
if auth_event:
yield auth_event

if json_response := _output_schema_processor.get_structured_model_response(
function_response_event
):
final_event = _output_schema_processor.create_final_model_response_event(
invocation_context, json_response
)
yield final_event
transfer_to_agent = function_response_event.actions.transfer_to_agent
if transfer_to_agent:
agent_to_run = self._get_agent_to_run(
invocation_context, transfer_to_agent
)
async with Aclosing(agent_to_run.run_async(invocation_context)) as agen:
async for event in agen:
yield event

def _get_agent_to_run(
self, invocation_context: InvocationContext, agent_name: str
Expand Down
74 changes: 74 additions & 0 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ...telemetry import trace_tool_call
from ...telemetry import tracer
from ...tools.base_tool import BaseTool
from ...tools.progressive_function_tool import ProgressiveFunctionTool
from ...tools.tool_context import ToolContext
from ...utils.context_utils import Aclosing

Expand Down Expand Up @@ -193,6 +194,79 @@ async def handle_function_calls_async(
return merged_event


async def iter_progressive_function_calls_async(
invocation_context: InvocationContext,
function_call_event: Event,
tools_dict: dict[str, BaseTool],
) -> AsyncGenerator[Event, None]:
"""Streams progress for ProgressiveFunctionTool, then yields final result.

This is async-run only and independent of LiveRequestQueue.
For each function call that maps to a ProgressiveFunctionTool:
- yield partial Events for each progress update
- then run the tool's run_async for the final result and yield a final Event
Non-progressive tools are ignored by this iterator.
"""
function_calls = function_call_event.get_function_calls()
if not function_calls:
return

for function_call in function_calls:
name = function_call.name
if name not in tools_dict:
continue
tool = tools_dict[name]
if not isinstance(tool, ProgressiveFunctionTool):
continue

tool_context = ToolContext(
invocation_context=invocation_context,
function_call_id=function_call.id,
)
function_args = (
copy.deepcopy(function_call.args) if function_call.args else {}
)

# Progress stream
try:
async with Aclosing(
tool.progress_stream(args=function_args, tool_context=tool_context)
) as agen:
async for progress in agen:
partial_event = __build_response_event(
tool, progress, tool_context, invocation_context
)
partial_event.partial = True
yield partial_event
except Exception as tool_error:
# Let on_tool_error callbacks decide if they want to convert error to result
error_response = (
await invocation_context.plugin_manager.run_on_tool_error_callback(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_error,
)
)
if error_response is None:
raise
# Treat handled error as final function response
final_event = __build_response_event(
tool, error_response, tool_context, invocation_context
)
yield final_event
continue

# Final result for the model
final_result = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
final_event = __build_response_event(
tool, final_result, tool_context, invocation_context
)
yield final_event


async def _execute_single_function_call_async(
invocation_context: InvocationContext,
function_call: types.FunctionCall,
Expand Down
2 changes: 2 additions & 0 deletions src/google/adk/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .load_memory_tool import load_memory_tool as load_memory
from .long_running_tool import LongRunningFunctionTool
from .preload_memory_tool import preload_memory_tool as preload_memory
from .progressive_tool import ProgressiveTool
from .tool_context import ToolContext
from .transfer_to_agent_tool import transfer_to_agent
from .url_context_tool import url_context
Expand All @@ -45,6 +46,7 @@
'ExampleTool',
'exit_loop',
'FunctionTool',
'ProgressiveTool',
'get_user_choice',
'load_artifacts',
'load_memory',
Expand Down
45 changes: 45 additions & 0 deletions src/google/adk/tools/progressive_function_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any
from typing import AsyncGenerator

from .function_tool import FunctionTool
from .tool_context import ToolContext


class ProgressiveFunctionTool(FunctionTool):
"""A FunctionTool that can stream progress updates during run_async.

Implement `progress_stream` to yield intermediate progress payloads.
The final result for model consumption must be returned by `run_async`.
"""

async def progress_stream(
self,
*,
args: dict[str, Any],
tool_context: ToolContext,
) -> AsyncGenerator[Any, None]:
"""Yields progress updates while the tool is executing.

Subclasses should override this method to emit progress objects. The last
item yielded here does not need to be the final result; the final result
should be returned by `run_async`.
"""
raise NotImplementedError(
f"{type(self).__name__}.progress_stream is not implemented"
)
152 changes: 152 additions & 0 deletions src/google/adk/tools/progressive_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
import inspect
from typing import Any
from typing import Optional

from ..utils.context_utils import Aclosing
from .function_tool import FunctionTool
from .progressive_function_tool import ProgressiveFunctionTool
from .tool_context import ToolContext


class ProgressiveTool(ProgressiveFunctionTool):
"""Wraps a regular async function to emit progress during run_async.

Usage:
from google.adk.tools.progressive_tool import ProgressiveTool
ProgressiveTool(my_async_function)

Supported function shapes:
- async generator function: yields are treated as progress; last yielded
value is treated as the final result.
- async function with optional `progress` or `progress_callback` parameter:
the wrapper injects a reporter callable that streams progress; the return
value of the function is treated as the final result.
- async function without any progress parameter: no progress is emitted; the
return value is treated as the final result.
"""

def __init__(self, func):
# Initialize as FunctionTool to extract name/description and signature logic
FunctionTool.__init__(self, func)
self._results_by_call_id: dict[str, Any] = {}
# Hide internal progress params from function declaration so the model is
# never prompted for them and schema parsing doesn't fail.
try:
ignore_list = list(getattr(self, '_ignore_params', []))
except Exception:
ignore_list = []
for p in ('progress', 'progress_callback'):
if p not in ignore_list:
ignore_list.append(p)
self._ignore_params = ignore_list

async def progress_stream(
self,
*,
args: dict[str, Any],
tool_context: ToolContext,
) -> asyncio.AsyncGenerator[Any, None]:
signature = inspect.signature(self.func)
valid_params = {param for param in signature.parameters}

# Build args for the wrapped function
args_to_call = {k: v for k, v in args.items() if k in valid_params}
if 'tool_context' in valid_params:
args_to_call['tool_context'] = tool_context

call_id: Optional[str] = tool_context.function_call_id

# Async generator function: yield directly and capture last item
if inspect.isasyncgenfunction(self.func):
last: Any = None
async with Aclosing(self.func(**args_to_call)) as agen:
async for item in agen:
last = item
yield item
if call_id:
self._results_by_call_id[call_id] = last
return

# Coroutine function: run in background, capture progress via callback
# Determine which progress parameter to use if present
progress_param: Optional[str] = None
if 'progress' in valid_params:
progress_param = 'progress'
elif 'progress_callback' in valid_params:
progress_param = 'progress_callback'

queue: asyncio.Queue[Any] = asyncio.Queue()

async def _report_progress(payload: Any):
await queue.put(payload)

if progress_param:
args_to_call[progress_param] = _report_progress

result_box: dict[str, Any] = {}

async def _run_and_capture():
result_box['value'] = await self.func(**args_to_call)

task = asyncio.create_task(_run_and_capture())

# Drain progress while task runs
try:
while True:
if task.done() and queue.empty():
break
try:
item = await asyncio.wait_for(queue.get(), timeout=0.1)
yield item
except asyncio.TimeoutError:
await asyncio.sleep(0)
continue
finally:
# Ensure task completion / propagate exception
await task

if call_id:
self._results_by_call_id[call_id] = result_box.get('value')

async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
"""Return final result. If progress_stream already ran, use captured value."""
call_id: Optional[str] = tool_context.function_call_id
if call_id and call_id in self._results_by_call_id:
return self._results_by_call_id.pop(call_id)

# Fallback: invoke function directly if progress_stream wasn't used
signature = inspect.signature(self.func)
valid_params = {param for param in signature.parameters}
args_to_call = {k: v for k, v in args.items() if k in valid_params}
if 'tool_context' in valid_params:
args_to_call['tool_context'] = tool_context

if inspect.isasyncgenfunction(self.func):
# Consume generator fully; return last item
last: Any = None
async with Aclosing(self.func(**args_to_call)) as agen:
async for item in agen:
last = item
return last

# Coroutine function
return await self.func(**args_to_call)
Loading