Skip to content

Commit 22efa99

Browse files
committed
feat: implement enterprise search agent tool and related functionality
:
1 parent ffbb0b3 commit 22efa99

File tree

5 files changed

+276
-3
lines changed

5 files changed

+276
-3
lines changed

src/google/adk/agents/llm_agent.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ async def _convert_tool_union_to_tools(
138138
model: Union[str, BaseLlm],
139139
multiple_tools: bool = False,
140140
) -> list[BaseTool]:
141+
from ..tools.enterprise_search_tool import EnterpriseWebSearchTool
141142
from ..tools.google_search_tool import GoogleSearchTool
142143
from ..tools.vertex_ai_search_tool import VertexAiSearchTool
143144

@@ -171,6 +172,17 @@ async def _convert_tool_union_to_tools(
171172
)
172173
]
173174

175+
# Wrap enterprise_web_search tool with AgentTool if there are multiple tools
176+
# because the built-in tools cannot be used together with other tools.
177+
# TODO(b/448114567): Remove once the workaround is no longer needed.
178+
if multiple_tools and isinstance(tool_union, EnterpriseWebSearchTool):
179+
from ..tools.enterprise_search_agent_tool import create_enterprise_search_agent
180+
from ..tools.enterprise_search_agent_tool import EnterpriseSearchAgentTool
181+
182+
enterprise_tool = cast(EnterpriseWebSearchTool, tool_union)
183+
if enterprise_tool.bypass_multi_tools_limit:
184+
return [EnterpriseSearchAgentTool(create_enterprise_search_agent(model))]
185+
174186
if isinstance(tool_union, BaseTool):
175187
return [tool_union]
176188
if callable(tool_union):

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,11 @@ async def _maybe_add_grounding_metadata(
846846
tools = await agent.canonical_tools(readonly_context)
847847
invocation_context.canonical_tools_cache = tools
848848

849-
if not any(tool.name == 'google_search_agent' for tool in tools):
849+
if not any(
850+
tool.name == 'google_search_agent'
851+
or tool.name == 'enterprise_search_agent'
852+
for tool in tools
853+
):
850854
return response
851855
ground_metadata = invocation_context.session.state.get(
852856
'temp:_adk_grounding_metadata', None
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Any
18+
from typing import Union
19+
20+
from google.genai import types
21+
from typing_extensions import override
22+
23+
from ..agents.llm_agent import LlmAgent
24+
from ..memory.in_memory_memory_service import InMemoryMemoryService
25+
from ..models.base_llm import BaseLlm
26+
from ..runners import Runner
27+
from ..sessions.in_memory_session_service import InMemorySessionService
28+
from ..utils.context_utils import Aclosing
29+
from ._forwarding_artifact_service import ForwardingArtifactService
30+
from .agent_tool import AgentTool
31+
from .enterprise_search_tool import enterprise_web_search_tool
32+
from .tool_context import ToolContext
33+
34+
35+
def create_enterprise_search_agent(model: Union[str, BaseLlm]) -> LlmAgent:
36+
"""Create a sub-agent that only uses enterprise_web_search tool."""
37+
return LlmAgent(
38+
name='enterprise_search_agent',
39+
model=model,
40+
description=(
41+
'An agent for performing Enterprise search using the'
42+
' `enterprise_web_search` tool'
43+
),
44+
instruction="""
45+
You are a specialized Enterprise search agent.
46+
47+
When given a search query, use the `enterprise_web_search` tool to find the related information.
48+
""",
49+
tools=[enterprise_web_search_tool],
50+
)
51+
52+
53+
class EnterpriseSearchAgentTool(AgentTool):
54+
"""A tool that wraps a sub-agent that only uses enterprise_web_search tool.
55+
56+
This is a workaround to support using enterprise_web_search tool with other tools.
57+
TODO(b/448114567): Remove once the workaround is no longer needed.
58+
59+
Attributes:
60+
model: The model to use for the sub-agent.
61+
"""
62+
63+
def __init__(self, agent: LlmAgent):
64+
self.agent = agent
65+
super().__init__(agent=self.agent)
66+
67+
@override
68+
async def run_async(
69+
self,
70+
*,
71+
args: dict[str, Any],
72+
tool_context: ToolContext,
73+
) -> Any:
74+
from ..agents.llm_agent import LlmAgent
75+
76+
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
77+
input_value = self.agent.input_schema.model_validate(args)
78+
content = types.Content(
79+
role='user',
80+
parts=[
81+
types.Part.from_text(
82+
text=input_value.model_dump_json(exclude_none=True)
83+
)
84+
],
85+
)
86+
else:
87+
content = types.Content(
88+
role='user',
89+
parts=[types.Part.from_text(text=args['request'])],
90+
)
91+
runner = Runner(
92+
app_name=self.agent.name,
93+
agent=self.agent,
94+
artifact_service=ForwardingArtifactService(tool_context),
95+
session_service=InMemorySessionService(),
96+
memory_service=InMemoryMemoryService(),
97+
credential_service=tool_context._invocation_context.credential_service,
98+
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
99+
)
100+
101+
state_dict = {
102+
k: v
103+
for k, v in tool_context.state.to_dict().items()
104+
if not k.startswith('_adk') # Filter out adk internal states
105+
}
106+
session = await runner.session_service.create_session(
107+
app_name=self.agent.name,
108+
user_id=tool_context._invocation_context.user_id,
109+
state=state_dict,
110+
)
111+
112+
last_content = None
113+
last_grounding_metadata = None
114+
async with Aclosing(
115+
runner.run_async(
116+
user_id=session.user_id, session_id=session.id, new_message=content
117+
)
118+
) as agen:
119+
async for event in agen:
120+
# Forward state delta to parent session.
121+
if event.actions.state_delta:
122+
tool_context.state.update(event.actions.state_delta)
123+
if event.content:
124+
last_content = event.content
125+
last_grounding_metadata = event.grounding_metadata
126+
127+
if not last_content:
128+
return ''
129+
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
130+
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
131+
tool_result = self.agent.output_schema.model_validate_json(
132+
merged_text
133+
).model_dump(exclude_none=True)
134+
else:
135+
tool_result = merged_text
136+
137+
if last_grounding_metadata:
138+
tool_context.state['temp:_adk_grounding_metadata'] = (
139+
last_grounding_metadata
140+
)
141+
return tool_result

src/google/adk/tools/enterprise_search_tool.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,18 @@ class EnterpriseWebSearchTool(BaseTool):
3535
https://cloud.google.com/vertex-ai/generative-ai/docs/grounding/web-grounding-enterprise.
3636
"""
3737

38-
def __init__(self):
39-
"""Initializes the Vertex AI Search tool."""
38+
def __init__(self, *, bypass_multi_tools_limit: bool = False):
39+
"""Initializes the Google search tool.
40+
41+
Args:
42+
bypass_multi_tools_limit: Whether to bypass the multi tools limitation,
43+
so that the tool can be used with other tools in the same agent.
44+
"""
4045
# Name and description are not used because this is a model built-in tool.
4146
super().__init__(
4247
name='enterprise_web_search', description='enterprise_web_search'
4348
)
49+
self.bypass_multi_tools_limit = bypass_multi_tools_limit
4450

4551
@override
4652
async def process_llm_request(
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from unittest import mock
18+
19+
from google.adk.agents.invocation_context import InvocationContext
20+
from google.adk.agents.llm_agent import LlmAgent
21+
from google.adk.agents.run_config import RunConfig
22+
from google.adk.agents.sequential_agent import SequentialAgent
23+
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
24+
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
25+
from google.adk.plugins.plugin_manager import PluginManager
26+
from google.adk.sessions.in_memory_session_service import InMemorySessionService
27+
from google.adk.tools.enterprise_search_agent_tool import create_enterprise_search_agent
28+
from google.adk.tools.enterprise_search_agent_tool import EnterpriseSearchAgentTool
29+
from google.adk.tools.tool_context import ToolContext
30+
from pytest import mark
31+
32+
33+
async def _create_tool_context() -> ToolContext:
34+
session_service = InMemorySessionService()
35+
session = await session_service.create_session(
36+
app_name='test_app', user_id='test_user'
37+
)
38+
agent = SequentialAgent(name='test_agent')
39+
invocation_context = InvocationContext(
40+
invocation_id='invocation_id',
41+
agent=agent,
42+
session=session,
43+
session_service=session_service,
44+
artifact_service=InMemoryArtifactService(),
45+
memory_service=InMemoryMemoryService(),
46+
plugin_manager=PluginManager(),
47+
run_config=RunConfig(),
48+
)
49+
return ToolContext(invocation_context=invocation_context)
50+
51+
52+
class TestEnterpriseSearchAgentTool:
53+
"""Test the EnterpriseSearchAgentTool class."""
54+
55+
def test_create_enterprise_search_agent(self):
56+
"""Test that create_enterprise_search_agent creates a valid agent."""
57+
agent = create_enterprise_search_agent('gemini-pro')
58+
assert isinstance(agent, LlmAgent)
59+
assert agent.name == 'enterprise_search_agent'
60+
assert 'enterprise_web_search' in [t.name for t in agent.tools]
61+
62+
def test_enterprise_search_agent_tool_init(self):
63+
"""Test initialization of EnterpriseSearchAgentTool."""
64+
mock_agent = mock.MagicMock(spec=LlmAgent)
65+
mock_agent.name = 'test_agent'
66+
mock_agent.description = 'test_description'
67+
tool = EnterpriseSearchAgentTool(mock_agent)
68+
assert tool.agent == mock_agent
69+
70+
@mark.asyncio
71+
@mock.patch('google.adk.tools.enterprise_search_agent_tool.Runner')
72+
async def test_run_async_succeeds(self, mock_runner_class):
73+
"""Test that run_async executes the sub-agent and returns the result."""
74+
# Arrange
75+
mock_agent = mock.MagicMock(spec=LlmAgent)
76+
mock_agent.name = 'enterprise_search_agent'
77+
mock_agent.description = 'test_description'
78+
mock_agent.input_schema = None
79+
mock_agent.output_schema = None
80+
81+
tool = EnterpriseSearchAgentTool(mock_agent)
82+
tool_context = await _create_tool_context()
83+
84+
async def mock_run_async_gen():
85+
yield mock.MagicMock(
86+
actions=mock.MagicMock(state_delta={'key': 'value'}), content=None
87+
)
88+
yield mock.MagicMock(
89+
actions=mock.MagicMock(state_delta=None),
90+
content=mock.MagicMock(parts=[mock.MagicMock(text='test response')]),
91+
)
92+
93+
mock_runner_instance = mock.MagicMock()
94+
mock_runner_instance.run_async.return_value = mock_run_async_gen()
95+
mock_runner_instance.session_service = mock.AsyncMock()
96+
mock_runner_instance.session_service.create_session.return_value = (
97+
tool_context._invocation_context.session
98+
)
99+
mock_runner_class.return_value = mock_runner_instance
100+
101+
# Act
102+
result = await tool.run_async(
103+
args={'request': 'test query'}, tool_context=tool_context
104+
)
105+
106+
# Assert
107+
mock_runner_class.assert_called_once()
108+
mock_runner_instance.run_async.assert_called_once()
109+
assert tool_context.state['key'] == 'value'
110+
assert result == 'test response'

0 commit comments

Comments
 (0)