Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
20 changes: 20 additions & 0 deletions src/google/adk/tools/bigquery/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ class BigQueryToolConfig(BaseModel):
locations, see https://cloud.google.com/bigquery/docs/locations.
"""

job_labels: Optional[dict[str, str]] = None
"""Labels to apply to BigQuery jobs for tracking and monitoring.

These labels will be added to all BigQuery jobs executed by the execute_sql
function. Labels must be key-value pairs where both keys and values are
strings. Labels can be used for billing, monitoring, and resource organization.
For more information about labels, see
https://cloud.google.com/bigquery/docs/labels-intro.
"""

@field_validator('maximum_bytes_billed')
@classmethod
def validate_maximum_bytes_billed(cls, v):
Expand All @@ -121,3 +131,13 @@ def validate_application_name(cls, v):
if v and ' ' in v:
raise ValueError('Application name should not contain spaces.')
return v

@field_validator('job_labels')
@classmethod
def validate_job_labels(cls, v):
"""Validate that job_labels keys are not empty."""
if v is not None:
for key in v.keys():
if not key:
raise ValueError('Label keys cannot be empty.')
return v
5 changes: 4 additions & 1 deletion src/google/adk/tools/bigquery/query_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def _execute_sql(
bq_connection_properties = []

# BigQuery job labels if applicable
bq_job_labels = {}
bq_job_labels = (
settings.job_labels.copy() if settings and settings.job_labels else {}
)

if caller_id:
bq_job_labels["adk-bigquery-tool"] = caller_id
if settings and settings.application_name:
Expand Down
237 changes: 237 additions & 0 deletions tests/unittests/tools/bigquery/test_bigquery_query_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,6 +1709,65 @@ def test_execute_sql_job_labels(
}


@pytest.mark.parametrize(
("write_mode", "dry_run", "query_call_count", "query_and_wait_call_count"),
[
pytest.param(WriteMode.ALLOWED, False, 0, 1, id="write-allowed"),
pytest.param(WriteMode.ALLOWED, True, 1, 0, id="write-allowed-dry-run"),
pytest.param(WriteMode.BLOCKED, False, 1, 1, id="write-blocked"),
pytest.param(WriteMode.BLOCKED, True, 2, 0, id="write-blocked-dry-run"),
pytest.param(WriteMode.PROTECTED, False, 2, 1, id="write-protected"),
pytest.param(
WriteMode.PROTECTED, True, 3, 0, id="write-protected-dry-run"
),
],
)
def test_execute_sql_user_job_labels_augment_internal_labels(
write_mode, dry_run, query_call_count, query_and_wait_call_count
):
"""Test execute_sql tool augments user job_labels with internal labels."""
project = "my_project"
query = "SELECT 123 AS num"
statement_type = "SELECT"
credentials = mock.create_autospec(Credentials, instance=True)
user_labels = {"environment": "test", "team": "data"}
tool_settings = BigQueryToolConfig(
write_mode=write_mode,
job_labels=user_labels,
)
tool_context = mock.create_autospec(ToolContext, instance=True)
tool_context.state.get.return_value = None

with mock.patch.object(bigquery, "Client", autospec=True) as Client:
bq_client = Client.return_value

query_job = mock.create_autospec(bigquery.QueryJob)
query_job.statement_type = statement_type
bq_client.query.return_value = query_job

query_tool.execute_sql(
project,
query,
credentials,
tool_settings,
tool_context,
dry_run=dry_run,
)

assert bq_client.query.call_count == query_call_count
assert bq_client.query_and_wait.call_count == query_and_wait_call_count
# Build expected labels from user_labels + internal label
expected_labels = {**user_labels, "adk-bigquery-tool": "execute_sql"}
for call_args_list in [
bq_client.query.call_args_list,
bq_client.query_and_wait.call_args_list,
]:
for call_args in call_args_list:
_, mock_kwargs = call_args
# Verify user labels are preserved and internal label is added
assert mock_kwargs["job_config"].labels == expected_labels


@pytest.mark.parametrize(
("tool_call", "expected_tool_label"),
[
Expand Down Expand Up @@ -1850,6 +1909,94 @@ def test_ml_tool_job_labels_w_application_name(tool_call, expected_tool_label):
assert mock_kwargs["job_config"].labels == expected_labels


@pytest.mark.parametrize(
("tool_call", "expected_labels"),
[
pytest.param(
lambda tool_context: query_tool.forecast(
project_id="test-project",
history_data="SELECT * FROM `test-dataset.test-table`",
timestamp_col="ts_col",
data_col="data_col",
credentials=mock.create_autospec(Credentials, instance=True),
settings=BigQueryToolConfig(
write_mode=WriteMode.ALLOWED,
job_labels={"environment": "prod", "app": "forecaster"},
),
tool_context=tool_context,
),
{
"environment": "prod",
"app": "forecaster",
"adk-bigquery-tool": "forecast",
},
id="forecast",
),
pytest.param(
lambda tool_context: query_tool.analyze_contribution(
project_id="test-project",
input_data="test-dataset.test-table",
dimension_id_cols=["dim1", "dim2"],
contribution_metric="SUM(metric)",
is_test_col="is_test",
credentials=mock.create_autospec(Credentials, instance=True),
settings=BigQueryToolConfig(
write_mode=WriteMode.ALLOWED,
job_labels={"environment": "prod", "app": "analyzer"},
),
tool_context=tool_context,
),
{
"environment": "prod",
"app": "analyzer",
"adk-bigquery-tool": "analyze_contribution",
},
id="analyze-contribution",
),
pytest.param(
lambda tool_context: query_tool.detect_anomalies(
project_id="test-project",
history_data="SELECT * FROM `test-dataset.test-table`",
times_series_timestamp_col="ts_timestamp",
times_series_data_col="ts_data",
credentials=mock.create_autospec(Credentials, instance=True),
settings=BigQueryToolConfig(
write_mode=WriteMode.ALLOWED,
job_labels={"environment": "prod", "app": "detector"},
),
tool_context=tool_context,
),
{
"environment": "prod",
"app": "detector",
"adk-bigquery-tool": "detect_anomalies",
},
id="detect-anomalies",
),
],
)
def test_ml_tool_user_job_labels_augment_internal_labels(
tool_call, expected_labels
):
"""Test ML tools augment user job_labels with internal labels."""

with mock.patch.object(bigquery, "Client", autospec=True) as Client:
bq_client = Client.return_value

tool_context = mock.create_autospec(ToolContext, instance=True)
tool_context.state.get.return_value = None
tool_call(tool_context)

for call_args_list in [
bq_client.query.call_args_list,
bq_client.query_and_wait.call_args_list,
]:
for call_args in call_args_list:
_, mock_kwargs = call_args
# Verify user labels are preserved and internal label is added
assert mock_kwargs["job_config"].labels == expected_labels


def test_execute_sql_max_rows_config():
"""Test execute_sql tool respects max_query_result_rows from config."""
project = "my_project"
Expand Down Expand Up @@ -2014,3 +2161,93 @@ def test_tool_call_doesnt_change_global_settings(tool_call):

# Test settings write mode after
assert settings.write_mode == WriteMode.ALLOWED


@pytest.mark.parametrize(
("tool_call",),
[
pytest.param(
lambda settings, tool_context: query_tool.execute_sql(
project_id="test-project",
query="SELECT * FROM `test-dataset.test-table`",
credentials=mock.create_autospec(Credentials, instance=True),
settings=settings,
tool_context=tool_context,
),
id="execute-sql",
),
pytest.param(
lambda settings, tool_context: query_tool.forecast(
project_id="test-project",
history_data="SELECT * FROM `test-dataset.test-table`",
timestamp_col="ts_col",
data_col="data_col",
credentials=mock.create_autospec(Credentials, instance=True),
settings=settings,
tool_context=tool_context,
),
id="forecast",
),
pytest.param(
lambda settings, tool_context: query_tool.analyze_contribution(
project_id="test-project",
input_data="test-dataset.test-table",
dimension_id_cols=["dim1", "dim2"],
contribution_metric="SUM(metric)",
is_test_col="is_test",
credentials=mock.create_autospec(Credentials, instance=True),
settings=settings,
tool_context=tool_context,
),
id="analyze-contribution",
),
pytest.param(
lambda settings, tool_context: query_tool.detect_anomalies(
project_id="test-project",
history_data="SELECT * FROM `test-dataset.test-table`",
times_series_timestamp_col="ts_timestamp",
times_series_data_col="ts_data",
credentials=mock.create_autospec(Credentials, instance=True),
settings=settings,
tool_context=tool_context,
),
id="detect-anomalies",
),
],
)
def test_tool_call_doesnt_mutate_job_labels(tool_call):
"""Test query tools don't mutate job_labels in global settings."""
original_labels = {"environment": "test", "team": "data"}
settings = BigQueryToolConfig(
write_mode=WriteMode.ALLOWED,
job_labels=original_labels.copy(),
)
tool_context = mock.create_autospec(ToolContext, instance=True)
tool_context.state.get.return_value = (
"test-bq-session-id",
"_anonymous_dataset",
)

