Skip to content

[WIP][POC] Update logic for constructing protected resource metadata URL to include the resource's path component, for compatibility with Oauth RFCs #1179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ cython_debug/
.vscode/
.windsurfrules
**/CLAUDE.local.md
.idea
52 changes: 37 additions & 15 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,6 @@ class OAuthContext:
# State
lock: anyio.Lock = field(default_factory=anyio.Lock)

# Discovery state for fallback support
discovery_base_url: str | None = None
discovery_pathname: str | None = None

def get_authorization_base_url(self, server_url: str) -> str:
"""Extract base URL by removing path component."""
parsed = urlparse(server_url)
Expand Down Expand Up @@ -228,16 +224,23 @@ def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response

return None

async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request:
# RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response
url = self._extract_resource_metadata_from_www_auth(init_response)
def _get_protected_resource_discovery_urls(self) -> list[str]:
"""Generate ordered list of URLs for protected resource discovery attempts."""
urls: list[str] = []
parsed = urlparse(self.context.server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"

if not url:
# Fallback to well-known discovery
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
if parsed.path and parsed.path != "/":
# Try path-specific endpoint first
path_component = parsed.path.rstrip("/")
urls.append(urljoin(base_url, f"/.well-known/oauth-protected-resource{path_component}"))
# Then fallback to base endpoint
urls.append(urljoin(base_url, "/.well-known/oauth-protected-resource"))
else:
# No path, just use base endpoint
urls.append(urljoin(base_url, "/.well-known/oauth-protected-resource"))

return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
return urls

async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
"""Handle discovery response."""
Expand Down Expand Up @@ -510,9 +513,28 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
try:
# OAuth flow must be inline due to generator constraints
# Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support)
discovery_request = await self._discover_protected_resource(response)
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)
# Check if WWW-Authenticate provides resource_metadata URL first
www_auth_url = self._extract_resource_metadata_from_www_auth(response)
if www_auth_url:
discovery_request = httpx.Request(
"GET", www_auth_url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
)
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)
else:
# Try well-known discovery URLs with fallback
discovery_urls = self._get_protected_resource_discovery_urls()
for url in discovery_urls:
discovery_request = httpx.Request(
"GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
)
discovery_response = yield discovery_request

if discovery_response.status_code == 200:
await self._handle_protected_resource_response(discovery_response)
break # Success, stop trying other URLs
elif discovery_response.status_code != 404:
break # Non-404 error, stop trying
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice we may want to keep retrying on other types of errors too, just to be safe


# Step 2: Discover OAuth metadata (with fallback for legacy servers)
discovery_urls = self._get_discovery_urls()
Expand Down
208 changes: 189 additions & 19 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,42 +198,41 @@ class TestOAuthFlow:
"""Test OAuth flow methods."""

@pytest.mark.anyio
async def test_discover_protected_resource_request(self, client_metadata, mock_storage):
"""Test protected resource discovery request building maintains backward compatibility."""
async def test_protected_resource_discovery_urls_generation(self, client_metadata, mock_storage):
"""Test that discovery URL generation works correctly for different server URLs."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

# Test with path component - should have both path-specific and base endpoints
provider = OAuthClientProvider(
server_url="https://api.example.com",
server_url="https://api.example.com/api/2.0/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

# Test without WWW-Authenticate (fallback)
init_response = httpx.Response(
status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com")
)

request = await provider._discover_protected_resource(init_response)
assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
assert "mcp-protocol-version" in request.headers
urls = provider._get_protected_resource_discovery_urls()
assert urls == [
"https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp",
"https://api.example.com/.well-known/oauth-protected-resource",
]

# Test with WWW-Authenticate header
init_response.headers["WWW-Authenticate"] = (
'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"'
# Test without path component - should only have base endpoint
provider = OAuthClientProvider(
server_url="https://api.example.com",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

request = await provider._discover_protected_resource(init_response)
assert request.method == "GET"
assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path"
assert "mcp-protocol-version" in request.headers
urls = provider._get_protected_resource_discovery_urls()
assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"]

@pytest.mark.anyio
def test_create_oauth_metadata_request(self, oauth_provider):
Expand Down Expand Up @@ -595,6 +594,177 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider, mock_storage):
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
assert oauth_provider.context.token_expiry_time is not None

@pytest.mark.anyio
async def test_auth_flow_protected_resource_fallback(self, client_metadata, mock_storage):
"""Test that the OAuth flow correctly implements fallback from path-specific to base endpoint."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/api/2.0/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

