@@ -208,10 +208,159 @@ async def test_discover_oauth_metadata_request(self, oauth_provider):
208208 """Test OAuth metadata discovery request building."""
209209 request = await oauth_provider ._discover_oauth_metadata ()
210210
211+ assert request .method == "GET"
212+ assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
213+ assert "mcp-protocol-version" in request .headers
214+
215+ @pytest .mark .anyio
216+ async def test_discover_oauth_metadata_request_no_path (self , client_metadata , mock_storage ):
217+ """Test OAuth metadata discovery request building when server has no path."""
218+
219+ async def redirect_handler (url : str ) -> None :
220+ pass
221+
222+ async def callback_handler () -> tuple [str , str | None ]:
223+ return "test_auth_code" , "test_state"
224+
225+ provider = OAuthClientProvider (
226+ server_url = "https://api.example.com" ,
227+ client_metadata = client_metadata ,
228+ storage = mock_storage ,
229+ redirect_handler = redirect_handler ,
230+ callback_handler = callback_handler ,
231+ )
232+
233+ request = await provider ._discover_oauth_metadata ()
234+
235+ assert request .method == "GET"
236+ assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server"
237+ assert "mcp-protocol-version" in request .headers
238+
239+ @pytest .mark .anyio
240+ async def test_discover_oauth_metadata_request_trailing_slash (self , client_metadata , mock_storage ):
241+ """Test OAuth metadata discovery request building when server path has trailing slash."""
242+
243+ async def redirect_handler (url : str ) -> None :
244+ pass
245+
246+ async def callback_handler () -> tuple [str , str | None ]:
247+ return "test_auth_code" , "test_state"
248+
249+ provider = OAuthClientProvider (
250+ server_url = "https://api.example.com/v1/mcp/" ,
251+ client_metadata = client_metadata ,
252+ storage = mock_storage ,
253+ redirect_handler = redirect_handler ,
254+ callback_handler = callback_handler ,
255+ )
256+
257+ request = await provider ._discover_oauth_metadata ()
258+
259+ assert request .method == "GET"
260+ assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
261+ assert "mcp-protocol-version" in request .headers
262+
263+
264+ class TestOAuthFallback :
265+ """Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers."""
266+
267+ @pytest .mark .anyio
268+ async def test_fallback_discovery_request (self , client_metadata , mock_storage ):
269+ """Test fallback discovery request building."""
270+
271+ async def redirect_handler (url : str ) -> None :
272+ pass
273+
274+ async def callback_handler () -> tuple [str , str | None ]:
275+ return "test_auth_code" , "test_state"
276+
277+ provider = OAuthClientProvider (
278+ server_url = "https://api.example.com/v1/mcp" ,
279+ client_metadata = client_metadata ,
280+ storage = mock_storage ,
281+ redirect_handler = redirect_handler ,
282+ callback_handler = callback_handler ,
283+ )
284+
285+ # Set up discovery state manually as if path-aware discovery was attempted
286+ provider .context .discovery_base_url = "https://api.example.com"
287+ provider .context .discovery_pathname = "/v1/mcp"
288+
289+ # Test fallback request building
290+ request = await provider ._discover_oauth_metadata_fallback ()
291+
211292 assert request .method == "GET"
212293 assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server"
213294 assert "mcp-protocol-version" in request .headers
214295
296+ @pytest .mark .anyio
297+ async def test_should_attempt_fallback (self , oauth_provider ):
298+ """Test fallback decision logic."""
299+ # Should attempt fallback on 404 with non-root path
300+ assert oauth_provider ._should_attempt_fallback (404 , "/v1/mcp" )
301+
302+ # Should NOT attempt fallback on 404 with root path
303+ assert not oauth_provider ._should_attempt_fallback (404 , "/" )
304+
305+ # Should NOT attempt fallback on other status codes
306+ assert not oauth_provider ._should_attempt_fallback (200 , "/v1/mcp" )
307+ assert not oauth_provider ._should_attempt_fallback (500 , "/v1/mcp" )
308+
309+ @pytest .mark .anyio
310+ async def test_handle_metadata_response_success (self , oauth_provider ):
311+ """Test successful metadata response handling."""
312+ # Create minimal valid OAuth metadata
313+ content = b"""{
314+ "issuer": "https://auth.example.com",
315+ "authorization_endpoint": "https://auth.example.com/authorize",
316+ "token_endpoint": "https://auth.example.com/token"
317+ }"""
318+ response = httpx .Response (200 , content = content )
319+
320+ # Should return True (success) and set metadata
321+ result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = False )
322+ assert result is True
323+ assert oauth_provider .context .oauth_metadata is not None
324+ assert str (oauth_provider .context .oauth_metadata .issuer ) == "https://auth.example.com/"
325+
326+ @pytest .mark .anyio
327+ async def test_handle_metadata_response_404_needs_fallback (self , oauth_provider ):
328+ """Test 404 response handling that should trigger fallback."""
329+ # Set up discovery state for non-root path
330+ oauth_provider .context .discovery_base_url = "https://api.example.com"
331+ oauth_provider .context .discovery_pathname = "/v1/mcp"
332+
333+ # Mock 404 response
334+ response = httpx .Response (404 )
335+
336+ # Should return False (needs fallback)
337+ result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = False )
338+ assert result is False
339+
340+ @pytest .mark .anyio
341+ async def test_handle_metadata_response_404_no_fallback_needed (self , oauth_provider ):
342+ """Test 404 response handling when no fallback is needed."""
343+ # Set up discovery state for root path
344+ oauth_provider .context .discovery_base_url = "https://api.example.com"
345+ oauth_provider .context .discovery_pathname = "/"
346+
347+ # Mock 404 response
348+ response = httpx .Response (404 )
349+
350+ # Should return True (no fallback needed)
351+ result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = False )
352+ assert result is True
353+
354+ @pytest .mark .anyio
355+ async def test_handle_metadata_response_404_fallback_attempt (self , oauth_provider ):
356+ """Test 404 response handling during fallback attempt."""
357+ # Mock 404 response during fallback
358+ response = httpx .Response (404 )
359+
360+ # Should return True (fallback attempt complete, no further action needed)
361+ result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = True )
362+ assert result is True
363+
215364 @pytest .mark .anyio
216365 async def test_register_client_request (self , oauth_provider ):
217366 """Test client registration request building."""
0 commit comments