Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions packages/toolbox-core/tests/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,70 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header(
tool_instance.add_auth_token_getters(new_auth_getters_causing_conflict)


@pytest.mark.asyncio
async def test_auth_token_getter_overrides_client_header_conflict(
http_session: ClientSession,
sample_tool_description: str,
sample_tool_auth_params: list[ParameterSchema],
auth_token_value: str, # Use fixture for auth token value
auth_getters: dict, # Use fixture for auth getters
):
"""
This test verifies that when both client headers and auth token getters
produce the same header name, the auth token getter value takes precedence
during actual tool invocation.
"""

tool_name = TEST_TOOL_NAME
base_url = HTTPS_BASE_URL
invoke_url = f"{base_url}/api/tool/{tool_name}/invoke"

# Define the conflicting header name and client-provided value
conflicting_header_name = "X-Conflict-Header"
client_header_value = "client-provided-value"
client_headers = {conflicting_header_name: client_header_value}

params_with_auth_source = sample_tool_auth_params # Renaming for clarity

input_args = {"target": "test_target", "token": "dummy_token"}
mock_server_response = {"result": "Auth success"}

with aioresponses() as m:
m.post(invoke_url, status=200, payload=mock_server_response)

tool_instance = ToolboxTool(
session=http_session,
base_url=base_url,
name=tool_name,
description=sample_tool_description,
params=params_with_auth_source,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters=auth_getters,
bound_params={},
client_headers=client_headers,
)

original_get_auth_header = tool_instance._ToolboxTool__get_auth_header
tool_instance._ToolboxTool__get_auth_header = (
lambda auth_service: conflicting_header_name
)
result = await tool_instance(**input_args)

assert result == mock_server_response["result"]

# Verify the request was made with the auth token value overriding the client header
m.assert_called_once_with(
invoke_url,
method="POST",
json=input_args, # The payload is the input_args
headers={conflicting_header_name: auth_token_value},
)

# Restore original method
tool_instance._ToolboxTool__get_auth_header = original_get_auth_header


def test_add_auth_token_getters_unused_token(
http_session: ClientSession,
sample_tool_params: list[ParameterSchema],
Expand Down