Skip to content
Open
49 changes: 46 additions & 3 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
DEFAULT_REQUEST_QUEUE_TIMEOUT = 0.25
DEFAULT_TRANSFER_AGENT_DELAY = 1.0
DEFAULT_TASK_COMPLETION_DELAY = 1.0
DEFAULT_LLM_CLEANUP_TIMEOUT = 5.0

# Statistics configuration
DEFAULT_ENABLE_CACHE_STATISTICS = False
Expand Down Expand Up @@ -751,6 +752,22 @@ async def _call_llm_async(
# Calls the LLM.
llm = self.__get_llm(invocation_context)

# Determine if this LLM instance was created just for this request
# (needs cleanup) or is a reused instance from the agent (no cleanup).
from ...agents.llm_agent import LlmAgent
from ...models.base_llm import BaseLlm

needs_cleanup = False
if isinstance(invocation_context.agent, LlmAgent):
agent_model = invocation_context.agent.model
# If agent.model is a string, canonical_model creates a new instance
# that needs cleanup. If agent.model is a BaseLlm instance, it's reused.
needs_cleanup = not isinstance(agent_model, BaseLlm)
logger.debug(
f'LLM cleanup check: agent.model type={type(agent_model).__name__}, '
f'needs_cleanup={needs_cleanup}, llm type={type(llm).__name__}'
)

async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
with tracer.start_as_current_span('call_llm'):
if invocation_context.run_config.support_cfc:
Expand Down Expand Up @@ -812,9 +829,35 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:

yield llm_response

async with Aclosing(_call_llm_with_tracing()) as agen:
async for event in agen:
yield event
try:
async with Aclosing(_call_llm_with_tracing()) as agen:
async for event in agen:
yield event
finally:
# Clean up the LLM instance if it was created for this request
if needs_cleanup:
try:
import asyncio

logger.info(f'Cleaning up LLM instance: {type(llm).__name__}')
# Use timeout to prevent hanging on cleanup
await asyncio.wait_for(
llm.aclose(), timeout=DEFAULT_LLM_CLEANUP_TIMEOUT
)
logger.info(
f'Successfully cleaned up LLM instance: {type(llm).__name__}'
)
except asyncio.TimeoutError:
logger.warning(
'LLM cleanup timed out after'
f' {DEFAULT_LLM_CLEANUP_TIMEOUT} seconds'
)
except Exception as e:
logger.warning(f'Error closing LLM instance: {e}')
else:
logger.debug(
f'Skipping LLM cleanup (reused instance): {type(llm).__name__}'
)

async def _handle_before_model_callback(
self,
Expand Down
11 changes: 11 additions & 0 deletions src/google/adk/models/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,14 @@ def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
raise NotImplementedError(
f'Live connection is not supported for {self.model}.'
)

async def aclose(self) -> None:
"""Closes the LLM and releases resources.

This method provides a lifecycle hook for cleanup when the LLM is no longer
needed. The default implementation is a no-op for backward compatibility.

Subclasses that manage resources (e.g., HTTP clients) should override this
method to perform proper cleanup.
"""
pass
40 changes: 40 additions & 0 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,46 @@ def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
headers[key] = ' '.join(value_parts)
return headers

@override
async def aclose(self) -> None:
"""Closes API clients if they were accessed.

Checks if the cached_property clients have been instantiated and closes
them if necessary. Uses asyncio.gather to ensure all cleanup attempts
complete even if some fail.
"""
import asyncio

_CLIENT_CLOSE_TIMEOUT = 10.0
close_tasks = []

def _add_close_task(client):
"""Appends the appropriate aclose coroutine to close_tasks."""
if hasattr(client, 'aio') and hasattr(client.aio, 'aclose'):
close_tasks.append(client.aio.aclose())
elif hasattr(client, 'aclose'):
close_tasks.append(client.aclose())

# Check if api_client was accessed and close it
if 'api_client' in self.__dict__:
_add_close_task(self.__dict__['api_client'])

# Check if _live_api_client was accessed and close it
if '_live_api_client' in self.__dict__:
_add_close_task(self.__dict__['_live_api_client'])

# Execute all close operations concurrently with timeout
if close_tasks:
try:
await asyncio.wait_for(
asyncio.gather(*close_tasks, return_exceptions=True),
timeout=_CLIENT_CLOSE_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning('Timeout waiting for API clients to close')
except Exception as e:
logger.warning(f'Error during API client cleanup: {e}')


def _build_function_declaration_log(
func_decl: types.FunctionDeclaration,
Expand Down
79 changes: 77 additions & 2 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@

logger = logging.getLogger('google_adk.' + __name__)

# LLM cleanup configuration
_LLM_MODEL_CLEANUP_TIMEOUT = 5.0


class Runner:
"""The Runner class is used to run agents.
Expand Down Expand Up @@ -1311,6 +1314,40 @@ def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]:
toolsets.update(self._collect_toolset(sub_agent))
return toolsets

def _collect_llm_models(self, agent: BaseAgent) -> list:
"""Recursively collects all LLM model instances from the agent tree.

Args:
agent: The root agent to collect LLM models from.

Returns:
A list of unique BaseLlm instances found in the agent tree.
"""
from google.adk.models.base_llm import BaseLlm

llm_models = []
seen_ids = set()

def _collect(current_agent: BaseAgent):
"""Helper to recursively collect models."""
if isinstance(current_agent, LlmAgent):
try:
canonical = current_agent.canonical_model
if isinstance(canonical, BaseLlm):
model_id = id(canonical)
if model_id not in seen_ids:
llm_models.append(canonical)
seen_ids.add(model_id)
except (ValueError, AttributeError):
# Agent might not have a model configured or canonical_model fails
pass

for sub_agent in current_agent.sub_agents:
_collect(sub_agent)

_collect(agent)
return llm_models

async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]):
"""Clean up toolsets with proper task context management."""
if not toolsets_to_close:
Expand Down Expand Up @@ -1341,12 +1378,50 @@ async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]):
except Exception as e:
logger.error('Error closing toolset %s: %s', type(toolset).__name__, e)

async def _cleanup_llm_models(self, llm_models_to_close: list):
"""Clean up LLM models with proper error handling and timeout.

Args:
llm_models_to_close: List of BaseLlm instances to close.
"""
if not llm_models_to_close:
return

for llm_model in llm_models_to_close:
try:
logger.info('Closing LLM model: %s', type(llm_model).__name__)
# Use asyncio.wait_for to add timeout protection
await asyncio.wait_for(
llm_model.aclose(), timeout=_LLM_MODEL_CLEANUP_TIMEOUT
)
logger.info(
'Successfully closed LLM model: %s', type(llm_model).__name__
)
except asyncio.TimeoutError:
logger.warning(
'LLM model %s cleanup timed out after %s seconds',
type(llm_model).__name__,
_LLM_MODEL_CLEANUP_TIMEOUT,
)
except Exception as e:
logger.error(
'Error closing LLM model %s: %s', type(llm_model).__name__, e
)

async def close(self):
"""Closes the runner."""
"""Closes the runner and cleans up all resources.

Cleans up toolsets first, then LLM models, to ensure proper resource
cleanup order.
"""
logger.info('Closing runner...')
# Close Toolsets
# Clean up toolsets first
await self._cleanup_toolsets(self._collect_toolset(self.agent))

# Then clean up LLM models
llm_models_to_close = self._collect_llm_models(self.agent)
await self._cleanup_llm_models(llm_models_to_close)

# Close Plugins
if self.plugin_manager:
await self.plugin_manager.close()
Expand Down