with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
# The mock instance
bq_client = Client.return_value

# Simulate the result of query API
query_job = mock.create_autospec(bigquery.QueryJob)
query_job.destination.dataset_id = "_anonymous_dataset"
bq_client.query.return_value = query_job
bq_client.query_and_wait.return_value = []

# Test job_labels before
assert settings.job_labels == original_labels
assert "adk-bigquery-tool" not in settings.job_labels

# Call the tool
result = tool_call(settings, tool_context)

# Test successful execution of the tool
assert result == {"status": "SUCCESS", "rows": []}

# Test job_labels remain unchanged after tool call
assert settings.job_labels == original_labels
assert "adk-bigquery-tool" not in settings.job_labels
58 changes: 58 additions & 0 deletions tests/unittests/tools/bigquery/test_bigquery_tool_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,61 @@ def test_bigquery_tool_config_invalid_maximum_bytes_billed():
),
):
BigQueryToolConfig(maximum_bytes_billed=10_485_759)


@pytest.mark.parametrize(
"labels",
[
pytest.param(
{"environment": "test", "team": "data"},
id="valid-labels",
),
pytest.param(
{},
id="empty-labels",
),
pytest.param(
None,
id="none-labels",
),
],
)
def test_bigquery_tool_config_valid_labels(labels):
"""Test BigQueryToolConfig accepts valid labels."""
with pytest.warns(UserWarning):
config = BigQueryToolConfig(job_labels=labels)
assert config.job_labels == labels


@pytest.mark.parametrize(
("labels", "message"),
[
pytest.param(
"invalid",
"Input should be a valid dictionary",
id="invalid-type",
),
pytest.param(
{123: "value"},
"Input should be a valid string",
id="non-str-key",
),
pytest.param(
{"key": 123},
"Input should be a valid string",
id="non-str-value",
),
pytest.param(
{"": "value"},
"Label keys cannot be empty",
id="empty-label-key",
),
],
)
def test_bigquery_tool_config_invalid_labels(labels, message):
"""Test BigQueryToolConfig raises an exception with invalid labels."""
with pytest.raises(
ValueError,
match=message,
):
BigQueryToolConfig(job_labels=labels)