From 9681fc46945ec303029bd44de7b921f9fcc093ce Mon Sep 17 00:00:00 2001 From: the-rooster Date: Mon, 5 May 2025 08:57:49 -0600 Subject: [PATCH] feat(oauth): fixes + cache client credentials --- crates/rmcp/src/transport/auth.rs | 87 +++++++++++++++++++++++- examples/servers/src/mcp_oauth_server.rs | 4 +- 2 files changed, 88 insertions(+), 3 deletions(-) diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index cfe7b9a9..ea1d96ae 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -1,4 +1,5 @@ use std::{ + collections::HashMap, sync::Arc, time::{Duration, Instant}, }; @@ -70,6 +71,9 @@ pub struct AuthorizationMetadata { pub issuer: Option, pub jwks_uri: Option, pub scopes_supported: Option>, + // allow additional fields + #[serde(flatten)] + pub additional_fields: HashMap, } /// oauth2 client config @@ -100,6 +104,7 @@ type OAuthClient = oauth2::Client< oauth2::EndpointNotSet, oauth2::EndpointSet, >; +type Credentials = (String, Option); /// oauth2 auth manager pub struct AuthorizationManager { @@ -124,9 +129,12 @@ pub struct ClientRegistrationRequest { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ClientRegistrationResponse { pub client_id: String, - pub client_secret: String, + pub client_secret: Option, pub client_name: String, pub redirect_uris: Vec, + // allow additional fields + #[serde(flatten)] + pub additional_fields: HashMap, } impl AuthorizationManager { @@ -191,10 +199,22 @@ impl AuthorizationManager { issuer: None, jwks_uri: None, scopes_supported: None, + additional_fields: HashMap::new(), }) } } + /// get client id and credentials + pub async fn get_credentials(&self) -> Result { + let credentials = self.credentials.read().await; + let client_id = self + .oauth_client + .as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))? + .client_id(); + Ok((client_id.to_string(), credentials.clone())) + } + /// configure oauth2 client with client credentials pub fn configure_client(&mut self, config: OAuthClientConfig) -> Result<(), AuthError> { if self.metadata.is_none() { @@ -287,6 +307,7 @@ impl AuthorizationManager { status, error_text ))); } + debug!("registration response: {:?}", response); let reg_response = match response.json::().await { Ok(response) => response, @@ -301,7 +322,7 @@ impl AuthorizationManager { let config = OAuthClientConfig { client_id: reg_response.client_id, - client_secret: Some(reg_response.client_secret), + client_secret: reg_response.client_secret, redirect_uri: redirect_uri.to_string(), scopes: vec![], }; @@ -310,6 +331,18 @@ impl AuthorizationManager { Ok(config) } + /// use provided client id to configure oauth2 client instead of dynamic registration + /// this is useful when you have a stored client id from previous registration + pub fn configure_client_id(&mut self, client_id: &str) -> Result<(), AuthError> { + let config = OAuthClientConfig { + client_id: client_id.to_string(), + client_secret: None, + scopes: vec![], + redirect_uri: self.base_url.to_string(), + }; + self.configure_client(config) + } + /// generate authorization url pub async fn get_authorization_url(&self, scopes: &[&str]) -> Result { let oauth_client = self @@ -513,6 +546,11 @@ impl AuthorizationSession { }) } + /// get client_id and credentials + pub async fn get_credentials(&self) -> Result { + self.auth_manager.get_credentials().await + } + /// get authorization url pub fn get_authorization_url(&self) -> &str { &self.auth_url @@ -590,9 +628,54 @@ impl OAuthState { if let Some(client) = client { manager.with_client(client)?; } + Ok(OAuthState::Unauthorized(manager)) } + /// Get client_id and OAuth credentials + pub async fn get_credentials(&self) -> Result { + // return client_id and credentials + match self { + OAuthState::Unauthorized(manager) | OAuthState::Authorized(manager) => { + manager.get_credentials().await + } + OAuthState::Session(session) => session.get_credentials().await, + OAuthState::AuthorizedHttpClient(client) => client.auth_manager.get_credentials().await, + } + } + + /// Manually set credentials and move into authorized state + /// Useful if you're caching credentials externally and wish to reuse them + pub async fn set_credentials( + &mut self, + client_id: &str, + credentials: OAuthTokenResponse, + ) -> Result<(), AuthError> { + if let OAuthState::Unauthorized(manager) = self { + let mut manager = std::mem::replace( + manager, + AuthorizationManager::new("http://localhost").await?, + ); + + // write credentials + *manager.credentials.write().await = Some(credentials); + + // discover metadata + let metadata = manager.discover_metadata().await?; + manager.metadata = Some(metadata); + + // set client id and secret + manager.configure_client_id(client_id)?; + + *self = OAuthState::Authorized(manager); + Ok(()) + } else { + Err(AuthError::InternalError( + "Cannot set credentials in this state".to_string(), + )) + } + } + /// start authorization pub async fn start_authorization( &mut self, diff --git a/examples/servers/src/mcp_oauth_server.rs b/examples/servers/src/mcp_oauth_server.rs index 2a6472d3..8047434f 100644 --- a/examples/servers/src/mcp_oauth_server.rs +++ b/examples/servers/src/mcp_oauth_server.rs @@ -525,6 +525,7 @@ async fn oauth_authorization_server() -> impl IntoResponse { registration_endpoint: format!("http://{}/oauth/register", BIND_ADDRESS), issuer: Some(BIND_ADDRESS.to_string()), jwks_uri: Some(format!("http://{}/oauth/jwks", BIND_ADDRESS)), + additional_fields: HashMap::new(), }; debug!("metadata: {:?}", metadata); (StatusCode::OK, Json(metadata)) @@ -567,9 +568,10 @@ async fn oauth_register( // return client information let response = ClientRegistrationResponse { client_id, - client_secret, + client_secret: Some(client_secret), client_name: req.client_name, redirect_uris: req.redirect_uris, + additional_fields: HashMap::new(), }; (StatusCode::CREATED, Json(response)).into_response()