provider.context.current_tokens = None
provider.context.token_expiry_time = None
provider._initialized = True

test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp")
auth_flow = provider.async_auth_flow(test_request)

# Step 1: Original request without auth
request = await auth_flow.__anext__()
assert "Authorization" not in request.headers

# Step 2: 401 triggers protected resource discovery - should try path-specific first
response = httpx.Response(401, request=test_request)
path_discovery_request = await auth_flow.asend(response)
assert (
str(path_discovery_request.url)
== "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp"
)

# Step 3: Path-specific fails with 404 - should trigger fallback
path_404_response = httpx.Response(404, request=path_discovery_request)
base_discovery_request = await auth_flow.asend(path_404_response)
assert str(base_discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"

# Step 4: Base endpoint succeeds - should store metadata and continue to OAuth discovery
successful_response = httpx.Response(
200,
content=b'{"resource": "https://api.example.com", "authorization_servers": ["https://api.example.com"]}',
request=base_discovery_request,
)

# Verify the fallback worked and metadata was stored
await auth_flow.asend(successful_response)
assert provider.context.protected_resource_metadata is not None
assert str(provider.context.protected_resource_metadata.resource) == "https://api.example.com/"

# Clean up the generator
try:
await auth_flow.aclose()
except Exception:
pass

@pytest.mark.anyio
async def test_auth_flow_www_authenticate_no_fallback(self, client_metadata, mock_storage):
"""Test that WWW-Authenticate header skips fallback logic entirely."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/api/2.0/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

provider.context.current_tokens = None
provider.context.token_expiry_time = None
provider._initialized = True

test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp")
auth_flow = provider.async_auth_flow(test_request)

# Step 1: Original request without auth
request = await auth_flow.__anext__()
assert "Authorization" not in request.headers

# Step 2: 401 with WWW-Authenticate should use that URL directly
response = httpx.Response(
401,
headers={
"WWW-Authenticate": 'Bearer resource_metadata="https://custom.example.com/.well-known/oauth-protected-resource"'
},
request=test_request,
)

www_auth_request = await auth_flow.asend(response)
assert str(www_auth_request.url) == "https://custom.example.com/.well-known/oauth-protected-resource"

# Step 3: Should proceed directly to OAuth metadata discovery (no fallback attempted)
successful_response = httpx.Response(
200,
content=b'{"resource": "https://api.example.com/api/2.0/mcp", "authorization_servers": ["https://api.example.com"]}',
request=www_auth_request,
)

await auth_flow.asend(successful_response)
assert provider.context.protected_resource_metadata is not None

# Clean up the generator
try:
await auth_flow.aclose()
except Exception:
pass

@pytest.mark.anyio
async def test_auth_flow_no_fallback_on_success(self, client_metadata, mock_storage):
"""Test that first successful discovery response stops the fallback process."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/api/2.0/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

provider.context.current_tokens = None
provider.context.token_expiry_time = None
provider._initialized = True

test_request = httpx.Request("GET", "https://api.example.com/api/2.0/mcp")
auth_flow = provider.async_auth_flow(test_request)

# Step 1: Original request without auth
request = await auth_flow.__anext__()
assert "Authorization" not in request.headers

# Step 2: 401 triggers path-specific discovery
response = httpx.Response(401, request=test_request)
path_discovery_request = await auth_flow.asend(response)
assert (
str(path_discovery_request.url)
== "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp"
)

# Step 3: Path-specific succeeds - should skip fallback and go to OAuth discovery
successful_response = httpx.Response(
200,
content=b'{"resource": "https://api.example.com/api/2.0/mcp", "authorization_servers": ["https://api.example.com"]}',
request=path_discovery_request,
)

await auth_flow.asend(successful_response)
assert provider.context.protected_resource_metadata is not None
assert str(provider.context.protected_resource_metadata.resource) == "https://api.example.com/api/2.0/mcp"

# Clean up the generator
try:
await auth_flow.aclose()
except Exception:
pass


@pytest.mark.parametrize(
(
Expand Down
Loading