From 507be8672717db0543ef618b7108148dd811203f Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Wed, 3 May 2023 10:06:37 -0400 Subject: [PATCH 01/15] wip --- src/client/csfle/state_machine.rs | 50 +++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index b59ac3f1b..e88ef0728 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -1,7 +1,7 @@ use std::{ convert::TryInto, ops::DerefMut, - path::{Path, PathBuf}, + path::{Path, PathBuf}, time::{Duration, Instant}, }; use bson::{rawdoc, Document, RawDocument, RawDocumentBuf}; @@ -35,6 +35,14 @@ pub(crate) struct CryptExecutor { mongocryptd: Option, mongocryptd_client: Option, metadata_client: Option, + cached_azure_access_token: Mutex>, + azure_imds: Box, +} + +#[derive(Debug)] +struct CachedAzureAccessToken { + token_doc: RawDocumentBuf, + expire_time: Instant, } impl CryptExecutor { @@ -56,6 +64,8 @@ impl CryptExecutor { mongocryptd: None, mongocryptd_client: None, metadata_client: None, + cached_azure_access_token: Mutex::new(None), + azure_imds: Box::new(ProdAzureImds), }) } @@ -211,11 +221,10 @@ impl CryptExecutor { let ctx = result_mut(&mut ctx)?; #[allow(unused_mut)] let mut out = rawdoc! {}; - if self - .kms_providers - .credentials() + let credentials = self.kms_providers.credentials(); + if credentials .get(&KmsProvider::Aws) - .map_or(false, |d| d.is_empty()) + .map_or(false, Document::is_empty) { #[cfg(feature = "aws-auth")] { @@ -240,6 +249,22 @@ impl CryptExecutor { )); } } + if credentials + .get(&KmsProvider::Azure) + .map_or(false, Document::is_empty) + { + let mut cached_token = self.cached_azure_access_token.lock().await; + match &*cached_token { + Some(cached) if cached.expire_time.saturating_duration_since(Instant::now()) > Duration::from_secs(60) => { + out.append("azure", cached.token_doc.clone()); + } + _ => { + let token = self.azure_imds.get_token().await?; + out.append("azure", token.token_doc.clone()); + *cached_token = Some(token); + } + } + } ctx.provide_kms_providers(&out)?; } State::Ready => { @@ -346,3 +371,18 @@ fn raw_to_doc(raw: &RawDocument) -> Result { raw.try_into() .map_err(|e| Error::internal(format!("could not parse raw document: {}", e))) } + +#[async_trait::async_trait] +trait AzureImds: std::fmt::Debug { + async fn get_token(&self) -> Result; +} + +#[derive(Debug)] +struct ProdAzureImds; + +#[async_trait::async_trait] +impl AzureImds for ProdAzureImds { + async fn get_token(&self) -> Result { + todo!() + } +} \ No newline at end of file From e7735a3c19dde3bafa73957caf99271ab7d09bb6 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Wed, 3 May 2023 12:51:41 -0400 Subject: [PATCH 02/15] trait for Azure IMDS access --- src/client/csfle/state_machine.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index e88ef0728..a781cc607 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -36,13 +36,7 @@ pub(crate) struct CryptExecutor { mongocryptd_client: Option, metadata_client: Option, cached_azure_access_token: Mutex>, - azure_imds: Box, -} - -#[derive(Debug)] -struct CachedAzureAccessToken { - token_doc: RawDocumentBuf, - expire_time: Instant, + azure_imds: Box, } impl CryptExecutor { @@ -372,6 +366,12 @@ fn raw_to_doc(raw: &RawDocument) -> Result { .map_err(|e| Error::internal(format!("could not parse raw document: {}", e))) } +#[derive(Debug)] +struct CachedAzureAccessToken { + token_doc: RawDocumentBuf, + expire_time: Instant, +} + #[async_trait::async_trait] trait AzureImds: std::fmt::Debug { async fn get_token(&self) -> Result; From 8839f609ee87681e6781655f90272479b439d587 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Wed, 3 May 2023 12:56:13 -0400 Subject: [PATCH 03/15] move http client construction into handshaker --- src/cmap/establish/handshake/mod.rs | 4 ++-- src/cmap/establish/handshake/test.rs | 4 +--- src/cmap/establish/mod.rs | 6 +++--- src/cmap/test/integration.rs | 3 --- src/cmap/test/mod.rs | 1 - src/sdam/topology.rs | 3 +-- 6 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/cmap/establish/handshake/mod.rs b/src/cmap/establish/handshake/mod.rs index c85ae8e47..04b102472 100644 --- a/src/cmap/establish/handshake/mod.rs +++ b/src/cmap/establish/handshake/mod.rs @@ -332,7 +332,7 @@ pub(crate) struct Handshaker { impl Handshaker { /// Creates a new Handshaker. - pub(crate) fn new(http_client: HttpClient, options: HandshakerOptions) -> Self { + pub(crate) fn new(options: HandshakerOptions) -> Self { let mut metadata = BASE_CLIENT_METADATA.clone(); let compressors = options.compressors; @@ -383,7 +383,7 @@ impl Handshaker { command.body.insert("client", metadata.clone()); Self { - http_client, + http_client: HttpClient::default(), command, compressors, server_api: options.server_api, diff --git a/src/cmap/establish/handshake/test.rs b/src/cmap/establish/handshake/test.rs index 5a4616d32..568cf51e7 100644 --- a/src/cmap/establish/handshake/test.rs +++ b/src/cmap/establish/handshake/test.rs @@ -3,13 +3,11 @@ use crate::{ bson::doc, cmap::establish::handshake::HandshakerOptions, options::DriverInfo, - runtime::HttpClient, }; #[test] fn metadata_no_options() { let handshaker = Handshaker::new( - HttpClient::default(), HandshakerOptions { app_name: None, compressors: None, @@ -51,7 +49,7 @@ fn metadata_with_options() { load_balanced: false, }; - let handshaker = Handshaker::new(HttpClient::default(), options); + let handshaker = Handshaker::new(options); let metadata = handshaker.command.body.get_document("client").unwrap(); assert_eq!( diff --git a/src/cmap/establish/mod.rs b/src/cmap/establish/mod.rs index 03603c760..36285fa99 100644 --- a/src/cmap/establish/mod.rs +++ b/src/cmap/establish/mod.rs @@ -15,7 +15,7 @@ use crate::{ }, error::{Error as MongoError, ErrorKind, Result}, hello::HelloReply, - runtime::{self, stream::DEFAULT_CONNECT_TIMEOUT, AsyncStream, HttpClient, TlsConfig}, + runtime::{self, stream::DEFAULT_CONNECT_TIMEOUT, AsyncStream, TlsConfig}, sdam::HandshakePhase, }; @@ -56,8 +56,8 @@ impl EstablisherOptions { impl ConnectionEstablisher { /// Creates a new ConnectionEstablisher from the given options. - pub(crate) fn new(http_client: HttpClient, options: EstablisherOptions) -> Result { - let handshaker = Handshaker::new(http_client, options.handshake_options); + pub(crate) fn new(options: EstablisherOptions) -> Result { + let handshaker = Handshaker::new(options.handshake_options); let tls_config = if let Some(tls_options) = options.tls_options { Some(TlsConfig::new(tls_options)?) diff --git a/src/cmap/test/integration.rs b/src/cmap/test/integration.rs index e8a2be014..bf5ac21e1 100644 --- a/src/cmap/test/integration.rs +++ b/src/cmap/test/integration.rs @@ -51,7 +51,6 @@ async fn acquire_connection_and_send_command() { let pool = ConnectionPool::new( client_options.hosts[0].clone(), ConnectionEstablisher::new( - Default::default(), EstablisherOptions::from_client_options(&client_options), ) .unwrap(), @@ -133,7 +132,6 @@ async fn concurrent_connections() { let pool = ConnectionPool::new( CLIENT_OPTIONS.get().await.hosts[0].clone(), ConnectionEstablisher::new( - Default::default(), EstablisherOptions::from_client_options(&client_options), ) .unwrap(), @@ -226,7 +224,6 @@ async fn connection_error_during_establishment() { let pool = ConnectionPool::new( client_options.hosts[0].clone(), ConnectionEstablisher::new( - Default::default(), EstablisherOptions::from_client_options(&client_options), ) .unwrap(), diff --git a/src/cmap/test/mod.rs b/src/cmap/test/mod.rs index 739cf74f4..38b109719 100644 --- a/src/cmap/test/mod.rs +++ b/src/cmap/test/mod.rs @@ -160,7 +160,6 @@ impl Executor { let pool = ConnectionPool::new( CLIENT_OPTIONS.get().await.hosts[0].clone(), ConnectionEstablisher::new( - Default::default(), EstablisherOptions::from_client_options(CLIENT_OPTIONS.get().await), ) .unwrap(), diff --git a/src/sdam/topology.rs b/src/sdam/topology.rs index e7bbd22fc..f6cb4ac0e 100644 --- a/src/sdam/topology.rs +++ b/src/sdam/topology.rs @@ -35,7 +35,7 @@ use crate::{ TopologyDescriptionChangedEvent, TopologyOpeningEvent, }, - runtime::{self, AcknowledgedMessage, HttpClient, WorkerHandle, WorkerHandleListener}, + runtime::{self, AcknowledgedMessage, WorkerHandle, WorkerHandleListener}, selection_criteria::SelectionCriteria, ClusterTime, ServerInfo, @@ -93,7 +93,6 @@ impl Topology { let (watcher, publisher) = TopologyWatcher::channel(state); let connection_establisher = ConnectionEstablisher::new( - HttpClient::default(), EstablisherOptions::from_client_options(&options), )?; From 283a8363ab28e466c9e308b0084faf8d93be2ac6 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Wed, 3 May 2023 13:19:21 -0400 Subject: [PATCH 04/15] decouple http client from aws-auth --- src/client/auth/mod.rs | 8 ++++---- src/cmap/establish/handshake/mod.rs | 12 +++++++----- src/runtime/http.rs | 4 ---- src/runtime/mod.rs | 2 ++ 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/client/auth/mod.rs b/src/client/auth/mod.rs index b3b953621..4a7a5d03b 100644 --- a/src/client/auth/mod.rs +++ b/src/client/auth/mod.rs @@ -23,7 +23,6 @@ use crate::{ client::options::ServerApi, cmap::{Command, Connection, StreamDescription}, error::{Error, ErrorKind, Result}, - runtime::HttpClient, }; const SCRAM_SHA_1_STR: &str = "SCRAM-SHA-1"; @@ -253,7 +252,7 @@ impl AuthMechanism { stream: &mut Connection, credential: &Credential, server_api: Option<&ServerApi>, - #[cfg_attr(not(feature = "aws-auth"), allow(unused))] http_client: &HttpClient, + #[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient, ) -> Result<()> { self.validate_credential(credential)?; @@ -398,9 +397,10 @@ impl Credential { pub(crate) async fn authenticate_stream( &self, conn: &mut Connection, - http_client: &HttpClient, server_api: Option<&ServerApi>, first_round: Option, + #[cfg(feature = "aws-auth")] + http_client: &crate::runtime::HttpClient, ) -> Result<()> { let stream_description = conn.stream_description()?; @@ -431,7 +431,7 @@ impl Credential { // Authenticate according to the chosen mechanism. mechanism - .authenticate_stream(conn, self, server_api, http_client) + .authenticate_stream(conn, self, server_api, #[cfg(feature = "aws-auth")] http_client) .await } diff --git a/src/cmap/establish/handshake/mod.rs b/src/cmap/establish/handshake/mod.rs index 04b102472..ee6029f80 100644 --- a/src/cmap/establish/handshake/mod.rs +++ b/src/cmap/establish/handshake/mod.rs @@ -13,7 +13,6 @@ use crate::{ error::Result, hello::{hello_command, run_hello, HelloReply}, options::{AuthMechanism, Credential, DriverInfo, ServerApi}, - runtime::HttpClient, }; #[cfg(all(feature = "tokio-runtime", not(feature = "tokio-sync")))] @@ -323,11 +322,12 @@ pub(crate) struct Handshaker { #[allow(dead_code)] compressors: Option>, - http_client: HttpClient, - server_api: Option, metadata: ClientMetadata, + + #[cfg(feature = "aws-auth")] + http_client: crate::runtime::HttpClient, } impl Handshaker { @@ -383,11 +383,12 @@ impl Handshaker { command.body.insert("client", metadata.clone()); Self { - http_client: HttpClient::default(), command, compressors, server_api: options.server_api, metadata, + #[cfg(feature = "aws-auth")] + http_client: crate::runtime::HttpClient::default(), } } @@ -457,9 +458,10 @@ impl Handshaker { credential .authenticate_stream( conn, - &self.http_client, self.server_api.as_ref(), first_round, + #[cfg(feature = "aws-auth")] + &self.http_client, ) .await? } diff --git a/src/runtime/http.rs b/src/runtime/http.rs index e816e2d25..eb0f15849 100644 --- a/src/runtime/http.rs +++ b/src/runtime/http.rs @@ -1,15 +1,11 @@ -#[cfg(feature = "aws-auth")] use reqwest::{Method, Response}; -#[cfg(feature = "aws-auth")] use serde::Deserialize; #[derive(Clone, Debug, Default)] pub(crate) struct HttpClient { - #[cfg(feature = "aws-auth")] inner: reqwest::Client, } -#[cfg(feature = "aws-auth")] impl HttpClient { /// Executes an HTTP GET request and deserializes the JSON response. pub(crate) async fn get_and_deserialize_json<'a, T>( diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 7a6f64332..b625a5101 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -1,4 +1,5 @@ mod acknowledged_message; +#[cfg(feature = "reqwest")] mod http; #[cfg(feature = "async-std-runtime")] mod interval; @@ -28,6 +29,7 @@ pub(crate) use self::{ worker_handle::{WorkerHandle, WorkerHandleListener}, }; use crate::{error::Result, options::ServerAddress}; +#[cfg(feature = "reqwest")] pub(crate) use http::HttpClient; #[cfg(feature = "async-std-runtime")] use interval::Interval; From ea4b5a2e5f218704eba86f5338cf00415e84eccc Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Wed, 3 May 2023 13:37:03 -0400 Subject: [PATCH 05/15] factor out azure kms module --- Cargo.toml | 3 + src/client/csfle/state_machine.rs | 99 +++++++++++++++++++++---------- 2 files changed, 71 insertions(+), 31 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5b28d1a81..517d39fc6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,6 +65,9 @@ bson-uuid-1 = ["bson/uuid-1"] # This can only be used with the tokio-runtime feature flag. aws-auth = ["reqwest"] +# Enable support for on-demand Azure KMS credentials. +azure-kms = ["reqwest"] + zstd-compression = ["zstd"] zlib-compression = ["flate2"] snappy-compression = ["snap"] diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index a781cc607..955e69707 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -1,7 +1,7 @@ use std::{ convert::TryInto, ops::DerefMut, - path::{Path, PathBuf}, time::{Duration, Instant}, + path::{Path, PathBuf}, }; use bson::{rawdoc, Document, RawDocument, RawDocumentBuf}; @@ -35,8 +35,8 @@ pub(crate) struct CryptExecutor { mongocryptd: Option, mongocryptd_client: Option, metadata_client: Option, - cached_azure_access_token: Mutex>, - azure_imds: Box, + #[cfg(feature = "azure-kms")] + azure: azure::ExecutorState, } impl CryptExecutor { @@ -58,8 +58,8 @@ impl CryptExecutor { mongocryptd: None, mongocryptd_client: None, metadata_client: None, - cached_azure_access_token: Mutex::new(None), - azure_imds: Box::new(ProdAzureImds), + #[cfg(feature = "azure-kms")] + azure: azure::ExecutorState::new(), }) } @@ -247,16 +247,15 @@ impl CryptExecutor { .get(&KmsProvider::Azure) .map_or(false, Document::is_empty) { - let mut cached_token = self.cached_azure_access_token.lock().await; - match &*cached_token { - Some(cached) if cached.expire_time.saturating_duration_since(Instant::now()) > Duration::from_secs(60) => { - out.append("azure", cached.token_doc.clone()); - } - _ => { - let token = self.azure_imds.get_token().await?; - out.append("azure", token.token_doc.clone()); - *cached_token = Some(token); - } + #[cfg(feature = "azure-kms")] + { + out.append("azure", self.azure.get_token().await?); + } + #[cfg(not(feature = "azure-kms"))] + { + return Err(Error::invalid_argument( + "On-demand Azure KMS credentials require the `azure-kms` feature.", + )); } } ctx.provide_kms_providers(&out)?; @@ -366,23 +365,61 @@ fn raw_to_doc(raw: &RawDocument) -> Result { .map_err(|e| Error::internal(format!("could not parse raw document: {}", e))) } -#[derive(Debug)] -struct CachedAzureAccessToken { - token_doc: RawDocumentBuf, - expire_time: Instant, -} +#[cfg(feature = "azure-kms")] +mod azure { + use bson::RawDocumentBuf; + use tokio::sync::Mutex; + use std::time::{Instant, Duration}; -#[async_trait::async_trait] -trait AzureImds: std::fmt::Debug { - async fn get_token(&self) -> Result; -} + use crate::error::Result; -#[derive(Debug)] -struct ProdAzureImds; + #[derive(Debug)] + pub(super) struct ExecutorState { + cached_access_token: Mutex>, + token_source: Box, + } + + impl ExecutorState { + pub(super) fn new() -> Self { + Self { + cached_access_token: Mutex::new(None), + token_source: Box::new(AzureImdsTokenSource), + } + } + + pub(super) async fn get_token(&self) -> Result { + let mut cached_token = self.cached_access_token.lock().await; + if let Some(cached) = &*cached_token { + if cached.expire_time.saturating_duration_since(Instant::now()) > Duration::from_secs(60) { + return Ok(cached.token_doc.clone()); + } + } + let token = self.token_source.get_token().await?; + let out = token.token_doc.clone(); + *cached_token = Some(token); + Ok(out) + } + } + + #[derive(Debug)] + struct CachedAccessToken { + token_doc: RawDocumentBuf, + expire_time: Instant, + } -#[async_trait::async_trait] -impl AzureImds for ProdAzureImds { - async fn get_token(&self) -> Result { - todo!() + #[async_trait::async_trait] + trait TokenSource: std::fmt::Debug { + async fn get_token(&self) -> Result; } -} \ No newline at end of file + + #[derive(Debug)] + struct AzureImdsTokenSource; + + #[async_trait::async_trait] + impl TokenSource for AzureImdsTokenSource { + async fn get_token(&self) -> Result { + todo!() + } + } + +} From 2b0c542dd76c250a3fad57f7f0bb39150cf1668a Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Wed, 3 May 2023 13:55:38 -0400 Subject: [PATCH 06/15] token source impl --- src/client/csfle/state_machine.rs | 32 +++++++++++++++++++++++++------ src/runtime/http.rs | 5 ++++- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index 955e69707..4b405a24b 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -367,16 +367,18 @@ fn raw_to_doc(raw: &RawDocument) -> Result { #[cfg(feature = "azure-kms")] mod azure { - use bson::RawDocumentBuf; + use bson::{RawDocumentBuf, rawdoc}; + use serde::Deserialize; use tokio::sync::Mutex; use std::time::{Instant, Duration}; - use crate::error::Result; + use crate::{error::{Result, Error}, runtime::HttpClient}; #[derive(Debug)] pub(super) struct ExecutorState { cached_access_token: Mutex>, token_source: Box, + http: HttpClient, } impl ExecutorState { @@ -384,6 +386,7 @@ mod azure { Self { cached_access_token: Mutex::new(None), token_source: Box::new(AzureImdsTokenSource), + http: HttpClient::default(), } } @@ -394,7 +397,7 @@ mod azure { return Ok(cached.token_doc.clone()); } } - let token = self.token_source.get_token().await?; + let token = self.token_source.get_token(&self.http).await?; let out = token.token_doc.clone(); *cached_token = Some(token); Ok(out) @@ -409,7 +412,7 @@ mod azure { #[async_trait::async_trait] trait TokenSource: std::fmt::Debug { - async fn get_token(&self) -> Result; + async fn get_token(&self, http: &HttpClient) -> Result; } #[derive(Debug)] @@ -417,8 +420,25 @@ mod azure { #[async_trait::async_trait] impl TokenSource for AzureImdsTokenSource { - async fn get_token(&self) -> Result { - todo!() + async fn get_token(&self, http: &HttpClient) -> Result { + #[derive(Deserialize)] + struct Response { + access_token: String, + expires_in: u32, + } + + let now = Instant::now(); + let resp: Response = http.get_and_deserialize_json( + "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://vault.azure.net/", + &[("Metadata", "true"), ("Accept", "application/json")], + ) + .await + .map_err(|e| Error::authentication_error("azure imds", &format!("{}", e)))?; + + Ok(CachedAccessToken { + token_doc: rawdoc! { "accessToken": resp.access_token }, + expire_time: now + Duration::from_secs(resp.expires_in as u64), + }) } } diff --git a/src/runtime/http.rs b/src/runtime/http.rs index eb0f15849..c905b3db0 100644 --- a/src/runtime/http.rs +++ b/src/runtime/http.rs @@ -26,6 +26,7 @@ impl HttpClient { } /// Executes an HTTP GET request and returns the response body as a string. + #[allow(unused)] pub(crate) async fn get_and_read_string<'a>( &self, uri: &str, @@ -36,6 +37,7 @@ impl HttpClient { } /// Executes an HTTP PUT request and returns the response body as a string. + #[allow(unused)] pub(crate) async fn put_and_read_string<'a>( &self, uri: &str, @@ -46,6 +48,7 @@ impl HttpClient { } /// Executes an HTTP request and returns the response body as a string. + #[allow(unused)] pub(crate) async fn request_and_read_string<'a>( &self, method: Method, @@ -57,7 +60,7 @@ impl HttpClient { Ok(text) } - /// Executes an HTTP equest and returns the response. + /// Executes an HTTP request and returns the response. pub(crate) async fn request<'a>( &self, method: Method, From 631285dd17163f0aa9e598c6490ec2034f3451f7 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Thu, 4 May 2023 10:34:35 -0400 Subject: [PATCH 07/15] allow setting test host --- src/client/csfle/state_machine.rs | 54 ++++++++++++++++++------------- src/runtime/http.rs | 6 ++-- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index 4b405a24b..4f6b9db72 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -377,16 +377,18 @@ mod azure { #[derive(Debug)] pub(super) struct ExecutorState { cached_access_token: Mutex>, - token_source: Box, http: HttpClient, + #[cfg(test)] + test_host: Option<&'static str>, } impl ExecutorState { pub(super) fn new() -> Self { Self { cached_access_token: Mutex::new(None), - token_source: Box::new(AzureImdsTokenSource), http: HttpClient::default(), + #[cfg(test)] + test_host: None, } } @@ -397,30 +399,13 @@ mod azure { return Ok(cached.token_doc.clone()); } } - let token = self.token_source.get_token(&self.http).await?; + let token = self.fetch_new_token().await?; let out = token.token_doc.clone(); *cached_token = Some(token); Ok(out) } - } - - #[derive(Debug)] - struct CachedAccessToken { - token_doc: RawDocumentBuf, - expire_time: Instant, - } - - #[async_trait::async_trait] - trait TokenSource: std::fmt::Debug { - async fn get_token(&self, http: &HttpClient) -> Result; - } - #[derive(Debug)] - struct AzureImdsTokenSource; - - #[async_trait::async_trait] - impl TokenSource for AzureImdsTokenSource { - async fn get_token(&self, http: &HttpClient) -> Result { + async fn fetch_new_token(&self) -> Result { #[derive(Deserialize)] struct Response { access_token: String, @@ -428,8 +413,26 @@ mod azure { } let now = Instant::now(); - let resp: Response = http.get_and_deserialize_json( - "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://vault.azure.net/", + let url = reqwest::Url::parse_with_params( + "http://169.254.169.254/metadata/identity/oauth2/token", + &[ + ("api-version", "2018-02-01"), + ("resource", "https://vault.azure.net/"), + ], + ) + .map_err(|e| Error::internal(format!("invalid Azure IMDS URL: {}", e)))?; + #[cfg(test)] + let url = { + let mut url = url; + if let Some(th) = self.test_host { + url + .set_host(Some(th)) + .map_err(|e| Error::internal(format!("invalid test host: {}", e)))?; + } + url + }; + let resp: Response = self.http.get_and_deserialize_json( + url, &[("Metadata", "true"), ("Accept", "application/json")], ) .await @@ -442,4 +445,9 @@ mod azure { } } + #[derive(Debug)] + struct CachedAccessToken { + token_doc: RawDocumentBuf, + expire_time: Instant, + } } diff --git a/src/runtime/http.rs b/src/runtime/http.rs index c905b3db0..1f1b92633 100644 --- a/src/runtime/http.rs +++ b/src/runtime/http.rs @@ -1,4 +1,4 @@ -use reqwest::{Method, Response}; +use reqwest::{Method, Response, IntoUrl}; use serde::Deserialize; #[derive(Clone, Debug, Default)] @@ -10,7 +10,7 @@ impl HttpClient { /// Executes an HTTP GET request and deserializes the JSON response. pub(crate) async fn get_and_deserialize_json<'a, T>( &self, - uri: &str, + uri: impl IntoUrl, headers: impl IntoIterator, ) -> reqwest::Result where @@ -64,7 +64,7 @@ impl HttpClient { pub(crate) async fn request<'a>( &self, method: Method, - uri: &str, + uri: impl IntoUrl, headers: impl IntoIterator, ) -> reqwest::Result { let response = headers From 9ebcb115c2f8d2d145b6dde4538801d3e5c8a1b6 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Mon, 8 May 2023 11:08:07 -0400 Subject: [PATCH 08/15] passing case 1 --- src/client/csfle.rs | 2 +- src/client/csfle/state_machine.rs | 57 +++++++++++++++++-------------- src/test/csfle.rs | 22 +++++++++++- 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/src/client/csfle.rs b/src/client/csfle.rs index 0c508d108..037267d04 100644 --- a/src/client/csfle.rs +++ b/src/client/csfle.rs @@ -1,7 +1,7 @@ pub(crate) mod client_builder; pub mod client_encryption; pub mod options; -mod state_machine; +pub(crate) mod state_machine; use std::{path::Path, time::Duration}; diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index 4f6b9db72..43d8e7463 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -366,7 +366,7 @@ fn raw_to_doc(raw: &RawDocument) -> Result { } #[cfg(feature = "azure-kms")] -mod azure { +pub(crate) mod azure { use bson::{RawDocumentBuf, rawdoc}; use serde::Deserialize; use tokio::sync::Mutex; @@ -375,15 +375,15 @@ mod azure { use crate::{error::{Result, Error}, runtime::HttpClient}; #[derive(Debug)] - pub(super) struct ExecutorState { + pub(crate) struct ExecutorState { cached_access_token: Mutex>, http: HttpClient, #[cfg(test)] - test_host: Option<&'static str>, + pub(crate) test_host: Option<(&'static str, u16)>, } impl ExecutorState { - pub(super) fn new() -> Self { + pub(crate) fn new() -> Self { Self { cached_access_token: Mutex::new(None), http: HttpClient::default(), @@ -392,59 +392,66 @@ mod azure { } } - pub(super) async fn get_token(&self) -> Result { + pub(crate) async fn get_token(&self) -> Result { let mut cached_token = self.cached_access_token.lock().await; if let Some(cached) = &*cached_token { if cached.expire_time.saturating_duration_since(Instant::now()) > Duration::from_secs(60) { return Ok(cached.token_doc.clone()); } } + let now = Instant::now(); let token = self.fetch_new_token().await?; - let out = token.token_doc.clone(); - *cached_token = Some(token); + let expires_in_secs: u64 = token.expires_in + .parse() + .map_err(|_| Error::invalid_authentication_response("azure imds"))?; + let cached = CachedAccessToken { + token_doc: rawdoc! { "accessToken": token.access_token }, + expire_time: now + Duration::from_secs(expires_in_secs), + }; + let out = cached.token_doc.clone(); + *cached_token = Some(cached); Ok(out) } - async fn fetch_new_token(&self) -> Result { - #[derive(Deserialize)] - struct Response { - access_token: String, - expires_in: u32, - } - - let now = Instant::now(); + pub(crate) async fn fetch_new_token(&self) -> Result { let url = reqwest::Url::parse_with_params( "http://169.254.169.254/metadata/identity/oauth2/token", &[ ("api-version", "2018-02-01"), - ("resource", "https://vault.azure.net/"), + ("resource", "https://vault.azure.net"), ], ) .map_err(|e| Error::internal(format!("invalid Azure IMDS URL: {}", e)))?; #[cfg(test)] let url = { let mut url = url; - if let Some(th) = self.test_host { + if let Some((host, port)) = self.test_host { url - .set_host(Some(th)) + .set_host(Some(host)) .map_err(|e| Error::internal(format!("invalid test host: {}", e)))?; + url + .set_port(Some(port)) + .map_err(|()| Error::internal(format!("invalid test port")))?; } url }; - let resp: Response = self.http.get_and_deserialize_json( + self.http.get_and_deserialize_json( url, &[("Metadata", "true"), ("Accept", "application/json")], ) .await - .map_err(|e| Error::authentication_error("azure imds", &format!("{}", e)))?; - - Ok(CachedAccessToken { - token_doc: rawdoc! { "accessToken": resp.access_token }, - expire_time: now + Duration::from_secs(resp.expires_in as u64), - }) + .map_err(|e| Error::authentication_error("azure imds", &format!("{}", e))) } } + #[derive(Debug, Deserialize)] + pub(crate) struct AccessToken { + pub(crate) access_token: String, + pub(crate) expires_in: String, + #[allow(unused)] + pub(crate) resource: String, + } + #[derive(Debug)] struct CachedAccessToken { token_doc: RawDocumentBuf, diff --git a/src/test/csfle.rs b/src/test/csfle.rs index 44162fab6..3af91a199 100644 --- a/src/test/csfle.rs +++ b/src/test/csfle.rs @@ -2858,7 +2858,27 @@ async fn on_demand_aws_success() -> Result<()> { // TODO RUST-1417: implement prose test 17. On-demand GCP Credentials -// TODO RUST-1442: implement prose test 18. Azure IMDS Credentials +// Prose test 18. Azure IMDS Credentials +#[cfg(feature = "azure-kms")] +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn azure_imds() -> Result<()> { + if !check_env("azure_imds", false) { + return Ok(()); + } + let _guard = LOCK.run_concurrently().await; + + let mut azure_exec = crate::client::csfle::state_machine::azure::ExecutorState::new(); + azure_exec.test_host = Some(("localhost", 8080)); + + // Case 1 + let token = azure_exec.fetch_new_token().await?; + assert_eq!(token.access_token, "magic-cookie"); + assert_eq!(token.expires_in, "70"); + assert_eq!(token.resource, "https://vault.azure.net"); + + Ok(()) +} // TODO RUST-1442: implement prose test 19. Azure IMDS Credentials Integration Test From 768cd63fdae0e0e153b93017bdf3afa27a856a22 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Mon, 8 May 2023 11:41:12 -0400 Subject: [PATCH 09/15] better test coverage --- src/client/csfle/state_machine.rs | 47 ++++++++++++++++++------------- src/test/csfle.rs | 11 +++++--- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index 43d8e7463..ab7d82538 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -399,21 +399,14 @@ pub(crate) mod azure { return Ok(cached.token_doc.clone()); } } - let now = Instant::now(); let token = self.fetch_new_token().await?; - let expires_in_secs: u64 = token.expires_in - .parse() - .map_err(|_| Error::invalid_authentication_response("azure imds"))?; - let cached = CachedAccessToken { - token_doc: rawdoc! { "accessToken": token.access_token }, - expire_time: now + Duration::from_secs(expires_in_secs), - }; - let out = cached.token_doc.clone(); - *cached_token = Some(cached); + let out = token.token_doc.clone(); + *cached_token = Some(token); Ok(out) } - pub(crate) async fn fetch_new_token(&self) -> Result { + async fn fetch_new_token(&self) -> Result { + let now = Instant::now(); let url = reqwest::Url::parse_with_params( "http://169.254.169.254/metadata/identity/oauth2/token", &[ @@ -435,26 +428,42 @@ pub(crate) mod azure { } url }; - self.http.get_and_deserialize_json( + let server_response: ServerResponse = self.http.get_and_deserialize_json( url, &[("Metadata", "true"), ("Accept", "application/json")], ) .await - .map_err(|e| Error::authentication_error("azure imds", &format!("{}", e))) + .map_err(|e| Error::authentication_error("azure imds", &format!("{}", e)))?; + let expires_in_secs: u64 = server_response.expires_in + .parse() + .map_err(|_| Error::invalid_authentication_response("azure imds"))?; + Ok(CachedAccessToken { + token_doc: rawdoc! { "accessToken": server_response.access_token.clone() }, + expire_time: now + Duration::from_secs(expires_in_secs), + #[cfg(test)] + server_response, + }) + } + + #[cfg(test)] + pub(crate) async fn cached(&self) -> Option { + self.cached_access_token.lock().await.clone() } } - #[derive(Debug, Deserialize)] - pub(crate) struct AccessToken { + #[derive(Debug, Deserialize, Clone)] + pub(crate) struct ServerResponse { pub(crate) access_token: String, pub(crate) expires_in: String, #[allow(unused)] pub(crate) resource: String, } - #[derive(Debug)] - struct CachedAccessToken { - token_doc: RawDocumentBuf, - expire_time: Instant, + #[derive(Debug, Clone)] + pub(crate) struct CachedAccessToken { + pub(crate) token_doc: RawDocumentBuf, + pub(crate) expire_time: Instant, + #[cfg(test)] + pub(crate) server_response: ServerResponse, } } diff --git a/src/test/csfle.rs b/src/test/csfle.rs index 3af91a199..fbcc59699 100644 --- a/src/test/csfle.rs +++ b/src/test/csfle.rs @@ -2872,10 +2872,13 @@ async fn azure_imds() -> Result<()> { azure_exec.test_host = Some(("localhost", 8080)); // Case 1 - let token = azure_exec.fetch_new_token().await?; - assert_eq!(token.access_token, "magic-cookie"); - assert_eq!(token.expires_in, "70"); - assert_eq!(token.resource, "https://vault.azure.net"); + let now = std::time::Instant::now(); + let token = azure_exec.get_token().await?; + assert_eq!(token, rawdoc! { "accessToken": "magic-cookie" }); + let cached = azure_exec.cached().await.expect("cached token"); + assert_eq!(cached.server_response.expires_in, "70"); + assert_eq!(cached.server_response.resource, "https://vault.azure.net"); + assert!((65..75).contains(&cached.expire_time.duration_since(now).as_secs())); Ok(()) } From 54b62ddcbf757a65ac08891a029e0d826ec9d7b8 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Mon, 8 May 2023 12:31:42 -0400 Subject: [PATCH 10/15] passing unit tests --- src/client/csfle/state_machine.rs | 75 +++++++++++++++++++++---------- src/runtime/http.rs | 10 +++++ src/test/csfle.rs | 60 +++++++++++++++++++++---- 3 files changed, 112 insertions(+), 33 deletions(-) diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index ab7d82538..eac32e60b 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -59,7 +59,7 @@ impl CryptExecutor { mongocryptd_client: None, metadata_client: None, #[cfg(feature = "azure-kms")] - azure: azure::ExecutorState::new(), + azure: azure::ExecutorState::new()?, }) } @@ -380,16 +380,21 @@ pub(crate) mod azure { http: HttpClient, #[cfg(test)] pub(crate) test_host: Option<(&'static str, u16)>, + #[cfg(test)] + pub(crate) test_param: Option<&'static str>, } impl ExecutorState { - pub(crate) fn new() -> Self { - Self { + pub(crate) fn new() -> Result { + const AZURE_IMDS_TIMEOUT: Duration = Duration::from_secs(10); + Ok(Self { cached_access_token: Mutex::new(None), - http: HttpClient::default(), + http: HttpClient::with_timeout(AZURE_IMDS_TIMEOUT)?, #[cfg(test)] test_host: None, - } + #[cfg(test)] + test_param: None, + }) } pub(crate) async fn get_token(&self) -> Result { @@ -407,6 +412,29 @@ pub(crate) mod azure { async fn fetch_new_token(&self) -> Result { let now = Instant::now(); + let server_response: ServerResponse = self.http.get_and_deserialize_json( + self.make_url()?, + &self.make_headers(), + ) + .await + .map_err(|e| Error::authentication_error("azure imds", &format!("{}", e)))?; + let expires_in_secs: u64 = server_response.expires_in + .parse() + .map_err(|e| { + Error::authentication_error( + "azure imds", + &format!("invalid `expires_in` response field: {}", e), + ) + })?; + Ok(CachedAccessToken { + token_doc: rawdoc! { "accessToken": server_response.access_token.clone() }, + expire_time: now + Duration::from_secs(expires_in_secs), + #[cfg(test)] + server_response, + }) + } + + fn make_url(&self) -> Result { let url = reqwest::Url::parse_with_params( "http://169.254.169.254/metadata/identity/oauth2/token", &[ @@ -428,30 +456,29 @@ pub(crate) mod azure { } url }; - let server_response: ServerResponse = self.http.get_and_deserialize_json( - url, - &[("Metadata", "true"), ("Accept", "application/json")], - ) - .await - .map_err(|e| Error::authentication_error("azure imds", &format!("{}", e)))?; - let expires_in_secs: u64 = server_response.expires_in - .parse() - .map_err(|_| Error::invalid_authentication_response("azure imds"))?; - Ok(CachedAccessToken { - token_doc: rawdoc! { "accessToken": server_response.access_token.clone() }, - expire_time: now + Duration::from_secs(expires_in_secs), - #[cfg(test)] - server_response, - }) + Ok(url) + } + + fn make_headers(&self) -> Vec<(&'static str, &'static str)> { + let headers = vec![("Metadata", "true"), ("Accept", "application/json")]; + #[cfg(test)] + let headers = { + let mut headers = headers; + if let Some(p) = self.test_param { + headers.push(("X-MongoDB-HTTP-TestParams", p)); + } + headers + }; + headers } #[cfg(test)] - pub(crate) async fn cached(&self) -> Option { - self.cached_access_token.lock().await.clone() + pub(crate) async fn take_cached(&self) -> Option { + self.cached_access_token.lock().await.take() } } - #[derive(Debug, Deserialize, Clone)] + #[derive(Debug, Deserialize)] pub(crate) struct ServerResponse { pub(crate) access_token: String, pub(crate) expires_in: String, @@ -459,7 +486,7 @@ pub(crate) mod azure { pub(crate) resource: String, } - #[derive(Debug, Clone)] + #[derive(Debug)] pub(crate) struct CachedAccessToken { pub(crate) token_doc: RawDocumentBuf, pub(crate) expire_time: Instant, diff --git a/src/runtime/http.rs b/src/runtime/http.rs index 1f1b92633..a614e3395 100644 --- a/src/runtime/http.rs +++ b/src/runtime/http.rs @@ -1,12 +1,22 @@ use reqwest::{Method, Response, IntoUrl}; use serde::Deserialize; +use crate::error::{Error, Result}; + #[derive(Clone, Debug, Default)] pub(crate) struct HttpClient { inner: reqwest::Client, } impl HttpClient { + pub(crate) fn with_timeout(timeout: std::time::Duration) -> Result { + let inner = reqwest::Client::builder() + .timeout(timeout) + .build() + .map_err(|e| Error::internal(format!("error initializing http client: {}", e)))?; + Ok(Self { inner }) + } + /// Executes an HTTP GET request and deserializes the JSON response. pub(crate) async fn get_and_deserialize_json<'a, T>( &self, diff --git a/src/test/csfle.rs b/src/test/csfle.rs index fbcc59699..65ec30c2b 100644 --- a/src/test/csfle.rs +++ b/src/test/csfle.rs @@ -2868,17 +2868,59 @@ async fn azure_imds() -> Result<()> { } let _guard = LOCK.run_concurrently().await; - let mut azure_exec = crate::client::csfle::state_machine::azure::ExecutorState::new(); + let mut azure_exec = crate::client::csfle::state_machine::azure::ExecutorState::new()?; azure_exec.test_host = Some(("localhost", 8080)); - // Case 1 - let now = std::time::Instant::now(); - let token = azure_exec.get_token().await?; - assert_eq!(token, rawdoc! { "accessToken": "magic-cookie" }); - let cached = azure_exec.cached().await.expect("cached token"); - assert_eq!(cached.server_response.expires_in, "70"); - assert_eq!(cached.server_response.resource, "https://vault.azure.net"); - assert!((65..75).contains(&cached.expire_time.duration_since(now).as_secs())); + // Case 1: Success + { + let now = std::time::Instant::now(); + let token = azure_exec.get_token().await?; + assert_eq!(token, rawdoc! { "accessToken": "magic-cookie" }); + let cached = azure_exec.take_cached().await.expect("cached token"); + assert_eq!(cached.server_response.expires_in, "70"); + assert_eq!(cached.server_response.resource, "https://vault.azure.net"); + assert!((65..75).contains(&cached.expire_time.duration_since(now).as_secs())); + } + + // Case 2: Empty JSON + { + azure_exec.test_param = Some("case=empty-json"); + let result = azure_exec.get_token().await; + assert!(result.is_err(), "expected err got {:?}", result); + assert!(result.unwrap_err().is_auth_error()); + } + + // Case 3: Bad JSON + { + azure_exec.test_param = Some("case=bad-json"); + let result = azure_exec.get_token().await; + assert!(result.is_err(), "expected err got {:?}", result); + assert!(result.unwrap_err().is_auth_error()); + } + + // Case 4: HTTP 404 + { + azure_exec.test_param = Some("case=404"); + let result = azure_exec.get_token().await; + assert!(result.is_err(), "expected err got {:?}", result); + assert!(result.unwrap_err().is_auth_error()); + } + + // Case 5: HTTP 500 + { + azure_exec.test_param = Some("case=500"); + let result = azure_exec.get_token().await; + assert!(result.is_err(), "expected err got {:?}", result); + assert!(result.unwrap_err().is_auth_error()); + } + + // Case 6: Slow Response + { + azure_exec.test_param = Some("case=slow"); + let result = azure_exec.get_token().await; + assert!(result.is_err(), "expected err got {:?}", result); + assert!(result.unwrap_err().is_auth_error()); + } Ok(()) } From e7e679ce168e171f9d1c61119b121ffcdb45b444 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Tue, 9 May 2023 11:55:08 -0400 Subject: [PATCH 11/15] evergreen --- .evergreen/config.yml | 14 ++++++++++++++ .evergreen/feature-combinations.sh | 2 +- .evergreen/run-csfle-mock-azure-imds.sh | 9 +++++++++ .evergreen/run-csfle-tests.sh | 2 +- src/test/csfle.rs | 2 +- 5 files changed, 26 insertions(+), 3 deletions(-) create mode 100755 .evergreen/run-csfle-mock-azure-imds.sh diff --git a/.evergreen/config.yml b/.evergreen/config.yml index b62faf279..e3ebeef59 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -91,6 +91,8 @@ functions: export TOPOLOGY=${TOPOLOGY} export MONGODB_VERSION=${MONGODB_VERSION} + export AZURE_IMDS_MOCK_PORT=44175 + if [ "Windows_NT" != "$OS" ]; then ulimit -n 64000 fi @@ -488,6 +490,16 @@ functions: export TLS_FEATURE=${TLS_FEATURE} .evergreen/run-csfle-kmip-servers.sh + "run mock azure imds server": + - command: shell.exec + params: + shell: bash + working_dir: "src" + background: true + script: | + ${PREPARE_SHELL} + .evergreen/run-csfle-mock-azure-imds.sh + "build csfle expansions": - command: shell.exec params: @@ -1214,6 +1226,7 @@ tasks: - func: "install junit dependencies" - func: "bootstrap mongo-orchestration" - func: "run kmip server" + - func: "run mock azure imds server" - func: "build csfle expansions" - func: "run csfle tests" @@ -1229,6 +1242,7 @@ tasks: - func: "install junit dependencies" - func: "install libmongocrypt" - func: "run kmip server" + - func: "run mock azure imds server" - func: "build csfle expansions" - func: "run csfle serverless tests" diff --git a/.evergreen/feature-combinations.sh b/.evergreen/feature-combinations.sh index 4ec6ec4e2..ac04e1f6b 100755 --- a/.evergreen/feature-combinations.sh +++ b/.evergreen/feature-combinations.sh @@ -5,7 +5,7 @@ export NO_FEATURES='' # async-std-related features that conflict with the library's default features. export ASYNC_STD_FEATURES='--no-default-features --features async-std-runtime,sync' # All additional features that do not conflict with the default features. New features added to the library should also be added to this list. -export ADDITIONAL_FEATURES='--features tokio-sync,zstd-compression,snappy-compression,zlib-compression,openssl-tls,aws-auth,tracing-unstable,in-use-encryption-unstable' +export ADDITIONAL_FEATURES='--features tokio-sync,zstd-compression,snappy-compression,zlib-compression,openssl-tls,aws-auth,tracing-unstable,in-use-encryption-unstable,azure-kms' # Array of feature combinations that, in total, provides complete coverage of the driver. diff --git a/.evergreen/run-csfle-mock-azure-imds.sh b/.evergreen/run-csfle-mock-azure-imds.sh new file mode 100755 index 000000000..691110672 --- /dev/null +++ b/.evergreen/run-csfle-mock-azure-imds.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +. ${DRIVERS_TOOLS}/.evergreen/find-python3.sh +PYTHON=$(find_python3) + +function prepend() { while read line; do echo "${1}${line}"; done; } + +cd ${DRIVERS_TOOLS}/.evergreen/csfle +${PYTHON} bottle.py fake_azure:imds -b localhost:${AZURE_IMDS_MOCK_PORT} 2>&1 | prepend "[MOCK AZURE IMDS] " \ No newline at end of file diff --git a/.evergreen/run-csfle-tests.sh b/.evergreen/run-csfle-tests.sh index ade986bd0..dab04d157 100755 --- a/.evergreen/run-csfle-tests.sh +++ b/.evergreen/run-csfle-tests.sh @@ -7,7 +7,7 @@ source ./.evergreen/env.sh set -o xtrace -FEATURE_FLAGS="in-use-encryption-unstable,aws-auth,${TLS_FEATURE}" +FEATURE_FLAGS="in-use-encryption-unstable,aws-auth,azure-kms,${TLS_FEATURE}" OPTIONS="-- -Z unstable-options --format json --report-time" if [ "$SINGLE_THREAD" = true ]; then diff --git a/src/test/csfle.rs b/src/test/csfle.rs index 65ec30c2b..e766af94e 100644 --- a/src/test/csfle.rs +++ b/src/test/csfle.rs @@ -2869,7 +2869,7 @@ async fn azure_imds() -> Result<()> { let _guard = LOCK.run_concurrently().await; let mut azure_exec = crate::client::csfle::state_machine::azure::ExecutorState::new()?; - azure_exec.test_host = Some(("localhost", 8080)); + azure_exec.test_host = Some(("localhost", std::env::var("AZURE_IMDS_MOCK_PORT").unwrap().parse().unwrap())); // Case 1: Success { From 434150c78d9dfc461c085fb034662c44a5d003c9 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Tue, 9 May 2023 11:57:28 -0400 Subject: [PATCH 12/15] fmt --- src/client/auth/mod.rs | 11 +++++--- src/client/csfle/state_machine.rs | 44 +++++++++++++++---------------- src/runtime/http.rs | 2 +- src/test/csfle.rs | 8 +++++- 4 files changed, 38 insertions(+), 27 deletions(-) diff --git a/src/client/auth/mod.rs b/src/client/auth/mod.rs index 4a7a5d03b..edf9a0959 100644 --- a/src/client/auth/mod.rs +++ b/src/client/auth/mod.rs @@ -399,8 +399,7 @@ impl Credential { conn: &mut Connection, server_api: Option<&ServerApi>, first_round: Option, - #[cfg(feature = "aws-auth")] - http_client: &crate::runtime::HttpClient, + #[cfg(feature = "aws-auth")] http_client: &crate::runtime::HttpClient, ) -> Result<()> { let stream_description = conn.stream_description()?; @@ -431,7 +430,13 @@ impl Credential { // Authenticate according to the chosen mechanism. mechanism - .authenticate_stream(conn, self, server_api, #[cfg(feature = "aws-auth")] http_client) + .authenticate_stream( + conn, + self, + server_api, + #[cfg(feature = "aws-auth")] + http_client, + ) .await } diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index eac32e60b..8665f97e9 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -367,12 +367,15 @@ fn raw_to_doc(raw: &RawDocument) -> Result { #[cfg(feature = "azure-kms")] pub(crate) mod azure { - use bson::{RawDocumentBuf, rawdoc}; + use bson::{rawdoc, RawDocumentBuf}; use serde::Deserialize; + use std::time::{Duration, Instant}; use tokio::sync::Mutex; - use std::time::{Instant, Duration}; - use crate::{error::{Result, Error}, runtime::HttpClient}; + use crate::{ + error::{Error, Result}, + runtime::HttpClient, + }; #[derive(Debug)] pub(crate) struct ExecutorState { @@ -400,7 +403,9 @@ pub(crate) mod azure { pub(crate) async fn get_token(&self) -> Result { let mut cached_token = self.cached_access_token.lock().await; if let Some(cached) = &*cached_token { - if cached.expire_time.saturating_duration_since(Instant::now()) > Duration::from_secs(60) { + if cached.expire_time.saturating_duration_since(Instant::now()) + > Duration::from_secs(60) + { return Ok(cached.token_doc.clone()); } } @@ -412,20 +417,17 @@ pub(crate) mod azure { async fn fetch_new_token(&self) -> Result { let now = Instant::now(); - let server_response: ServerResponse = self.http.get_and_deserialize_json( - self.make_url()?, - &self.make_headers(), - ) - .await - .map_err(|e| Error::authentication_error("azure imds", &format!("{}", e)))?; - let expires_in_secs: u64 = server_response.expires_in - .parse() - .map_err(|e| { - Error::authentication_error( - "azure imds", - &format!("invalid `expires_in` response field: {}", e), - ) - })?; + let server_response: ServerResponse = self + .http + .get_and_deserialize_json(self.make_url()?, &self.make_headers()) + .await + .map_err(|e| Error::authentication_error("azure imds", &format!("{}", e)))?; + let expires_in_secs: u64 = server_response.expires_in.parse().map_err(|e| { + Error::authentication_error( + "azure imds", + &format!("invalid `expires_in` response field: {}", e), + ) + })?; Ok(CachedAccessToken { token_doc: rawdoc! { "accessToken": server_response.access_token.clone() }, expire_time: now + Duration::from_secs(expires_in_secs), @@ -447,11 +449,9 @@ pub(crate) mod azure { let url = { let mut url = url; if let Some((host, port)) = self.test_host { - url - .set_host(Some(host)) + url.set_host(Some(host)) .map_err(|e| Error::internal(format!("invalid test host: {}", e)))?; - url - .set_port(Some(port)) + url.set_port(Some(port)) .map_err(|()| Error::internal(format!("invalid test port")))?; } url diff --git a/src/runtime/http.rs b/src/runtime/http.rs index a614e3395..912aa4bd1 100644 --- a/src/runtime/http.rs +++ b/src/runtime/http.rs @@ -1,4 +1,4 @@ -use reqwest::{Method, Response, IntoUrl}; +use reqwest::{IntoUrl, Method, Response}; use serde::Deserialize; use crate::error::{Error, Result}; diff --git a/src/test/csfle.rs b/src/test/csfle.rs index e766af94e..dbb9f0379 100644 --- a/src/test/csfle.rs +++ b/src/test/csfle.rs @@ -2869,7 +2869,13 @@ async fn azure_imds() -> Result<()> { let _guard = LOCK.run_concurrently().await; let mut azure_exec = crate::client::csfle::state_machine::azure::ExecutorState::new()?; - azure_exec.test_host = Some(("localhost", std::env::var("AZURE_IMDS_MOCK_PORT").unwrap().parse().unwrap())); + azure_exec.test_host = Some(( + "localhost", + std::env::var("AZURE_IMDS_MOCK_PORT") + .unwrap() + .parse() + .unwrap(), + )); // Case 1: Success { From 6ebfe557de9c89ba7d7dcc397a9a2a8d49e67b5e Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Tue, 9 May 2023 12:03:29 -0400 Subject: [PATCH 13/15] more fmt --- src/cmap/establish/handshake/test.rs | 22 ++++++++-------------- src/cmap/test/integration.rs | 18 ++++++------------ src/cmap/test/mod.rs | 6 +++--- src/sdam/topology.rs | 5 ++--- 4 files changed, 19 insertions(+), 32 deletions(-) diff --git a/src/cmap/establish/handshake/test.rs b/src/cmap/establish/handshake/test.rs index 568cf51e7..127b9fdf5 100644 --- a/src/cmap/establish/handshake/test.rs +++ b/src/cmap/establish/handshake/test.rs @@ -1,21 +1,15 @@ use super::Handshaker; -use crate::{ - bson::doc, - cmap::establish::handshake::HandshakerOptions, - options::DriverInfo, -}; +use crate::{bson::doc, cmap::establish::handshake::HandshakerOptions, options::DriverInfo}; #[test] fn metadata_no_options() { - let handshaker = Handshaker::new( - HandshakerOptions { - app_name: None, - compressors: None, - driver_info: None, - server_api: None, - load_balanced: false, - }, - ); + let handshaker = Handshaker::new(HandshakerOptions { + app_name: None, + compressors: None, + driver_info: None, + server_api: None, + load_balanced: false, + }); let metadata = handshaker.command.body.get_document("client").unwrap(); assert!(!metadata.contains_key("application")); diff --git a/src/cmap/test/integration.rs b/src/cmap/test/integration.rs index bf5ac21e1..ecfc24c6e 100644 --- a/src/cmap/test/integration.rs +++ b/src/cmap/test/integration.rs @@ -50,10 +50,8 @@ async fn acquire_connection_and_send_command() { let pool = ConnectionPool::new( client_options.hosts[0].clone(), - ConnectionEstablisher::new( - EstablisherOptions::from_client_options(&client_options), - ) - .unwrap(), + ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options)) + .unwrap(), TopologyUpdater::channel().0, bson::oid::ObjectId::new(), Some(pool_options), @@ -131,10 +129,8 @@ async fn concurrent_connections() { let pool = ConnectionPool::new( CLIENT_OPTIONS.get().await.hosts[0].clone(), - ConnectionEstablisher::new( - EstablisherOptions::from_client_options(&client_options), - ) - .unwrap(), + ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options)) + .unwrap(), TopologyUpdater::channel().0, bson::oid::ObjectId::new(), Some(options), @@ -223,10 +219,8 @@ async fn connection_error_during_establishment() { Some(handler.clone() as Arc); let pool = ConnectionPool::new( client_options.hosts[0].clone(), - ConnectionEstablisher::new( - EstablisherOptions::from_client_options(&client_options), - ) - .unwrap(), + ConnectionEstablisher::new(EstablisherOptions::from_client_options(&client_options)) + .unwrap(), TopologyUpdater::channel().0, bson::oid::ObjectId::new(), Some(options), diff --git a/src/cmap/test/mod.rs b/src/cmap/test/mod.rs index 38b109719..8772ac739 100644 --- a/src/cmap/test/mod.rs +++ b/src/cmap/test/mod.rs @@ -159,9 +159,9 @@ impl Executor { let pool = ConnectionPool::new( CLIENT_OPTIONS.get().await.hosts[0].clone(), - ConnectionEstablisher::new( - EstablisherOptions::from_client_options(CLIENT_OPTIONS.get().await), - ) + ConnectionEstablisher::new(EstablisherOptions::from_client_options( + CLIENT_OPTIONS.get().await, + )) .unwrap(), updater, bson::oid::ObjectId::new(), diff --git a/src/sdam/topology.rs b/src/sdam/topology.rs index f6cb4ac0e..1f061da26 100644 --- a/src/sdam/topology.rs +++ b/src/sdam/topology.rs @@ -92,9 +92,8 @@ impl Topology { }; let (watcher, publisher) = TopologyWatcher::channel(state); - let connection_establisher = ConnectionEstablisher::new( - EstablisherOptions::from_client_options(&options), - )?; + let connection_establisher = + ConnectionEstablisher::new(EstablisherOptions::from_client_options(&options))?; let id = ObjectId::new(); From 342a09e2a6dda5e4800df135900cc0789fb84050 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Tue, 9 May 2023 14:02:13 -0400 Subject: [PATCH 14/15] clippy --- src/client/csfle/state_machine.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index 8665f97e9..b0e1fd397 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -428,6 +428,7 @@ pub(crate) mod azure { &format!("invalid `expires_in` response field: {}", e), ) })?; + #[allow(clippy::redundant_clone)] Ok(CachedAccessToken { token_doc: rawdoc! { "accessToken": server_response.access_token.clone() }, expire_time: now + Duration::from_secs(expires_in_secs), @@ -452,7 +453,7 @@ pub(crate) mod azure { url.set_host(Some(host)) .map_err(|e| Error::internal(format!("invalid test host: {}", e)))?; url.set_port(Some(port)) - .map_err(|()| Error::internal(format!("invalid test port")))?; + .map_err(|()| Error::internal(format!("invalid test port {}", port)))?; } url }; From 93603882acf3addf2c1be6cc9f6ccaa78be91b7b Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Fri, 12 May 2023 11:23:53 -0400 Subject: [PATCH 15/15] update feature doc --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index 517d39fc6..e97439fad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,7 @@ bson-uuid-1 = ["bson/uuid-1"] aws-auth = ["reqwest"] # Enable support for on-demand Azure KMS credentials. +# This can only be used with the tokio-runtime feature flag. azure-kms = ["reqwest"] zstd-compression = ["zstd"]