diff --git a/src/client/csfle/state_machine.rs b/src/client/csfle/state_machine.rs index 39d965631..9ffaf48a9 100644 --- a/src/client/csfle/state_machine.rs +++ b/src/client/csfle/state_machine.rs @@ -14,12 +14,12 @@ use tokio::{ }; use crate::{ - client::{auth::Credential, options::ServerAddress, WeakClient}, + client::{options::ServerAddress, WeakClient}, coll::options::FindOptions, error::{Error, Result}, operation::{RawOutput, RunCommand}, options::ReadConcern, - runtime::{AsyncStream, HttpClient, Process, TlsConfig}, + runtime::{AsyncStream, Process, TlsConfig}, Client, Namespace, }; @@ -209,6 +209,7 @@ impl CryptExecutor { } State::NeedKmsCredentials => { let ctx = result_mut(&mut ctx)?; + #[allow(unused_mut)] let mut out = rawdoc! {}; if self .kms_providers @@ -219,8 +220,8 @@ impl CryptExecutor { #[cfg(feature = "aws-auth")] { let aws_creds = crate::client::auth::aws::AwsCredential::get( - &Credential::default(), - &HttpClient::default(), + &crate::client::auth::Credential::default(), + &crate::runtime::HttpClient::default(), ) .await?; let mut creds = rawdoc! { diff --git a/src/db/mod.rs b/src/db/mod.rs index 6635b5dd5..a39cacd68 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -6,6 +6,8 @@ use std::{fmt::Debug, sync::Arc}; use bson::doc; use futures_util::stream::TryStreamExt; +#[cfg(feature = "in-use-encryption-unstable")] +use crate::client_encryption::{ClientEncryption, MasterKey}; use crate::{ bson::{Bson, Document}, change_stream::{ @@ -421,6 +423,56 @@ impl Database { self.create_collection_common(name, options, session).await } + /// Creates a new collection with encrypted fields, automatically creating new data encryption + /// keys when needed based on the configured [`CreateCollectionOptions::encrypted_fields`]. + /// + /// Returns the potentially updated `encrypted_fields` along with status, as keys may have been + /// created even when a failure occurs. + /// + /// Does not affect any auto encryption settings on existing MongoClients that are already + /// configured with auto encryption. + #[cfg(feature = "in-use-encryption-unstable")] + pub async fn create_encrypted_collection( + &self, + ce: &ClientEncryption, + name: impl AsRef, + options: impl Into>, + master_key: MasterKey, + ) -> (Document, Result<()>) { + let options: Option = options.into(); + let ef = match options.as_ref().and_then(|o| o.encrypted_fields.as_ref()) { + Some(ef) => ef, + None => { + return ( + doc! {}, + Err(Error::invalid_argument( + "no encrypted_fields defined for collection", + )), + ); + } + }; + let mut ef_prime = ef.clone(); + if let Ok(fields) = ef_prime.get_array_mut("fields") { + for f in fields { + let f_doc = if let Some(d) = f.as_document_mut() { + d + } else { + continue; + }; + if f_doc.get("keyId") == Some(&Bson::Null) { + let d = match ce.create_data_key(master_key.clone()).run().await { + Ok(v) => v, + Err(e) => return (ef_prime, Err(e)), + }; + f_doc.insert("keyId", d); + } + } + } + let mut opts_prime = options.unwrap().clone(); // safe unwrap: no options would be caught by the encrypted_fields check + opts_prime.encrypted_fields = Some(ef_prime.clone()); + (ef_prime, self.create_collection(name, opts_prime).await) + } + pub(crate) async fn run_command_common( &self, command: Document, diff --git a/src/error.rs b/src/error.rs index 1f50c2b4c..068900755 100644 --- a/src/error.rs +++ b/src/error.rs @@ -158,7 +158,7 @@ impl Error { } pub(crate) fn is_max_time_ms_expired_error(&self) -> bool { - self.code() == Some(50) + self.sdam_code() == Some(50) } /// Whether a read operation should be retried if this error occurs. @@ -166,7 +166,7 @@ impl Error { if self.is_network_error() { return true; } - match self.code() { + match self.sdam_code() { Some(code) => RETRYABLE_READ_CODES.contains(&code), None => false, } @@ -187,7 +187,7 @@ impl Error { if self.is_network_error() { return true; } - match &self.code() { + match &self.sdam_code() { Some(code) => RETRYABLE_WRITE_CODES.contains(code), None => false, } @@ -201,7 +201,7 @@ impl Error { { return true; } - match self.code() { + match self.sdam_code() { Some(code) => UNKNOWN_TRANSACTION_COMMIT_RESULT_LABEL_CODES.contains(&code), None => false, } @@ -259,7 +259,7 @@ impl Error { /// Gets the code from this error for performing SDAM updates, if applicable. /// Any codes contained in WriteErrors are ignored. - pub(crate) fn code(&self) -> Option { + pub(crate) fn sdam_code(&self) -> Option { match self.kind.as_ref() { ErrorKind::Command(command_error) => Some(command_error.code), // According to SDAM spec, write concern error codes MUST also be checked, and @@ -271,7 +271,22 @@ impl Error { ErrorKind::Write(WriteFailure::WriteConcernError(wc_error)) => Some(wc_error.code), _ => None, } - .or_else(|| self.source.as_ref().and_then(|s| s.code())) + .or_else(|| self.source.as_ref().and_then(|s| s.sdam_code())) + } + + /// Gets the code from this error. + #[allow(unused)] + pub(crate) fn code(&self) -> Option { + match self.kind.as_ref() { + ErrorKind::Command(command_error) => Some(command_error.code), + ErrorKind::BulkWrite(BulkWriteFailure { + write_concern_error: Some(wc_error), + .. + }) => Some(wc_error.code), + ErrorKind::Write(e) => Some(e.code()), + _ => None, + } + .or_else(|| self.source.as_ref().and_then(|s| s.sdam_code())) } /// Gets the message for this error, if applicable, for use in testing. @@ -333,21 +348,21 @@ impl Error { /// If this error corresponds to a "not writable primary" error as per the SDAM spec. pub(crate) fn is_notwritableprimary(&self) -> bool { - self.code() + self.sdam_code() .map(|code| NOTWRITABLEPRIMARY_CODES.contains(&code)) .unwrap_or(false) } /// If this error corresponds to a "node is recovering" error as per the SDAM spec. pub(crate) fn is_recovering(&self) -> bool { - self.code() + self.sdam_code() .map(|code| RECOVERING_CODES.contains(&code)) .unwrap_or(false) } /// If this error corresponds to a "node is shutting down" error as per the SDAM spec. pub(crate) fn is_shutting_down(&self) -> bool { - self.code() + self.sdam_code() .map(|code| SHUTTING_DOWN_CODES.contains(&code)) .unwrap_or(false) } @@ -361,7 +376,7 @@ impl Error { if !self.is_server_error() { return true; } - let code = self.code(); + let code = self.sdam_code(); if code == Some(43) { return true; } @@ -388,6 +403,11 @@ impl Error { matches!(self.kind.as_ref(), ErrorKind::IncompatibleServer { .. }) } + #[allow(unused)] + pub(crate) fn is_invalid_argument(&self) -> bool { + matches!(self.kind.as_ref(), ErrorKind::InvalidArgument { .. }) + } + pub(crate) fn with_source>>(mut self, source: E) -> Self { self.source = source.into().map(Box::new); self @@ -825,6 +845,13 @@ impl WriteFailure { .into()) } } + + pub(crate) fn code(&self) -> i32 { + match self { + Self::WriteConcernError(e) => e.code, + Self::WriteError(e) => e.code, + } + } } /// An error that occurred during a GridFS operation. diff --git a/src/gridfs/upload.rs b/src/gridfs/upload.rs index bca9ea5f5..303114c73 100644 --- a/src/gridfs/upload.rs +++ b/src/gridfs/upload.rs @@ -180,7 +180,7 @@ impl GridFsBucket { .build(); // Ignore NamespaceExists errors if the collection has already been created. if let Err(error) = self.inner.db.create_collection(coll.name(), options).await { - if error.code() != Some(48) { + if error.sdam_code() != Some(48) { return Err(error); } } diff --git a/src/sync/db.rs b/src/sync/db.rs index c97a03fe7..34e0163e0 100644 --- a/src/sync/db.rs +++ b/src/sync/db.rs @@ -219,6 +219,28 @@ impl Database { )) } + /// Creates a new collection with encrypted fields, automatically creating new data encryption + /// keys when needed based on the configured [`CreateCollectionOptions::encrypted_fields`]. + /// + /// Returns the potentially updated `encrypted_fields` along with status, as keys may have been + /// created even when a failure occurs. + /// + /// Does not affect any auto encryption settings on existing MongoClients that are already + /// configured with auto encryption. + #[cfg(feature = "in-use-encryption-unstable")] + pub fn create_encrypted_collection( + &self, + ce: &crate::client_encryption::ClientEncryption, + name: impl AsRef, + options: impl Into>, + master_key: crate::client_encryption::MasterKey, + ) -> (Document, Result<()>) { + runtime::block_on( + self.async_database + .create_encrypted_collection(ce, name, options, master_key), + ) + } + /// Runs a database-level command. /// /// Note that no inspection is done on `doc`, so the command will not use the database's default diff --git a/src/test/csfle.rs b/src/test/csfle.rs index eea1b7520..3b7d94d76 100644 --- a/src/test/csfle.rs +++ b/src/test/csfle.rs @@ -2485,7 +2485,7 @@ async fn unique_index_keyaltnames_add_key_alt_name() -> Result<()> { // `Error::code` skips write errors per the SDAM spec, but we need those. fn write_err_code(err: &crate::error::Error) -> Option { - if let Some(code) = err.code() { + if let Some(code) = err.sdam_code() { return Some(code); } match *err.kind { @@ -2555,7 +2555,7 @@ async fn decryption_events_command_error() -> Result<()> { .aggregate(vec![doc! { "$count": "total" }], None) .await .unwrap_err(); - assert_eq!(Some(123), err.code()); + assert_eq!(Some(123), err.sdam_code()); assert!(td.ev_handler.failed.lock().unwrap().is_some()); Ok(()) @@ -2881,6 +2881,138 @@ async fn bypass_mongocryptd_client() -> Result<()> { Ok(()) } +// Prost test 21. Automatic Data Encryption Keys +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn auto_encryption_keys_local() -> Result<()> { + auto_encryption_keys(MasterKey::Local).await +} + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn auto_encryption_keys_aws() -> Result<()> { + auto_encryption_keys(MasterKey::Aws { + region: "us-east-1".to_string(), + key: "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0" + .to_string(), + endpoint: None, + }) + .await +} + +async fn auto_encryption_keys(master_key: MasterKey) -> Result<()> { + if !check_env("custom_key_material", false) { + return Ok(()); + } + let _guard = LOCK.run_exclusively().await; + + let client = Client::test_builder().build().await; + if client.server_version_lt(6, 0) { + log_uncaptured("Skipping auto_encryption_key test: server < 6.0"); + return Ok(()); + } + if client.is_standalone() { + log_uncaptured("Skipping auto_encryption_key test: standalone server"); + return Ok(()); + } + let db = client.database("test_auto_encryption_keys"); + db.drop(None).await?; + let ce = ClientEncryption::new( + client.into_client(), + KV_NAMESPACE.clone(), + KMS_PROVIDERS + .iter() + .filter(|(p, ..)| p == &KmsProvider::Local || p == &KmsProvider::Aws) + .cloned() + .collect::>(), + )?; + + // Case 1: Simple Creation and Validation + let opts = CreateCollectionOptions::builder() + .encrypted_fields(doc! { + "fields": [{ + "path": "ssn", + "bsonType": "string", + "keyId": Bson::Null, + }], + }) + .build(); + db.create_encrypted_collection(&ce, "case_1", opts, master_key.clone()) + .await + .1?; + let coll = db.collection::("case_1"); + let result = coll.insert_one(doc! { "ssn": "123-45-6789" }, None).await; + assert!( + result.as_ref().unwrap_err().code() == Some(121), + "Expected error 121 (failed validation), got {:?}", + result + ); + + // Case 2: Missing encryptedFields + let result = db + .create_encrypted_collection(&ce, "case_2", None, master_key.clone()) + .await + .1; + assert!( + result.as_ref().unwrap_err().is_invalid_argument(), + "Expected invalid argument error, got {:?}", + result + ); + + // Case 3: Invalid keyId + let opts = CreateCollectionOptions::builder() + .encrypted_fields(doc! { + "fields": [{ + "path": "ssn", + "bsonType": "string", + "keyId": false, + }], + }) + .build(); + let result = db + .create_encrypted_collection(&ce, "case_1", opts, master_key.clone()) + .await + .1; + assert!( + result.as_ref().unwrap_err().code() == Some(14), + "Expected error 14 (type mismatch), got {:?}", + result + ); + + // Case 4: Insert encrypted value + let opts = CreateCollectionOptions::builder() + .encrypted_fields(doc! { + "fields": [{ + "path": "ssn", + "bsonType": "string", + "keyId": Bson::Null, + }], + }) + .build(); + let (ef, result) = db + .create_encrypted_collection(&ce, "case_4", opts, master_key.clone()) + .await; + result?; + let key = match ef.get_array("fields")?[0] + .as_document() + .unwrap() + .get("keyId") + .unwrap() + { + Bson::Binary(bin) => bin.clone(), + v => panic!("invalid keyId {:?}", v), + }; + let encrypted_payload = ce + .encrypt("123-45-6789", key, Algorithm::Unindexed) + .run() + .await?; + let coll = db.collection::("case_1"); + coll.insert_one(doc! { "ssn": encrypted_payload }, None) + .await?; + + Ok(()) +} + // Prose test 22. Range explicit encryption #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] diff --git a/src/test/spec/gridfs.rs b/src/test/spec/gridfs.rs index 586834120..10bfe7311 100644 --- a/src/test/spec/gridfs.rs +++ b/src/test/spec/gridfs.rs @@ -236,7 +236,7 @@ async fn upload_stream_errors() { .unwrap(); let error = get_mongo_error(upload_stream.write_all(&[11]).await); - assert_eq!(error.code(), Some(1234)); + assert_eq!(error.sdam_code(), Some(1234)); assert_closed(&bucket, upload_stream).await; @@ -258,7 +258,7 @@ async fn upload_stream_errors() { .unwrap(); let error = get_mongo_error(upload_stream.close().await); - assert_eq!(error.code(), Some(1234)); + assert_eq!(error.sdam_code(), Some(1234)); assert_closed(&bucket, upload_stream).await; } diff --git a/src/test/spec/transactions.rs b/src/test/spec/transactions.rs index 230e57223..08912dc92 100644 --- a/src/test/spec/transactions.rs +++ b/src/test/spec/transactions.rs @@ -255,7 +255,7 @@ async fn convenient_api_retry_timeout_commit_unknown() { .await; let err = result.unwrap_err(); - assert_eq!(Some(251), err.code()); + assert_eq!(Some(251), err.sdam_code()); } #[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))] diff --git a/src/test/spec/unified_runner/operation.rs b/src/test/spec/unified_runner/operation.rs index 06ac8fc9e..429213463 100644 --- a/src/test/spec/unified_runner/operation.rs +++ b/src/test/spec/unified_runner/operation.rs @@ -2065,7 +2065,7 @@ impl TestOperation for AssertIndexNotExists { match coll.list_index_names().await { Ok(indexes) => assert!(!indexes.contains(&self.index_name)), // a namespace not found error indicates that the index does not exist - Err(err) => assert_eq!(err.code(), Some(26)), + Err(err) => assert_eq!(err.sdam_code(), Some(26)), } } .boxed() diff --git a/src/test/spec/unified_runner/test_file.rs b/src/test/spec/unified_runner/test_file.rs index 0379aa166..ea18353af 100644 --- a/src/test/spec/unified_runner/test_file.rs +++ b/src/test/spec/unified_runner/test_file.rs @@ -490,7 +490,7 @@ impl ExpectError { } } if let Some(error_code) = self.error_code { - match &error.code() { + match &error.sdam_code() { Some(code) => { if code != &error_code { return Err(format!( diff --git a/src/test/spec/v2_runner/mod.rs b/src/test/spec/v2_runner/mod.rs index 6b3dfdf21..d527c00d2 100644 --- a/src/test/spec/v2_runner/mod.rs +++ b/src/test/spec/v2_runner/mod.rs @@ -467,7 +467,7 @@ async fn run_v2_test(path: std::path::PathBuf, test_file: TestFile) { .await { Ok(_) => {} - Err(err) => match err.code() { + Err(err) => match err.sdam_code() { Some(11601) => {} _ => panic!("{}: killAllSessions failed", test.description), }, diff --git a/src/test/spec/v2_runner/operation.rs b/src/test/spec/v2_runner/operation.rs index 1a7d2564f..e5a18cdce 100644 --- a/src/test/spec/v2_runner/operation.rs +++ b/src/test/spec/v2_runner/operation.rs @@ -137,7 +137,7 @@ impl Operation { ); } if let Some(error_code) = operation_error.error_code { - let code = error.code().unwrap(); + let code = error.sdam_code().unwrap(); assert_eq!(error_code, code); } if let Some(error_labels_contain) = &operation_error.error_labels_contain { @@ -1564,7 +1564,7 @@ impl TestOperation for AssertIndexNotExists { match coll.list_index_names().await { Ok(indexes) => assert!(!indexes.contains(&self.index)), // a namespace not found error indicates that the index does not exist - Err(err) => assert_eq!(err.code(), Some(26)), + Err(err) => assert_eq!(err.sdam_code(), Some(26)), } Ok(None) }