diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/tools/bigquery/config.py b/src/google/adk/tools/bigquery/config.py index 7768f214ed..047c075ea1 100644 --- a/src/google/adk/tools/bigquery/config.py +++ b/src/google/adk/tools/bigquery/config.py @@ -101,6 +101,16 @@ class BigQueryToolConfig(BaseModel): locations, see https://cloud.google.com/bigquery/docs/locations. """ + job_labels: Optional[dict[str, str]] = None + """Labels to apply to BigQuery jobs for tracking and monitoring. + + These labels will be added to all BigQuery jobs executed by the execute_sql + function. Labels must be key-value pairs where both keys and values are + strings. Labels can be used for billing, monitoring, and resource organization. + For more information about labels, see + https://cloud.google.com/bigquery/docs/labels-intro. + """ + @field_validator('maximum_bytes_billed') @classmethod def validate_maximum_bytes_billed(cls, v): @@ -121,3 +131,13 @@ def validate_application_name(cls, v): if v and ' ' in v: raise ValueError('Application name should not contain spaces.') return v + + @field_validator('job_labels') + @classmethod + def validate_job_labels(cls, v): + """Validate that job_labels keys are not empty.""" + if v is not None: + for key in v.keys(): + if not key: + raise ValueError('Label keys cannot be empty.') + return v diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 666dc3c5a1..5bcd734e70 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -68,7 +68,10 @@ def _execute_sql( bq_connection_properties = [] # BigQuery job labels if applicable - bq_job_labels = {} + bq_job_labels = ( + settings.job_labels.copy() if settings and settings.job_labels else {} + ) + if caller_id: bq_job_labels["adk-bigquery-tool"] = caller_id if settings and settings.application_name: diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index eef83a1f5e..1791100e1f 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -1709,6 +1709,65 @@ def test_execute_sql_job_labels( } +@pytest.mark.parametrize( + ("write_mode", "dry_run", "query_call_count", "query_and_wait_call_count"), + [ + pytest.param(WriteMode.ALLOWED, False, 0, 1, id="write-allowed"), + pytest.param(WriteMode.ALLOWED, True, 1, 0, id="write-allowed-dry-run"), + pytest.param(WriteMode.BLOCKED, False, 1, 1, id="write-blocked"), + pytest.param(WriteMode.BLOCKED, True, 2, 0, id="write-blocked-dry-run"), + pytest.param(WriteMode.PROTECTED, False, 2, 1, id="write-protected"), + pytest.param( + WriteMode.PROTECTED, True, 3, 0, id="write-protected-dry-run" + ), + ], +) +def test_execute_sql_user_job_labels_augment_internal_labels( + write_mode, dry_run, query_call_count, query_and_wait_call_count +): + """Test execute_sql tool augments user job_labels with internal labels.""" + project = "my_project" + query = "SELECT 123 AS num" + statement_type = "SELECT" + credentials = mock.create_autospec(Credentials, instance=True) + user_labels = {"environment": "test", "team": "data"} + tool_settings = BigQueryToolConfig( + write_mode=write_mode, + job_labels=user_labels, + ) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = None + + with mock.patch.object(bigquery, "Client", autospec=True) as Client: + bq_client = Client.return_value + + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + bq_client.query.return_value = query_job + + query_tool.execute_sql( + project, + query, + credentials, + tool_settings, + tool_context, + dry_run=dry_run, + ) + + assert bq_client.query.call_count == query_call_count + assert bq_client.query_and_wait.call_count == query_and_wait_call_count + # Build expected labels from user_labels + internal label + expected_labels = {**user_labels, "adk-bigquery-tool": "execute_sql"} + for call_args_list in [ + bq_client.query.call_args_list, + bq_client.query_and_wait.call_args_list, + ]: + for call_args in call_args_list: + _, mock_kwargs = call_args + # Verify user labels are preserved and internal label is added + assert mock_kwargs["job_config"].labels == expected_labels + + @pytest.mark.parametrize( ("tool_call", "expected_tool_label"), [ @@ -1850,6 +1909,94 @@ def test_ml_tool_job_labels_w_application_name(tool_call, expected_tool_label): assert mock_kwargs["job_config"].labels == expected_labels +@pytest.mark.parametrize( + ("tool_call", "expected_labels"), + [ + pytest.param( + lambda tool_context: query_tool.forecast( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + timestamp_col="ts_col", + data_col="data_col", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "forecaster"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "forecaster", + "adk-bigquery-tool": "forecast", + }, + id="forecast", + ), + pytest.param( + lambda tool_context: query_tool.analyze_contribution( + project_id="test-project", + input_data="test-dataset.test-table", + dimension_id_cols=["dim1", "dim2"], + contribution_metric="SUM(metric)", + is_test_col="is_test", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "analyzer"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "analyzer", + "adk-bigquery-tool": "analyze_contribution", + }, + id="analyze-contribution", + ), + pytest.param( + lambda tool_context: query_tool.detect_anomalies( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + times_series_timestamp_col="ts_timestamp", + times_series_data_col="ts_data", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "detector"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "detector", + "adk-bigquery-tool": "detect_anomalies", + }, + id="detect-anomalies", + ), + ], +) +def test_ml_tool_user_job_labels_augment_internal_labels( + tool_call, expected_labels +): + """Test ML tools augment user job_labels with internal labels.""" + + with mock.patch.object(bigquery, "Client", autospec=True) as Client: + bq_client = Client.return_value + + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = None + tool_call(tool_context) + + for call_args_list in [ + bq_client.query.call_args_list, + bq_client.query_and_wait.call_args_list, + ]: + for call_args in call_args_list: + _, mock_kwargs = call_args + # Verify user labels are preserved and internal label is added + assert mock_kwargs["job_config"].labels == expected_labels + + def test_execute_sql_max_rows_config(): """Test execute_sql tool respects max_query_result_rows from config.""" project = "my_project" @@ -2014,3 +2161,93 @@ def test_tool_call_doesnt_change_global_settings(tool_call): # Test settings write mode after assert settings.write_mode == WriteMode.ALLOWED + + +@pytest.mark.parametrize( + ("tool_call",), + [ + pytest.param( + lambda settings, tool_context: query_tool.execute_sql( + project_id="test-project", + query="SELECT * FROM `test-dataset.test-table`", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="execute-sql", + ), + pytest.param( + lambda settings, tool_context: query_tool.forecast( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + timestamp_col="ts_col", + data_col="data_col", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="forecast", + ), + pytest.param( + lambda settings, tool_context: query_tool.analyze_contribution( + project_id="test-project", + input_data="test-dataset.test-table", + dimension_id_cols=["dim1", "dim2"], + contribution_metric="SUM(metric)", + is_test_col="is_test", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="analyze-contribution", + ), + pytest.param( + lambda settings, tool_context: query_tool.detect_anomalies( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + times_series_timestamp_col="ts_timestamp", + times_series_data_col="ts_data", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="detect-anomalies", + ), + ], +) +def test_tool_call_doesnt_mutate_job_labels(tool_call): + """Test query tools don't mutate job_labels in global settings.""" + original_labels = {"environment": "test", "team": "data"} + settings = BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels=original_labels.copy(), + ) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = ( + "test-bq-session-id", + "_anonymous_dataset", + ) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.destination.dataset_id = "_anonymous_dataset" + bq_client.query.return_value = query_job + bq_client.query_and_wait.return_value = [] + + # Test job_labels before + assert settings.job_labels == original_labels + assert "adk-bigquery-tool" not in settings.job_labels + + # Call the tool + result = tool_call(settings, tool_context) + + # Test successful execution of the tool + assert result == {"status": "SUCCESS", "rows": []} + + # Test job_labels remain unchanged after tool call + assert settings.job_labels == original_labels + assert "adk-bigquery-tool" not in settings.job_labels diff --git a/tests/unittests/tools/bigquery/test_bigquery_tool_config.py b/tests/unittests/tools/bigquery/test_bigquery_tool_config.py index 5854c97797..072ccea7d0 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_tool_config.py +++ b/tests/unittests/tools/bigquery/test_bigquery_tool_config.py @@ -77,3 +77,61 @@ def test_bigquery_tool_config_invalid_maximum_bytes_billed(): ), ): BigQueryToolConfig(maximum_bytes_billed=10_485_759) + + +@pytest.mark.parametrize( + "labels", + [ + pytest.param( + {"environment": "test", "team": "data"}, + id="valid-labels", + ), + pytest.param( + {}, + id="empty-labels", + ), + pytest.param( + None, + id="none-labels", + ), + ], +) +def test_bigquery_tool_config_valid_labels(labels): + """Test BigQueryToolConfig accepts valid labels.""" + with pytest.warns(UserWarning): + config = BigQueryToolConfig(job_labels=labels) + assert config.job_labels == labels + + +@pytest.mark.parametrize( + ("labels", "message"), + [ + pytest.param( + "invalid", + "Input should be a valid dictionary", + id="invalid-type", + ), + pytest.param( + {123: "value"}, + "Input should be a valid string", + id="non-str-key", + ), + pytest.param( + {"key": 123}, + "Input should be a valid string", + id="non-str-value", + ), + pytest.param( + {"": "value"}, + "Label keys cannot be empty", + id="empty-label-key", + ), + ], +) +def test_bigquery_tool_config_invalid_labels(labels, message): + """Test BigQueryToolConfig raises an exception with invalid labels.""" + with pytest.raises( + ValueError, + match=message, + ): + BigQueryToolConfig(job_labels=labels)