Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions crates/rmcp/src/transport/auth.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
Expand Down Expand Up @@ -70,6 +71,9 @@ pub struct AuthorizationMetadata {
pub issuer: Option<String>,
pub jwks_uri: Option<String>,
pub scopes_supported: Option<Vec<String>>,
// allow additional fields
#[serde(flatten)]
pub additional_fields: HashMap<String, serde_json::Value>,
}

/// oauth2 client config
Expand Down Expand Up @@ -100,6 +104,7 @@ type OAuthClient = oauth2::Client<
oauth2::EndpointNotSet,
oauth2::EndpointSet,
>;
type Credentials = (String, Option<OAuthTokenResponse>);

/// oauth2 auth manager
pub struct AuthorizationManager {
Expand All @@ -124,9 +129,12 @@ pub struct ClientRegistrationRequest {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientRegistrationResponse {
pub client_id: String,
pub client_secret: String,
pub client_secret: Option<String>,
pub client_name: String,
pub redirect_uris: Vec<String>,
// allow additional fields
#[serde(flatten)]
pub additional_fields: HashMap<String, serde_json::Value>,
}

impl AuthorizationManager {
Expand Down Expand Up @@ -191,10 +199,22 @@ impl AuthorizationManager {
issuer: None,
jwks_uri: None,
scopes_supported: None,
additional_fields: HashMap::new(),
})
}
}

/// get client id and credentials
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
let credentials = self.credentials.read().await;
let client_id = self
.oauth_client
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?
.client_id();
Ok((client_id.to_string(), credentials.clone()))
}

/// configure oauth2 client with client credentials
pub fn configure_client(&mut self, config: OAuthClientConfig) -> Result<(), AuthError> {
if self.metadata.is_none() {
Expand Down Expand Up @@ -287,6 +307,7 @@ impl AuthorizationManager {
status, error_text
)));
}

debug!("registration response: {:?}", response);
let reg_response = match response.json::<ClientRegistrationResponse>().await {
Ok(response) => response,
Expand All @@ -301,7 +322,7 @@ impl AuthorizationManager {

let config = OAuthClientConfig {
client_id: reg_response.client_id,
client_secret: Some(reg_response.client_secret),
client_secret: reg_response.client_secret,
redirect_uri: redirect_uri.to_string(),
scopes: vec![],
};
Expand All @@ -310,6 +331,18 @@ impl AuthorizationManager {
Ok(config)
}

/// use provided client id to configure oauth2 client instead of dynamic registration
/// this is useful when you have a stored client id from previous registration
pub fn configure_client_id(&mut self, client_id: &str) -> Result<(), AuthError> {
let config = OAuthClientConfig {
client_id: client_id.to_string(),
client_secret: None,
scopes: vec![],
redirect_uri: self.base_url.to_string(),
};
self.configure_client(config)
}

/// generate authorization url
pub async fn get_authorization_url(&self, scopes: &[&str]) -> Result<String, AuthError> {
let oauth_client = self
Expand Down Expand Up @@ -513,6 +546,11 @@ impl AuthorizationSession {
})
}

/// get client_id and credentials
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
self.auth_manager.get_credentials().await
}

/// get authorization url
pub fn get_authorization_url(&self) -> &str {
&self.auth_url
Expand Down Expand Up @@ -590,9 +628,54 @@ impl OAuthState {
if let Some(client) = client {
manager.with_client(client)?;
}

Ok(OAuthState::Unauthorized(manager))
}

/// Get client_id and OAuth credentials
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
// return client_id and credentials
match self {
OAuthState::Unauthorized(manager) | OAuthState::Authorized(manager) => {
manager.get_credentials().await
}
OAuthState::Session(session) => session.get_credentials().await,
OAuthState::AuthorizedHttpClient(client) => client.auth_manager.get_credentials().await,
}
}

/// Manually set credentials and move into authorized state
/// Useful if you're caching credentials externally and wish to reuse them
pub async fn set_credentials(
&mut self,
client_id: &str,
credentials: OAuthTokenResponse,
) -> Result<(), AuthError> {
if let OAuthState::Unauthorized(manager) = self {
let mut manager = std::mem::replace(
manager,
AuthorizationManager::new("http://localhost").await?,
);

// write credentials
*manager.credentials.write().await = Some(credentials);

// discover metadata
let metadata = manager.discover_metadata().await?;
manager.metadata = Some(metadata);

// set client id and secret
manager.configure_client_id(client_id)?;

*self = OAuthState::Authorized(manager);
Ok(())
} else {
Err(AuthError::InternalError(
"Cannot set credentials in this state".to_string(),
))
}
}

/// start authorization
pub async fn start_authorization(
&mut self,
Expand Down
4 changes: 3 additions & 1 deletion examples/servers/src/mcp_oauth_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ async fn oauth_authorization_server() -> impl IntoResponse {
registration_endpoint: format!("http://{}/oauth/register", BIND_ADDRESS),
issuer: Some(BIND_ADDRESS.to_string()),
jwks_uri: Some(format!("http://{}/oauth/jwks", BIND_ADDRESS)),
additional_fields: HashMap::new(),
};
debug!("metadata: {:?}", metadata);
(StatusCode::OK, Json(metadata))
Expand Down Expand Up @@ -567,9 +568,10 @@ async fn oauth_register(
// return client information
let response = ClientRegistrationResponse {
client_id,
client_secret,
client_secret: Some(client_secret),
client_name: req.client_name,
redirect_uris: req.redirect_uris,
additional_fields: HashMap::new(),
};

(StatusCode::CREATED, Json(response)).into_response()
Expand Down
Loading