diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 06b95dcaa..e31709e05 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -251,72 +251,32 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> except ValidationError: pass - def _build_well_known_path(self, pathname: str) -> str: - """Construct well-known path for OAuth metadata discovery.""" - well_known_path = f"/.well-known/oauth-authorization-server{pathname}" - if pathname.endswith("/"): - # Strip trailing slash from pathname to avoid double slashes - well_known_path = well_known_path[:-1] - return well_known_path - - def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool: - """Determine if fallback to root discovery should be attempted.""" - return response_status == 404 and pathname != "/" - - async def _try_metadata_discovery(self, url: str) -> httpx.Request: - """Build metadata discovery request for a specific URL.""" - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - - async def _discover_oauth_metadata(self) -> httpx.Request: - """Build OAuth metadata discovery request with fallback support.""" - if self.context.auth_server_url: - auth_server_url = self.context.auth_server_url - else: - auth_server_url = self.context.server_url - - # Per RFC 8414, try path-aware discovery first + def _get_discovery_urls(self) -> list[str]: + """Generate ordered list of (url, type) tuples for discovery attempts.""" + urls: list[str] = [] + auth_server_url = self.context.auth_server_url or self.context.server_url parsed = urlparse(auth_server_url) - well_known_path = self._build_well_known_path(parsed.path) base_url = f"{parsed.scheme}://{parsed.netloc}" - url = urljoin(base_url, well_known_path) - # Store fallback info for use in response handler - self.context.discovery_base_url = base_url - self.context.discovery_pathname = parsed.path + # RFC 8414: Path-aware OAuth discovery + if parsed.path and parsed.path != "/": + oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oauth_path)) - return await self._try_metadata_discovery(url) + # OAuth root fallback + urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) - async def _discover_oauth_metadata_fallback(self) -> httpx.Request: - """Build fallback OAuth metadata discovery request for legacy servers.""" - base_url = getattr(self.context, "discovery_base_url", "") - if not base_url: - raise OAuthFlowError("No base URL available for fallback discovery") - - # Fallback to root discovery for legacy servers - url = urljoin(base_url, "/.well-known/oauth-authorization-server") - return await self._try_metadata_discovery(url) - - async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool: - """Handle OAuth metadata response. Returns True if handled successfully.""" - if response.status_code == 200: - try: - content = await response.aread() - metadata = OAuthMetadata.model_validate_json(content) - self.context.oauth_metadata = metadata - # Apply default scope if none specified - if self.context.client_metadata.scope is None and metadata.scopes_supported is not None: - self.context.client_metadata.scope = " ".join(metadata.scopes_supported) - return True - except ValidationError: - pass + # RFC 8414 section 5: Path-aware OIDC discovery + # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 + if parsed.path and parsed.path != "/": + oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oidc_path)) - # Check if we should attempt fallback (404 on path-aware discovery) - if not is_fallback and self._should_attempt_fallback( - response.status_code, getattr(self.context, "discovery_pathname", "/") - ): - return False # Signal that fallback should be attempted + # OIDC 1.0 fallback (appends to full URL per OIDC spec) + oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration" + urls.append(oidc_fallback) - return True # Signal no fallback needed (either success or non-404 error) + return urls async def _register_client(self) -> httpx.Request | None: """Build registration request or skip if already registered.""" @@ -511,6 +471,17 @@ def _add_auth_header(self, request: httpx.Request) -> None: if self.context.current_tokens and self.context.current_tokens.access_token: request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" + def _create_oauth_metadata_request(self, url: str) -> httpx.Request: + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + + async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: + content = await response.aread() + metadata = OAuthMetadata.model_validate_json(content) + self.context.oauth_metadata = metadata + # Apply default scope if needed + if self.context.client_metadata.scope is None and metadata.scopes_supported is not None: + self.context.client_metadata.scope = " ".join(metadata.scopes_supported) + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: """HTTPX auth flow integration.""" async with self.context.lock: @@ -544,15 +515,19 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. await self._handle_protected_resource_response(discovery_response) # Step 2: Discover OAuth metadata (with fallback for legacy servers) - oauth_request = await self._discover_oauth_metadata() - oauth_response = yield oauth_request - handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) - - # If path-aware discovery failed with 404, try fallback to root - if not handled: - fallback_request = await self._discover_oauth_metadata_fallback() - fallback_response = yield fallback_request - await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) + discovery_urls = self._get_discovery_urls() + for url in discovery_urls: + request = self._create_oauth_metadata_request(url) + response = yield request + + if response.status_code == 200: + try: + await self._handle_oauth_metadata_response(response) + break + except ValidationError: + continue + elif response.status_code != 404: + break # Non-404 error, stop trying # Step 3: Register client if needed registration_request = await self._register_client() @@ -571,6 +546,6 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. logger.exception("OAuth flow error") raise - # Retry with new tokens - self._add_auth_header(request) - yield request + # Retry with new tokens + self._add_auth_header(request) + yield request diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index ea9c16c78..c47007a4c 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -235,60 +235,13 @@ async def callback_handler() -> tuple[str, str | None]: assert "mcp-protocol-version" in request.headers @pytest.mark.anyio - async def test_discover_oauth_metadata_request(self, oauth_provider): + def test_create_oauth_metadata_request(self, oauth_provider): """Test OAuth metadata discovery request building.""" - request = await oauth_provider._discover_oauth_metadata() + request = oauth_provider._create_oauth_metadata_request("https://example.com") + # Ensure correct method and headers, and that the URL is unmodified assert request.method == "GET" - assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp" - assert "mcp-protocol-version" in request.headers - - @pytest.mark.anyio - async def test_discover_oauth_metadata_request_no_path(self, client_metadata, mock_storage): - """Test OAuth metadata discovery request building when server has no path.""" - - async def redirect_handler(url: str) -> None: - pass - - async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" - - provider = OAuthClientProvider( - server_url="https://api.example.com", - client_metadata=client_metadata, - storage=mock_storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - ) - - request = await provider._discover_oauth_metadata() - - assert request.method == "GET" - assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server" - assert "mcp-protocol-version" in request.headers - - @pytest.mark.anyio - async def test_discover_oauth_metadata_request_trailing_slash(self, client_metadata, mock_storage): - """Test OAuth metadata discovery request building when server path has trailing slash.""" - - async def redirect_handler(url: str) -> None: - pass - - async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" - - provider = OAuthClientProvider( - server_url="https://api.example.com/v1/mcp/", - client_metadata=client_metadata, - storage=mock_storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - ) - - request = await provider._discover_oauth_metadata() - - assert request.method == "GET" - assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp" + assert str(request.url) == "https://example.com" assert "mcp-protocol-version" in request.headers @@ -296,46 +249,16 @@ class TestOAuthFallback: """Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers.""" @pytest.mark.anyio - async def test_fallback_discovery_request(self, client_metadata, mock_storage): - """Test fallback discovery request building.""" - - async def redirect_handler(url: str) -> None: - pass - - async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" - - provider = OAuthClientProvider( - server_url="https://api.example.com/v1/mcp", - client_metadata=client_metadata, - storage=mock_storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - ) - - # Set up discovery state manually as if path-aware discovery was attempted - provider.context.discovery_base_url = "https://api.example.com" - provider.context.discovery_pathname = "/v1/mcp" + async def test_oauth_discovery_fallback_order(self, oauth_provider): + """Test fallback URL construction order.""" + discovery_urls = oauth_provider._get_discovery_urls() - # Test fallback request building - request = await provider._discover_oauth_metadata_fallback() - - assert request.method == "GET" - assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server" - assert "mcp-protocol-version" in request.headers - - @pytest.mark.anyio - async def test_should_attempt_fallback(self, oauth_provider): - """Test fallback decision logic.""" - # Should attempt fallback on 404 with non-root path - assert oauth_provider._should_attempt_fallback(404, "/v1/mcp") - - # Should NOT attempt fallback on 404 with root path - assert not oauth_provider._should_attempt_fallback(404, "/") - - # Should NOT attempt fallback on other status codes - assert not oauth_provider._should_attempt_fallback(200, "/v1/mcp") - assert not oauth_provider._should_attempt_fallback(500, "/v1/mcp") + assert discovery_urls == [ + "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp", + "https://api.example.com/.well-known/oauth-authorization-server", + "https://api.example.com/.well-known/openid-configuration/v1/mcp", + "https://api.example.com/v1/mcp/.well-known/openid-configuration", + ] @pytest.mark.anyio async def test_handle_metadata_response_success(self, oauth_provider): @@ -348,50 +271,11 @@ async def test_handle_metadata_response_success(self, oauth_provider): }""" response = httpx.Response(200, content=content) - # Should return True (success) and set metadata - result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False) - assert result is True + # Should set metadata + await oauth_provider._handle_oauth_metadata_response(response) assert oauth_provider.context.oauth_metadata is not None assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com/" - @pytest.mark.anyio - async def test_handle_metadata_response_404_needs_fallback(self, oauth_provider): - """Test 404 response handling that should trigger fallback.""" - # Set up discovery state for non-root path - oauth_provider.context.discovery_base_url = "https://api.example.com" - oauth_provider.context.discovery_pathname = "/v1/mcp" - - # Mock 404 response - response = httpx.Response(404) - - # Should return False (needs fallback) - result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False) - assert result is False - - @pytest.mark.anyio - async def test_handle_metadata_response_404_no_fallback_needed(self, oauth_provider): - """Test 404 response handling when no fallback is needed.""" - # Set up discovery state for root path - oauth_provider.context.discovery_base_url = "https://api.example.com" - oauth_provider.context.discovery_pathname = "/" - - # Mock 404 response - response = httpx.Response(404) - - # Should return True (no fallback needed) - result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False) - assert result is True - - @pytest.mark.anyio - async def test_handle_metadata_response_404_fallback_attempt(self, oauth_provider): - """Test 404 response handling during fallback attempt.""" - # Mock 404 response during fallback - response = httpx.Response(404) - - # Should return True (fallback attempt complete, no further action needed) - result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=True) - assert result is True - @pytest.mark.anyio async def test_register_client_request(self, oauth_provider): """Test client registration request building.""" diff --git a/tests/shared/test_auth.py b/tests/shared/test_auth.py new file mode 100644 index 000000000..fd39eb255 --- /dev/null +++ b/tests/shared/test_auth.py @@ -0,0 +1,39 @@ +"""Tests for OAuth 2.0 shared code.""" + +from mcp.shared.auth import OAuthMetadata + + +class TestOAuthMetadata: + """Tests for OAuthMetadata parsing.""" + + def test_oauth(self): + """Should not throw when parsing OAuth metadata.""" + OAuthMetadata.model_validate( + { + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/oauth2/authorize", + "token_endpoint": "https://example.com/oauth2/token", + "scopes_supported": ["read", "write"], + "response_types_supported": ["code", "token"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], + } + ) + + def test_oidc(self): + """Should not throw when parsing OIDC metadata.""" + OAuthMetadata.model_validate( + { + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/oauth2/authorize", + "token_endpoint": "https://example.com/oauth2/token", + "end_session_endpoint": "https://example.com/logout", + "id_token_signing_alg_values_supported": ["RS256"], + "jwks_uri": "https://example.com/.well-known/jwks.json", + "response_types_supported": ["code", "token"], + "revocation_endpoint": "https://example.com/oauth2/revoke", + "scopes_supported": ["openid", "read", "write"], + "subject_types_supported": ["public"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], + "userinfo_endpoint": "https://example.com/oauth2/userInfo", + } + )