diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 5f16a6e58..c24042da4 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -295,9 +295,9 @@ buildvariants: - name: oidc display_name: OIDC - patchable: false + patchable: true run_on: - - rhel87-small + - ubuntu2204-small expansions: AUTH: auth SSL: ssl diff --git a/src/client/auth.rs b/src/client/auth.rs index 97ae1e89e..7f81eaa80 100644 --- a/src/client/auth.rs +++ b/src/client/auth.rs @@ -195,7 +195,7 @@ impl AuthMechanism { )); } // TODO RUST-1660: Handle specific provider validation, perhaps also do Azure as - // part of this ticket. + // part of this ticket. Specific providers will add predefined oidc_callback here if credential .source .as_ref() @@ -279,7 +279,9 @@ impl AuthMechanism { x509::build_speculative_client_first(credential), )))), Self::Plain => Ok(None), - Self::MongoDbOidc => Ok(None), + Self::MongoDbOidc => Ok(Some(ClientFirst::Oidc(Box::new( + oidc::build_speculative_client_first(credential), + )))), #[cfg(feature = "aws-auth")] AuthMechanism::MongoDbAws => Ok(None), AuthMechanism::MongoDbCr => Err(ErrorKind::Authentication { @@ -332,7 +334,7 @@ impl AuthMechanism { } .into()), AuthMechanism::MongoDbOidc => { - oidc::authenticate_stream(stream, credential, server_api).await + oidc::authenticate_stream(stream, credential, server_api, None).await } _ => Err(ErrorKind::Authentication { message: format!("Authentication mechanism {:?} not yet implemented.", self), @@ -459,6 +461,9 @@ impl Credential { FirstRound::X509(server_first) => { x509::authenticate_stream(conn, self, server_api, server_first).await } + FirstRound::Oidc(server_first) => { + oidc::authenticate_stream(conn, self, server_api, server_first).await + } }; } @@ -517,6 +522,7 @@ impl Debug for Credential { pub(crate) enum ClientFirst { Scram(ScramVersion, scram::ClientFirst), X509(Box), + Oidc(Box), } impl ClientFirst { @@ -524,6 +530,7 @@ impl ClientFirst { match self { Self::Scram(version, client_first) => client_first.to_command(version).body, Self::X509(command) => command.body.clone(), + Self::Oidc(command) => command.body.clone(), } } @@ -537,6 +544,7 @@ impl ClientFirst { }, ), Self::X509(..) => FirstRound::X509(server_first), + Self::Oidc(..) => FirstRound::Oidc(server_first), } } } @@ -547,6 +555,7 @@ impl ClientFirst { pub(crate) enum FirstRound { Scram(ScramVersion, scram::FirstRound), X509(Document), + Oidc(Document), } pub(crate) fn generate_nonce_bytes() -> [u8; 32] { diff --git a/src/client/auth/oidc.rs b/src/client/auth/oidc.rs index c7554f0b9..16be3a816 100644 --- a/src/client/auth/oidc.rs +++ b/src/client/auth/oidc.rs @@ -1,11 +1,8 @@ +use serde::Deserialize; use std::{ - sync::Arc, + sync::{Arc, RwLock}, time::{Duration, Instant}, }; -use tokio::sync::RwLock; - -use bson::rawdoc; -use serde::Deserialize; use typed_builder::TypedBuilder; use crate::{ @@ -16,13 +13,19 @@ use crate::{ }, options::ServerApi, }, - cmap::Connection, + cmap::{Command, Connection}, error::{Error, Result}, BoxFuture, }; +use bson::{doc, rawdoc, Document}; use super::{sasl::SaslContinue, Credential, MONGODB_OIDC_STR}; +const HUMAN_CALLBACK_TIMEOUT: Duration = Duration::from_secs(5 * 60); +const MACHINE_CALLBACK_TIMEOUT: Duration = Duration::from_secs(60); +const MACHINE_INVALIDATE_SLEEP_TIMEOUT: Duration = Duration::from_millis(100); +const API_VERSION: u32 = 1; + /// The user-supplied callbacks for OIDC authentication. #[derive(Clone)] pub struct State { @@ -30,18 +33,6 @@ pub struct State { cache: Arc>, } -impl State { - pub(crate) async fn get_refresh_token(&self) -> Option { - self.cache.read().await.refresh_token.clone() - } - - // TODO RUST-1662: This function will actually be used. - #[allow(dead_code)] - pub(crate) async fn get_access_token(&self) -> Option { - self.cache.read().await.access_token.clone() - } -} - #[derive(Clone)] #[non_exhaustive] pub struct Callback { @@ -120,28 +111,19 @@ pub struct CallbackInner { f: Box BoxFuture<'static, Result> + Send + Sync>, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Cache { + idp_server_info: Option, refresh_token: Option, access_token: Option, - token_gen_id: i32, + token_gen_id: u32, last_call_time: Instant, } -impl Clone for Cache { - fn clone(&self) -> Self { - Self { - refresh_token: self.refresh_token.clone(), - access_token: self.access_token.clone(), - token_gen_id: self.token_gen_id, - last_call_time: self.last_call_time, - } - } -} - impl Cache { fn new() -> Self { Self { + idp_server_info: None, refresh_token: None, access_token: None, token_gen_id: 0, @@ -163,7 +145,7 @@ pub struct IdpServerInfo { #[non_exhaustive] pub struct CallbackContext { pub timeout_seconds: Option, - pub version: i32, + pub version: u32, pub refresh_token: Option, pub idp_info: Option, } @@ -177,13 +159,72 @@ pub struct IdpServerResponse { pub refresh_token: Option, } +pub(crate) fn build_speculative_client_first(credential: &Credential) -> Command { + self::build_client_first(credential, None) +} + +/// Constructs the first client message in the OIDC handshake for speculative authentication +pub(crate) fn build_client_first( + credential: &Credential, + server_api: Option<&ServerApi>, +) -> Command { + let mut auth_command_doc = doc! { + "authenticate": 1, + "mechanism": MONGODB_OIDC_STR, + }; + + if credential.oidc_callback.is_none() { + auth_command_doc.insert("jwt", ""); + } else if let Some(access_token) = get_access_token(credential) { + auth_command_doc.insert("jwt", access_token); + } + + let mut command = Command::new("authenticate", "$external", auth_command_doc); + if let Some(server_api) = server_api { + command.set_server_api(server_api); + } + + command +} + +fn get_access_token(credential: &Credential) -> Option { + credential + .oidc_callback + .as_ref() + .unwrap() + .cache + .read() + .unwrap() + .access_token + .clone() +} + +fn get_refresh_token_and_idp_info( + credential: &Credential, +) -> (Option, Option) { + let cache = credential + .oidc_callback + .as_ref() + // this unwrap is safe because this function is only called from within authenticate_human + .unwrap() + .cache + .read() + .unwrap(); + let refresh_token = cache.refresh_token.clone(); + let idp_info = cache.idp_server_info.clone(); + (refresh_token, idp_info) +} + pub(crate) async fn authenticate_stream( conn: &mut Connection, credential: &Credential, server_api: Option<&ServerApi>, + server_first: impl Into>, ) -> Result<()> { - // RUST-1662: Attempt speculative auth first, only works with a cache. - // First handle speculative authentication. If that succeeds, we are done. + if server_first.into().is_some() { + // speculative authentication succeeded, no need to authenticate again + return Ok(()); + } let Callback { inner, kind } = credential .oidc_callback @@ -197,38 +238,63 @@ pub(crate) async fn authenticate_stream( } } -async fn update_oidc_cache( +fn update_caches( + conn: &Connection, credential: &Credential, response: &IdpServerResponse, - token_gen_id: i32, + idp_server_info: Option, ) { - { - let mut cache = credential + let mut token_gen_id = conn.oidc_token_gen_id.write().unwrap(); + let mut cred_cache = credential .oidc_callback .as_ref() // unwrap() is safe here because authenticate_human is only called if oidc_callback is Some .unwrap() .cache .write() - .await; - cache.access_token = Some(response.access_token.clone()); - cache.refresh_token = response.refresh_token.clone(); - cache.last_call_time = Instant::now(); - cache.token_gen_id = token_gen_id; + .unwrap(); + + if idp_server_info.is_some() { + cred_cache.idp_server_info = idp_server_info; } + cred_cache.access_token = Some(response.access_token.clone()); + cred_cache.refresh_token = response.refresh_token.clone(); + cred_cache.last_call_time = Instant::now(); + cred_cache.token_gen_id += 1; + *token_gen_id = cred_cache.token_gen_id; } -async fn authenticate_human( +fn invalidate_caches(conn: &Connection, credential: &Credential) { + let mut token_gen_id = conn.oidc_token_gen_id.write().unwrap(); + let mut cred_cache = credential + .oidc_callback + .as_ref() + // unwrap() is safe here because authenticate_human/machine is only called if oidc_callback is Some + .unwrap() + .cache + .write() + .unwrap(); + // It should be impossible for token_gen_id to be > cache.token_gen_id, but we check just in + // case + if *token_gen_id >= cred_cache.token_gen_id { + cred_cache.access_token = None; + *token_gen_id = 0; + } +} + +// send_sasl_start_command creates and sends a sasl_start command handling either +// one step or two step sasl based on whether or not the access token is Some. +async fn send_sasl_start_command( + source: &str, conn: &mut Connection, credential: &Credential, server_api: Option<&ServerApi>, - callback: Arc, -) -> Result<()> { - // TODO RUST-1662: Use the Cached credential and add Cache invalidation - // this differs from the machine flow in that we will also try the refresh token - let source = credential.source.as_deref().unwrap_or("$external"); + access_token: Option, +) -> Result { let mut start_doc = rawdoc! {}; - if let Some(username) = credential.username.as_deref() { + if let Some(access_token) = access_token { + start_doc.append("jwt", access_token); + } else if let Some(username) = credential.username.as_deref() { start_doc.append("n", username); } let sasl_start = SaslStart::new( @@ -238,33 +304,37 @@ async fn authenticate_human( server_api.cloned(), ) .into_command(); - let response = send_sasl_command(conn, sasl_start).await?; + send_sasl_command(conn, sasl_start).await +} + +async fn do_two_step_auth( + source: &str, + conn: &mut Connection, + credential: &Credential, + server_api: Option<&ServerApi>, + callback: Arc, + timeout: Duration, +) -> Result<()> { + let response = send_sasl_start_command(source, conn, credential, server_api, None).await?; if response.done { return Err(invalid_auth_response()); } - // Even though most caching will be handled in RUST-1662, the refresh token only exists in the - // cache, so we need to access the cache to get it - let refresh_token = credential.oidc_callback.as_ref() - // this unwrap is safe because we are in the authenticate_human function which only gets called if oidc_callback is Some - .unwrap().get_refresh_token().await; - + let server_info: IdpServerInfo = + bson::from_slice(&response.payload).map_err(|_| invalid_auth_response())?; let idp_response = { - let server_info: IdpServerInfo = - bson::from_slice(&response.payload).map_err(|_| invalid_auth_response())?; - const CALLBACK_TIMEOUT: Duration = Duration::from_secs(5 * 60); let cb_context = CallbackContext { - timeout_seconds: Some(Instant::now() + CALLBACK_TIMEOUT), - version: 1, - refresh_token, - idp_info: Some(server_info), + timeout_seconds: Some(Instant::now() + timeout), + version: API_VERSION, + refresh_token: None, + idp_info: Some(server_info.clone()), }; (callback.f)(cb_context).await? }; - // we'll go ahead and update the cache, also, - // TODO RUST 1662: Modify this comment to just say we are updating the cache - update_oidc_cache(credential, &idp_response, 1).await; + // Update the credential and connection caches with the access token and the credential cache + // with the refresh token and token_gen_id + update_caches(conn, credential, &idp_response, Some(server_info)); let sasl_continue = SaslContinue::new( source.to_string(), @@ -281,59 +351,107 @@ async fn authenticate_human( Ok(()) } -async fn authenticate_machine( +async fn authenticate_human( conn: &mut Connection, credential: &Credential, server_api: Option<&ServerApi>, callback: Arc, ) -> Result<()> { - // TODO RUST-1662: Use the Cached credential and add Cache invalidation let source = credential.source.as_deref().unwrap_or("$external"); - let mut start_doc = rawdoc! {}; - if let Some(username) = credential.username.as_deref() { - start_doc.append("n", username); - } - let sasl_start = SaslStart::new( - source.to_string(), - AuthMechanism::MongoDbOidc, - start_doc.into_bytes(), - server_api.cloned(), - ) - .into_command(); - let response = send_sasl_command(conn, sasl_start).await?; - if response.done { - return Err(invalid_auth_response()); + + // If the access token is in the cache, we can use it to send the sasl start command and avoid + // the callback and sasl_continue + if let Some(access_token) = get_access_token(credential) { + let response = send_sasl_start_command( + source, + conn, + credential, + server_api, + Some(access_token.clone()), + ) + .await?; + if response.done { + return Ok(()); + } + invalidate_caches(conn, credential); } - let idp_response = { - let server_info: IdpServerInfo = - bson::from_slice(&response.payload).map_err(|_| invalid_auth_response())?; - const CALLBACK_TIMEOUT: Duration = Duration::from_secs(5 * 60); - let cb_context = CallbackContext { - timeout_seconds: Some(Instant::now() + CALLBACK_TIMEOUT), - version: 1, - refresh_token: None, - idp_info: Some(server_info), - }; - (callback.f)(cb_context).await? - }; - // we'll go ahead and update the cache, also, - // TODO RUST 1662: Modify this comment to just say we are updating the cache - update_oidc_cache(credential, &idp_response, 1).await; + // If the cache has a refresh token, we can avoid asking for the server info. + if let (refresh_token @ Some(_), idp_info) = get_refresh_token_and_idp_info(credential) { + let idp_response = { + let cb_context = CallbackContext { + timeout_seconds: Some(Instant::now() + HUMAN_CALLBACK_TIMEOUT), + version: API_VERSION, + refresh_token, + idp_info, + }; + (callback.f)(cb_context).await? + }; + // Update the credential and connection caches with the access token and the credential + // cache with the refresh token and token_gen_id + update_caches(conn, credential, &idp_response, None); + + let access_token = idp_response.access_token; + let response = send_sasl_start_command( + source, + conn, + credential, + server_api, + Some(access_token.clone()), + ) + .await?; + if response.done { + return Ok(()); + } + invalidate_caches(conn, credential); + } - let sasl_continue = SaslContinue::new( - source.to_string(), - response.conversation_id, - rawdoc! { "jwt": idp_response.access_token }.into_bytes(), - server_api.cloned(), + do_two_step_auth( + source, + conn, + credential, + server_api, + callback, + HUMAN_CALLBACK_TIMEOUT, ) - .into_command(); - let response = send_sasl_command(conn, sasl_continue).await?; - if !response.done { - return Err(invalid_auth_response()); + .await +} + +async fn authenticate_machine( + conn: &mut Connection, + credential: &Credential, + server_api: Option<&ServerApi>, + callback: Arc, +) -> Result<()> { + let source = credential.source.as_deref().unwrap_or("$external"); + + // If the access token is in the cache, we can use it to send the sasl start command and avoid + // the callback and sasl_continue + if let Some(access_token) = get_access_token(credential) { + let response = send_sasl_start_command( + source, + conn, + credential, + server_api, + Some(access_token.clone()), + ) + .await?; + if response.done { + return Ok(()); + } + invalidate_caches(conn, credential); + tokio::time::sleep(MACHINE_INVALIDATE_SLEEP_TIMEOUT).await; } - Ok(()) + do_two_step_auth( + source, + conn, + credential, + server_api, + callback, + MACHINE_CALLBACK_TIMEOUT, + ) + .await } fn auth_error(s: impl AsRef) -> Error { diff --git a/src/cmap/conn.rs b/src/cmap/conn.rs index c7d7000a7..a1a7a06e1 100644 --- a/src/cmap/conn.rs +++ b/src/cmap/conn.rs @@ -123,7 +123,7 @@ pub(crate) struct Connection { /// The token callback for OIDC authentication. #[derivative(Debug = "ignore")] - pub(crate) oidc_access_token: Option, + pub(crate) oidc_token_gen_id: std::sync::RwLock, } impl Connection { @@ -150,7 +150,7 @@ impl Connection { pinned_sender: None, compressor: None, more_to_come: false, - oidc_access_token: None, + oidc_token_gen_id: std::sync::RwLock::new(0), } } @@ -440,7 +440,7 @@ impl Connection { pinned_sender: self.pinned_sender.clone(), compressor: self.compressor.clone(), more_to_come: false, - oidc_access_token: self.oidc_access_token.take(), + oidc_token_gen_id: std::sync::RwLock::new(0), } } diff --git a/src/test/spec/oidc.rs b/src/test/spec/oidc.rs index e7c57692a..d886264ab 100644 --- a/src/test/spec/oidc.rs +++ b/src/test/spec/oidc.rs @@ -40,3 +40,48 @@ async fn single_principal_implicit_username() -> anyhow::Result<()> { .await?; Ok(()) } + +// TODO RUST-1497: The following test will be removed because it is not an actual test in the spec, +// but just showing that the human flow is still working for two_step (nothing in caching is +// correctly exercised here) +#[tokio::test] +async fn human_flow() -> anyhow::Result<()> { + use crate::{ + client::{ + auth::{oidc, AuthMechanism, Credential}, + options::ClientOptions, + }, + test::log_uncaptured, + Client, + }; + use bson::Document; + use futures_util::FutureExt; + + if std::env::var("OIDC_TOKEN_DIR").is_err() { + log_uncaptured("Skipping OIDC test"); + return Ok(()); + } + + let mut opts = ClientOptions::parse("mongodb://localhost/?authMechanism=MONGODB-OIDC").await?; + opts.credential = Credential::builder() + .mechanism(AuthMechanism::MongoDbOidc) + .oidc_callback(oidc::Callback::human(|_| { + async move { + Ok(oidc::IdpServerResponse { + access_token: tokio::fs::read_to_string("/tmp/tokens/test_user1").await?, + expires: None, + refresh_token: None, + }) + } + .boxed() + })) + .build() + .into(); + let client = Client::with_options(opts)?; + client + .database("test") + .collection::("test") + .find_one(None, None) + .await?; + Ok(()) +}