From 18275c2acf2ee7b2459f45e8c9ad0925eb48c913 Mon Sep 17 00:00:00 2001 From: Isabel Atkinson Date: Tue, 8 Apr 2025 15:07:30 -0600 Subject: [PATCH 1/8] perform getMore manually --- src/client/executor.rs | 4 +- src/operation/bulk_write.rs | 158 +++++++++++++++++++++++------------- 2 files changed, 104 insertions(+), 58 deletions(-) diff --git a/src/client/executor.rs b/src/client/executor.rs index 937d78ba9..08408c70d 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -189,7 +189,7 @@ impl Client { Ok(SessionCursor::new(self.clone(), details.output, pinned)) } - fn is_load_balanced(&self) -> bool { + pub(crate) fn is_load_balanced(&self) -> bool { self.inner.options.load_balanced.unwrap_or(false) } @@ -484,7 +484,7 @@ impl Client { } /// Executes an operation on a given connection, optionally using a provided session. - async fn execute_operation_on_connection( + pub(crate) async fn execute_operation_on_connection( &self, op: &mut T, connection: &mut PooledConnection, diff --git a/src/operation/bulk_write.rs b/src/operation/bulk_write.rs index dee2ad412..3c68265dd 100644 --- a/src/operation/bulk_write.rs +++ b/src/operation/bulk_write.rs @@ -12,7 +12,7 @@ use crate::{ cmap::{Command, RawCommandResponse, StreamDescription}, cursor::CursorSpecification, error::{BulkWriteError, Error, ErrorKind, Result}, - operation::OperationWithDefaults, + operation::{GetMore, OperationWithDefaults}, options::{BulkWriteOptions, OperationType, WriteModel}, results::{BulkWriteResult, DeleteResult, InsertOneResult, UpdateResult}, BoxFuture, @@ -75,49 +75,94 @@ where error: &mut BulkWriteError, ) -> Result<()> { while let Some(response) = stream.try_next().await? { - let index = response.index + self.offset; - match response.result { - SingleOperationResult::Success { - n, - n_modified, - upserted, - } => { - let model = self.get_model(response.index)?; - match model.operation_type() { - OperationType::Insert => { - let inserted_id = self.get_inserted_id(index)?; - let insert_result = InsertOneResult { inserted_id }; - result.add_insert_result(index, insert_result); - } - OperationType::Update => { - let modified_count = - n_modified.ok_or_else(|| ErrorKind::InvalidResponse { - message: "nModified value not returned for update bulkWrite \ - operation" - .into(), - })?; - let update_result = UpdateResult { - matched_count: n, - modified_count, - upserted_id: upserted.map(|upserted| upserted.id), - }; - result.add_update_result(index, update_result); - } - OperationType::Delete => { - let delete_result = DeleteResult { deleted_count: n }; - result.add_delete_result(index, delete_result); - } + self.handle_individual_response(response, result, error)?; + } + Ok(()) + } + + fn handle_individual_response( + &self, + response: SingleOperationResponse, + result: &mut impl BulkWriteResult, + error: &mut BulkWriteError, + ) -> Result<()> { + let index = response.index + self.offset; + match response.result { + SingleOperationResult::Success { + n, + n_modified, + upserted, + } => { + let model = self.get_model(response.index)?; + match model.operation_type() { + OperationType::Insert => { + let inserted_id = self.get_inserted_id(index)?; + let insert_result = InsertOneResult { inserted_id }; + result.add_insert_result(index, insert_result); + } + OperationType::Update => { + let modified_count = + n_modified.ok_or_else(|| ErrorKind::InvalidResponse { + message: "nModified value not returned for update bulkWrite \ + operation" + .into(), + })?; + let update_result = UpdateResult { + matched_count: n, + modified_count, + upserted_id: upserted.map(|upserted| upserted.id), + }; + result.add_update_result(index, update_result); + } + OperationType::Delete => { + let delete_result = DeleteResult { deleted_count: n }; + result.add_delete_result(index, delete_result); } - } - SingleOperationResult::Error(write_error) => { - error.write_errors.insert(index, write_error); } } + SingleOperationResult::Error(write_error) => { + error.write_errors.insert(index, write_error); + } } - Ok(()) } + async fn do_get_mores<'b>( + &self, + context: &mut ExecutionContext<'b>, + cursor_specification: CursorSpecification, + result: &mut impl BulkWriteResult, + error: &mut BulkWriteError, + ) -> Result<()> { + let mut responses = cursor_specification.initial_buffer; + let mut more_responses = cursor_specification.info.id != 0; + loop { + for response_document in &responses { + let response: SingleOperationResponse = + bson::from_slice(response_document.as_bytes())?; + self.handle_individual_response(response, result, error)?; + } + + if !more_responses { + return Ok(()); + } + + let mut get_more = GetMore::new(cursor_specification.info.clone(), None); + let get_more_response = self + .client + .execute_operation_on_connection( + &mut get_more, + context.connection, + &mut context.session, + None, + Retryability::None, + ) + .await?; + responses = get_more_response.batch; + more_responses = get_more_response.id != 0; + } + } + fn get_model(&self, index: usize) -> Result<&WriteModel> { self.models.get(index).ok_or_else(|| { ErrorKind::InvalidResponse { @@ -293,27 +338,28 @@ where self.options.and_then(|options| options.comment.clone()), ); - let pinned_connection = self.client.pin_connection_for_cursor( - &specification, - context.connection, - context.session.as_deref_mut(), - )?; - let iteration_result = match context.session { - Some(session) => { - let mut session_cursor = - SessionCursor::new(self.client.clone(), specification, pinned_connection); - self.iterate_results_cursor( - session_cursor.stream(session), - &mut result, - &mut error, - ) + let iteration_result = if self.client.is_load_balanced() { + // Using a cursor with a pinned connection is not feasible here; see RUST-2131 for + // more details. + self.do_get_mores(&mut context, specification, &mut result, &mut error) .await - } - None => { - let cursor = - Cursor::new(self.client.clone(), specification, None, pinned_connection); - self.iterate_results_cursor(cursor, &mut result, &mut error) + } else { + match context.session { + Some(session) => { + let mut session_cursor = + SessionCursor::new(self.client.clone(), specification, None); + self.iterate_results_cursor( + session_cursor.stream(session), + &mut result, + &mut error, + ) .await + } + None => { + let cursor = Cursor::new(self.client.clone(), specification, None, None); + self.iterate_results_cursor(cursor, &mut result, &mut error) + .await + } } }; From 4e8bf43c93ccf6bb2763d4d46605dc8071507898 Mon Sep 17 00:00:00 2001 From: Isabel Atkinson Date: Tue, 8 Apr 2025 15:08:25 -0600 Subject: [PATCH 2/8] unskip tests --- src/test/bulk_write.rs | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/src/test/bulk_write.rs b/src/test/bulk_write.rs index c37de8dd7..4772778bc 100644 --- a/src/test/bulk_write.rs +++ b/src/test/bulk_write.rs @@ -13,7 +13,6 @@ use crate::{ log_uncaptured, server_version_gte, server_version_lt, - topology_is_load_balanced, topology_is_sharded, topology_is_standalone, util::fail_point::{FailPoint, FailPointMode}, @@ -179,11 +178,6 @@ async fn write_error_batches() { log_uncaptured("skipping write_error_batches: bulkWrite requires 8.0+"); return; } - // TODO RUST-2131 - if topology_is_load_balanced().await { - log_uncaptured("skipping write_error_batches: load-balanced topology"); - return; - } let mut client = Client::for_test().monitor_events().await; @@ -241,11 +235,6 @@ async fn successful_cursor_iteration() { log_uncaptured("skipping successful_cursor_iteration: bulkWrite requires 8.0+"); return; } - // TODO RUST-2131 - if topology_is_load_balanced().await { - log_uncaptured("skipping successful_cursor_iteration: load-balanced topology"); - return; - } let client = Client::for_test().monitor_events().await; @@ -287,11 +276,6 @@ async fn cursor_iteration_in_a_transaction() { ); return; } - // TODO RUST-2131 - if topology_is_load_balanced().await { - log_uncaptured("skipping cursor_iteration_in_a_transaction: load-balanced topology"); - return; - } let client = Client::for_test().monitor_events().await; @@ -338,11 +322,6 @@ async fn failed_cursor_iteration() { log_uncaptured("skipping failed_cursor_iteration: bulkWrite requires 8.0+"); return; } - // TODO RUST-2131 - if topology_is_load_balanced().await { - log_uncaptured("skipping failed_cursor_iteration: load-balanced topology"); - return; - } let mut options = get_client_options().await.clone(); if topology_is_sharded().await { From eec84e8eb281d9f163c2f58be466c5529c75ce75 Mon Sep 17 00:00:00 2001 From: Isabel Atkinson Date: Wed, 9 Apr 2025 15:20:29 -0600 Subject: [PATCH 3/8] propagate txn number --- src/client/executor.rs | 31 ++++----------- src/client/session.rs | 19 ++++++++-- src/operation/bulk_write.rs | 76 +++++++++++++++++++------------------ 3 files changed, 63 insertions(+), 63 deletions(-) diff --git a/src/client/executor.rs b/src/client/executor.rs index 08408c70d..a221803b1 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -382,10 +382,14 @@ impl Client { retry.first_error()?; } - let txn_number = retry - .as_ref() - .and_then(|r| r.prior_txn_number) - .or_else(|| get_txn_number(&mut session, retryability)); + let txn_number = + if let Some(txn_number) = retry.as_ref().and_then(|r| r.prior_txn_number) { + Some(txn_number) + } else { + session + .as_mut() + .and_then(|s| s.get_txn_number_for_operation(retryability)) + }; let details = match self .execute_operation_on_connection( @@ -957,25 +961,6 @@ async fn get_connection( } } -fn get_txn_number( - session: &mut Option<&mut ClientSession>, - retryability: Retryability, -) -> Option { - match session { - Some(ref mut session) => { - if session.transaction.state != TransactionState::None { - Some(session.txn_number()) - } else { - match retryability { - Retryability::Write => Some(session.get_and_increment_txn_number()), - _ => None, - } - } - } - None => None, - } -} - impl Error { /// Adds the necessary labels to this Error, and unpins the session if needed. /// diff --git a/src/client/session.rs b/src/client/session.rs index 071bceb5d..b87a26207 100644 --- a/src/client/session.rs +++ b/src/client/session.rs @@ -16,6 +16,7 @@ use uuid::Uuid; use crate::{ bson::{doc, spec::BinarySubtype, Binary, Bson, Document, Timestamp}, cmap::conn::PinnedConnectionHandle, + operation::Retryability, options::{SessionOptions, TransactionOptions}, sdam::ServerInfo, selection_criteria::SelectionCriteria, @@ -310,10 +311,20 @@ impl ClientSession { self.server_session.txn_number += 1; } - /// Increments the txn_number and returns the new value. - pub(crate) fn get_and_increment_txn_number(&mut self) -> i64 { - self.increment_txn_number(); - self.server_session.txn_number + /// Gets the txn_number to use for an operation based on the current transaction status and the + /// operation's retryability. + pub(crate) fn get_txn_number_for_operation( + &mut self, + retryability: Retryability, + ) -> Option { + if self.transaction.state != TransactionState::None { + Some(self.txn_number()) + } else if retryability == Retryability::Write { + self.increment_txn_number(); + Some(self.txn_number()) + } else { + None + } } /// Pin mongos to session. diff --git a/src/operation/bulk_write.rs b/src/operation/bulk_write.rs index 3c68265dd..0ccea4316 100644 --- a/src/operation/bulk_write.rs +++ b/src/operation/bulk_write.rs @@ -80,6 +80,46 @@ where Ok(()) } + async fn do_get_mores( + &self, + context: &mut ExecutionContext<'_>, + cursor_specification: CursorSpecification, + result: &mut impl BulkWriteResult, + error: &mut BulkWriteError, + ) -> Result<()> { + let mut responses = cursor_specification.initial_buffer; + let mut more_responses = cursor_specification.info.id != 0; + loop { + for response_document in &responses { + let response: SingleOperationResponse = + bson::from_slice(response_document.as_bytes())?; + self.handle_individual_response(response, result, error)?; + } + + if !more_responses { + return Ok(()); + } + + let mut get_more = GetMore::new(cursor_specification.info.clone(), None); + let txn_number = context + .session + .as_mut() + .and_then(|s| s.get_txn_number_for_operation(get_more.retryability())); + let get_more_response = self + .client + .execute_operation_on_connection( + &mut get_more, + context.connection, + &mut context.session, + txn_number, + Retryability::None, + ) + .await?; + responses = get_more_response.batch; + more_responses = get_more_response.id != 0; + } + } + fn handle_individual_response( &self, response: SingleOperationResponse, @@ -127,42 +167,6 @@ where Ok(()) } - async fn do_get_mores<'b>( - &self, - context: &mut ExecutionContext<'b>, - cursor_specification: CursorSpecification, - result: &mut impl BulkWriteResult, - error: &mut BulkWriteError, - ) -> Result<()> { - let mut responses = cursor_specification.initial_buffer; - let mut more_responses = cursor_specification.info.id != 0; - loop { - for response_document in &responses { - let response: SingleOperationResponse = - bson::from_slice(response_document.as_bytes())?; - self.handle_individual_response(response, result, error)?; - } - - if !more_responses { - return Ok(()); - } - - let mut get_more = GetMore::new(cursor_specification.info.clone(), None); - let get_more_response = self - .client - .execute_operation_on_connection( - &mut get_more, - context.connection, - &mut context.session, - None, - Retryability::None, - ) - .await?; - responses = get_more_response.batch; - more_responses = get_more_response.id != 0; - } - } - fn get_model(&self, index: usize) -> Result<&WriteModel> { self.models.get(index).ok_or_else(|| { ErrorKind::InvalidResponse { From 0dba7d39b3a1cbdf6fe5dfd75857f095ac982a8d Mon Sep 17 00:00:00 2001 From: Isabel Atkinson Date: Wed, 9 Apr 2025 16:18:47 -0600 Subject: [PATCH 4/8] manual killCursors --- src/operation/bulk_write.rs | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/src/operation/bulk_write.rs b/src/operation/bulk_write.rs index 0ccea4316..cabf1a096 100644 --- a/src/operation/bulk_write.rs +++ b/src/operation/bulk_write.rs @@ -6,13 +6,13 @@ use futures_core::TryStream; use futures_util::{FutureExt, TryStreamExt}; use crate::{ - bson::{rawdoc, Bson, RawDocumentBuf}, + bson::{doc, rawdoc, Bson, RawDocumentBuf}, bson_util::{self, extend_raw_document_buf}, checked::Checked, cmap::{Command, RawCommandResponse, StreamDescription}, cursor::CursorSpecification, error::{BulkWriteError, Error, ErrorKind, Result}, - operation::{GetMore, OperationWithDefaults}, + operation::{run_command::RunCommand, GetMore, OperationWithDefaults}, options::{BulkWriteOptions, OperationType, WriteModel}, results::{BulkWriteResult, DeleteResult, InsertOneResult, UpdateResult}, BoxFuture, @@ -89,6 +89,7 @@ where ) -> Result<()> { let mut responses = cursor_specification.initial_buffer; let mut more_responses = cursor_specification.info.id != 0; + let mut namespace = cursor_specification.info.ns.clone(); loop { for response_document in &responses { let response: SingleOperationResponse = @@ -105,7 +106,7 @@ where .session .as_mut() .and_then(|s| s.get_txn_number_for_operation(get_more.retryability())); - let get_more_response = self + let get_more_result = self .client .execute_operation_on_connection( &mut get_more, @@ -114,9 +115,37 @@ where txn_number, Retryability::None, ) - .await?; + .await; + + let get_more_response = match get_more_result { + Ok(response) => response, + Err(error) => { + if !error.is_network_error() { + let kill_cursors = doc! { + "killCursors": &namespace.db, + "cursors": [cursor_specification.info.id], + }; + let mut run_command = + RunCommand::new(namespace.db.clone(), kill_cursors, None, None)?; + let result = self + .client + .execute_operation_on_connection( + &mut run_command, + context.connection, + &mut context.session, + txn_number, + Retryability::None, + ) + .await; + debug_assert!(result.is_ok()); + } + return Err(error); + } + }; + responses = get_more_response.batch; more_responses = get_more_response.id != 0; + namespace = get_more_response.ns; } } From 98ba175492d661e723c11dcd430d9b6821ceaf6d Mon Sep 17 00:00:00 2001 From: Isabel Atkinson Date: Thu, 10 Apr 2025 15:44:55 -0600 Subject: [PATCH 5/8] shuffling, including reauthentication --- src/client/executor.rs | 730 +++++++++++++++++++++-------------------- 1 file changed, 373 insertions(+), 357 deletions(-) diff --git a/src/client/executor.rs b/src/client/executor.rs index a221803b1..99256cd4c 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -31,6 +31,7 @@ use crate::{ }, ConnectionPool, RawCommandResponse, + StreamDescription, }, cursor::{session::SessionCursor, Cursor, CursorSpecification}, error::{ @@ -105,42 +106,55 @@ impl Client { op: &mut T, session: impl Into>, ) -> Result> { + // Validate inputs that can be checked before server selection and connection checkout. if self.inner.shutdown.executed.load(Ordering::SeqCst) { return Err(ErrorKind::Shutdown.into()); } - Box::pin(async { - // TODO RUST-9: allow unacknowledged write concerns - if !op.is_acknowledged() { - return Err(ErrorKind::InvalidArgument { - message: "Unacknowledged write concerns are not supported".to_string(), - } - .into()); + // TODO RUST-9: remove this validation + if !op.is_acknowledged() { + return Err(ErrorKind::InvalidArgument { + message: "Unacknowledged write concerns are not supported".to_string(), } - let session = session.into(); - if let Some(session) = &session { - if !TrackingArc::ptr_eq(&self.inner, &session.client().inner) { - return Err(ErrorKind::InvalidArgument { - message: "the session provided to an operation must be created from the \ - same client as the collection/database" - .into(), - } - .into()); - } + .into()); + } + if let Some(write_concern) = op.write_concern() { + write_concern.validate()?; + } - if let Some(SelectionCriteria::ReadPreference(read_preference)) = - op.selection_criteria() - { - if session.in_transaction() && read_preference != &ReadPreference::Primary { - return Err(ErrorKind::Transaction { - message: "read preference in a transaction must be primary".into(), - } - .into()); - } + // Validate the session and update its transaction status if needed. + let mut session = session.into(); + if let Some(ref mut session) = session { + if !TrackingArc::ptr_eq(&self.inner, &session.client().inner) { + return Err(Error::invalid_argument( + "the session provided to an operation must be created from the same client as \ + the collection/database on which the operation is being performed", + )); + } + if op + .selection_criteria() + .and_then(|sc| sc.as_read_pref()) + .is_some_and(|rp| rp != &ReadPreference::Primary) + && session.in_transaction() + { + return Err(ErrorKind::Transaction { + message: "read preference in a transaction must be primary".into(), } + .into()); } - self.execute_operation_with_retry(op, session).await - }) - .await + // If the current transaction has been committed/aborted and it is not being + // re-committed/re-aborted, reset the transaction's state to None. + if matches!( + session.transaction.state, + TransactionState::Committed { .. } + ) && op.name() != CommitTransaction::NAME + || session.transaction.state == TransactionState::Aborted + && op.name() != AbortTransaction::NAME + { + session.transaction.reset(); + } + } + + Box::pin(async { self.execute_operation_with_retry(op, session).await }).await } /// Execute the given operation, returning the cursor created by the operation. @@ -292,20 +306,6 @@ impl Client { op: &mut T, mut session: Option<&mut ClientSession>, ) -> Result> { - // If the current transaction has been committed/aborted and it is not being - // re-committed/re-aborted, reset the transaction's state to TransactionState::None. - if let Some(ref mut session) = session { - if matches!( - session.transaction.state, - TransactionState::Committed { .. } - ) && op.name() != CommitTransaction::NAME - || session.transaction.state == TransactionState::Aborted - && op.name() != AbortTransaction::NAME - { - session.transaction.reset(); - } - } - let mut retry: Option = None; let mut implicit_session: Option = None; loop { @@ -330,7 +330,7 @@ impl Client { Err(mut err) => { retry.first_error()?; - err.add_labels_and_update_pin(None, &mut session, None)?; + err.add_labels_and_update_pin(None, &mut session, None); return Err(err); } }; @@ -341,7 +341,7 @@ impl Client { Err(mut err) => { retry.first_error()?; - err.add_labels_and_update_pin(None, &mut session, None)?; + err.add_labels_and_update_pin(None, &mut session, None); if err.is_read_retryable() && self.inner.options.retry_writes != Some(false) { err.add_label(RETRYABLE_WRITE_ERROR); } @@ -407,30 +407,6 @@ impl Client { implicit_session, }, Err(mut err) => { - // If the error is a reauthentication required error, we reauthenticate and - // retry the operation. - if err.is_reauthentication_required() { - let credential = self.inner.options.credential.as_ref().ok_or( - ErrorKind::Authentication { - message: "No Credential when reauthentication required error \ - occured" - .to_string(), - }, - )?; - let server_api = self.inner.options.server_api.as_ref(); - - credential - .mechanism - .as_ref() - .ok_or(ErrorKind::Authentication { - message: "No AuthMechanism when reauthentication required error \ - occured" - .to_string(), - })? - .reauthenticate_stream(&mut conn, credential, server_api) - .await?; - continue; - } err.wire_version = conn.stream_description()?.max_wire_version; // Retryable writes are only supported by storage engines with document-level @@ -496,321 +472,278 @@ impl Client { txn_number: Option, retryability: Retryability, ) -> Result { - if let Some(wc) = op.write_concern() { - wc.validate()?; - } - - let stream_description = connection.stream_description()?; - let is_sharded = stream_description.initial_server_type == ServerType::Mongos; - let mut cmd = op.build(stream_description)?; - self.inner.topology.update_command_with_read_pref( - connection.address(), - &mut cmd, - op.selection_criteria(), - ); - - match session { - Some(ref mut session) if op.supports_sessions() && op.is_acknowledged() => { - cmd.set_session(session); - if let Some(txn_number) = txn_number { - cmd.set_txn_number(txn_number); - } - if session - .options() - .and_then(|opts| opts.snapshot) - .unwrap_or(false) - { - if connection - .stream_description()? - .max_wire_version - .unwrap_or(0) - < 13 + Box::pin(async move { + let stream_description = connection.stream_description()?; + let is_sharded = stream_description.initial_server_type == ServerType::Mongos; + let mut cmd = op.build(stream_description)?; + self.inner.topology.update_command_with_read_pref( + connection.address(), + &mut cmd, + op.selection_criteria(), + ); + + match session { + Some(ref mut session) if op.supports_sessions() && op.is_acknowledged() => { + cmd.set_session(session); + if let Some(txn_number) = txn_number { + cmd.set_txn_number(txn_number); + } + if session + .options() + .and_then(|opts| opts.snapshot) + .unwrap_or(false) { - let labels: Option> = None; - return Err(Error::new( - ErrorKind::IncompatibleServer { - message: "Snapshot reads require MongoDB 5.0 or later".into(), - }, - labels, - )); + if connection + .stream_description()? + .max_wire_version + .unwrap_or(0) + < 13 + { + let labels: Option> = None; + return Err(Error::new( + ErrorKind::IncompatibleServer { + message: "Snapshot reads require MongoDB 5.0 or later".into(), + }, + labels, + )); + } + cmd.set_snapshot_read_concern(session); + } + // If this is a causally consistent session, set `readConcern.afterClusterTime`. + // Causal consistency defaults to true, unless snapshot is true. + else if session.causal_consistency() + && matches!( + session.transaction.state, + TransactionState::None | TransactionState::Starting + ) + && op.supports_read_concern(stream_description) + { + cmd.set_after_cluster_time(session); } - cmd.set_snapshot_read_concern(session); - } - // If this is a causally consistent session, set `readConcern.afterClusterTime`. - // Causal consistency defaults to true, unless snapshot is true. - else if session.causal_consistency() - && matches!( - session.transaction.state, - TransactionState::None | TransactionState::Starting - ) - && op.supports_read_concern(stream_description) - { - cmd.set_after_cluster_time(session); - } - match session.transaction.state { - TransactionState::Starting => { - cmd.set_start_transaction(); - cmd.set_autocommit(); - if session.causal_consistency() { - cmd.set_after_cluster_time(session); - } + match session.transaction.state { + TransactionState::Starting => { + cmd.set_start_transaction(); + cmd.set_autocommit(); + if session.causal_consistency() { + cmd.set_after_cluster_time(session); + } - if let Some(ref options) = session.transaction.options { - if let Some(ref read_concern) = options.read_concern { - cmd.set_read_concern_level(read_concern.level.clone()); + if let Some(ref options) = session.transaction.options { + if let Some(ref read_concern) = options.read_concern { + cmd.set_read_concern_level(read_concern.level.clone()); + } } + if self.is_load_balanced() { + session.pin_connection(connection.pin()?); + } else if is_sharded { + session.pin_mongos(connection.address().clone()); + } + session.transaction.state = TransactionState::InProgress; } - if self.is_load_balanced() { - session.pin_connection(connection.pin()?); - } else if is_sharded { - session.pin_mongos(connection.address().clone()); - } - session.transaction.state = TransactionState::InProgress; - } - TransactionState::InProgress => cmd.set_autocommit(), - TransactionState::Committed { .. } | TransactionState::Aborted => { - cmd.set_autocommit(); - - // Append the recovery token to the command if we are committing or aborting - // on a sharded transaction. - if is_sharded { - if let Some(ref recovery_token) = session.transaction.recovery_token { - cmd.set_recovery_token(recovery_token); + TransactionState::InProgress => cmd.set_autocommit(), + TransactionState::Committed { .. } | TransactionState::Aborted => { + cmd.set_autocommit(); + + // Append the recovery token to the command if we are committing or + // aborting on a sharded transaction. + if is_sharded { + if let Some(ref recovery_token) = session.transaction.recovery_token + { + cmd.set_recovery_token(recovery_token); + } } } + _ => {} } - _ => {} + session.update_last_use(); } - session.update_last_use(); - } - Some(ref session) if !op.supports_sessions() && !session.is_implicit() => { - return Err(ErrorKind::InvalidArgument { - message: format!("{} does not support sessions", cmd.name), + Some(ref session) if !op.supports_sessions() && !session.is_implicit() => { + return Err(ErrorKind::InvalidArgument { + message: format!("{} does not support sessions", cmd.name), + } + .into()); } - .into()); - } - Some(ref session) if !op.is_acknowledged() && !session.is_implicit() => { - return Err(ErrorKind::InvalidArgument { - message: "Cannot use ClientSessions with unacknowledged write concern" - .to_string(), + Some(ref session) if !op.is_acknowledged() && !session.is_implicit() => { + return Err(ErrorKind::InvalidArgument { + message: "Cannot use ClientSessions with unacknowledged write concern" + .to_string(), + } + .into()); } - .into()); + _ => {} } - _ => {} - } - let session_cluster_time = session.as_ref().and_then(|session| session.cluster_time()); - let client_cluster_time = self.inner.topology.cluster_time(); - let max_cluster_time = std::cmp::max(session_cluster_time, client_cluster_time.as_ref()); - if let Some(cluster_time) = max_cluster_time { - cmd.set_cluster_time(cluster_time); - } + let session_cluster_time = session.as_ref().and_then(|session| session.cluster_time()); + let client_cluster_time = self.inner.topology.cluster_time(); + let max_cluster_time = + std::cmp::max(session_cluster_time, client_cluster_time.as_ref()); + if let Some(cluster_time) = max_cluster_time { + cmd.set_cluster_time(cluster_time); + } - let connection_info = connection.info(); - let service_id = connection.service_id(); - let request_id = next_request_id(); + let connection_info = connection.info(); + let service_id = connection.service_id(); + let request_id = next_request_id(); - if let Some(ref server_api) = self.inner.options.server_api { - cmd.set_server_api(server_api); - } + if let Some(ref server_api) = self.inner.options.server_api { + cmd.set_server_api(server_api); + } - let should_redact = cmd.should_redact(); + let should_redact = cmd.should_redact(); - let cmd_name = cmd.name.clone(); - let target_db = cmd.target_db.clone(); + let cmd_name = cmd.name.clone(); + let target_db = cmd.target_db.clone(); - let mut message = Message::try_from(cmd)?; - message.request_id = Some(request_id); - #[cfg(feature = "in-use-encryption")] - { - let guard = self.inner.csfle.read().await; - if let Some(ref csfle) = *guard { - if csfle.opts().bypass_auto_encryption != Some(true) { - let encrypted_payload = self - .auto_encrypt(csfle, &message.document_payload, &target_db) - .await?; - message.document_payload = encrypted_payload; + let mut message = Message::try_from(cmd)?; + message.request_id = Some(request_id); + #[cfg(feature = "in-use-encryption")] + { + let guard = self.inner.csfle.read().await; + if let Some(ref csfle) = *guard { + if csfle.opts().bypass_auto_encryption != Some(true) { + let encrypted_payload = self + .auto_encrypt(csfle, &message.document_payload, &target_db) + .await?; + message.document_payload = encrypted_payload; + } } } - } - self.emit_command_event(|| { - let command_body = if should_redact { - Document::new() - } else { - message.get_command_document() - }; - CommandEvent::Started(CommandStartedEvent { - command: command_body, - db: target_db.clone(), - command_name: cmd_name.clone(), - request_id, - connection: connection_info.clone(), - service_id, + self.emit_command_event(|| { + let command_body = if should_redact { + Document::new() + } else { + message.get_command_document() + }; + CommandEvent::Started(CommandStartedEvent { + command: command_body, + db: target_db.clone(), + command_name: cmd_name.clone(), + request_id, + connection: connection_info.clone(), + service_id, + }) }) - }) - .await; - - let start_time = Instant::now(); - let command_result = match connection.send_message(message).await { - Ok(response) => { - async fn handle_response( - client: &Client, - op: &T, - session: &mut Option<&mut ClientSession>, - is_sharded: bool, - response: RawCommandResponse, - ) -> Result { - let raw_doc = RawDocument::from_bytes(response.as_bytes())?; - - let ok = match raw_doc.get("ok")? { - Some(b) => crate::bson_util::get_int_raw(b).ok_or_else(|| { - ErrorKind::InvalidResponse { - message: format!( - "expected ok value to be a number, instead got {:?}", - b - ), - } - })?, - None => { - return Err(ErrorKind::InvalidResponse { - message: "missing 'ok' value in response".to_string(), - } - .into()) - } - }; - - let cluster_time: Option = raw_doc - .get("$clusterTime")? - .and_then(RawBsonRef::as_document) - .map(|d| bson::from_slice(d.as_bytes())) - .transpose()?; + .await; - let at_cluster_time = op.extract_at_cluster_time(raw_doc)?; - - client - .update_cluster_time(cluster_time, at_cluster_time, session) - .await; + let start_time = Instant::now(); + let command_result = match connection.send_message(message).await { + Ok(response) => { + self.handle_response(op, session, is_sharded, response) + .await + } + Err(err) => Err(err), + }; - if let (Some(session), Some(ts)) = ( - session.as_mut(), - raw_doc - .get("operationTime")? - .and_then(RawBsonRef::as_timestamp), - ) { - session.advance_operation_time(ts); - } + let duration = start_time.elapsed(); - if ok == 1 { - if let Some(ref mut session) = session { - if is_sharded && session.in_transaction() { - let recovery_token = raw_doc - .get("recoveryToken")? - .and_then(RawBsonRef::as_document) - .map(|d| bson::from_slice(d.as_bytes())) - .transpose()?; - session.transaction.recovery_token = recovery_token; - } + let result = match command_result { + Err(mut err) => { + self.emit_command_event(|| { + let mut err = err.clone(); + if should_redact { + err.redact(); } - Ok(response) - } else { - Err(response - .body::() - .map(|error_response| error_response.into()) - .unwrap_or_else(|e| { - Error::from(ErrorKind::InvalidResponse { - message: format!("error deserializing command error: {}", e), - }) - })) - } - } - handle_response(self, op, session, is_sharded, response).await - } - Err(err) => Err(err), - }; - - let duration = start_time.elapsed(); + CommandEvent::Failed(CommandFailedEvent { + duration, + command_name: cmd_name.clone(), + failure: err, + request_id, + connection: connection_info.clone(), + service_id, + }) + }) + .await; - match command_result { - Err(mut err) => { - self.emit_command_event(|| { - let mut err = err.clone(); - if should_redact { - err.redact(); + if let Some(ref mut session) = session { + if err.is_network_error() { + session.mark_dirty(); + } } - CommandEvent::Failed(CommandFailedEvent { - duration, - command_name: cmd_name.clone(), - failure: err, - request_id, - connection: connection_info.clone(), - service_id, - }) - }) - .await; + err.add_labels_and_update_pin( + Some(connection.stream_description()?), + session, + Some(retryability), + ); - if let Some(ref mut session) = session { - if err.is_network_error() { - session.mark_dirty(); - } + op.handle_error(err) } - - err.add_labels_and_update_pin(Some(connection), session, Some(retryability))?; - op.handle_error(err) - } - Ok(response) => { - self.emit_command_event(|| { - let reply = if should_redact { - Document::new() - } else { - response - .body() - .unwrap_or_else(|e| doc! { "deserialization error": e.to_string() }) + Ok(response) => { + self.emit_command_event(|| { + let reply = if should_redact { + Document::new() + } else { + response + .body() + .unwrap_or_else(|e| doc! { "deserialization error": e.to_string() }) + }; + + CommandEvent::Succeeded(CommandSucceededEvent { + duration, + reply, + command_name: cmd_name.clone(), + request_id, + connection: connection_info.clone(), + service_id, + }) + }) + .await; + + #[cfg(feature = "in-use-encryption")] + let response = { + let guard = self.inner.csfle.read().await; + if let Some(ref csfle) = *guard { + let new_body = self.auto_decrypt(csfle, response.raw_body()).await?; + RawCommandResponse::new_raw(response.source, new_body) + } else { + response + } }; - CommandEvent::Succeeded(CommandSucceededEvent { - duration, - reply, - command_name: cmd_name.clone(), - request_id, - connection: connection_info.clone(), - service_id, - }) - }) - .await; + let context = ExecutionContext { + connection, + session: session.as_deref_mut(), + }; - #[cfg(feature = "in-use-encryption")] - let response = { - let guard = self.inner.csfle.read().await; - if let Some(ref csfle) = *guard { - let new_body = self.auto_decrypt(csfle, response.raw_body()).await?; - RawCommandResponse::new_raw(response.source, new_body) - } else { - response + match op.handle_response(response, context).await { + Ok(response) => Ok(response), + Err(mut err) => { + err.add_labels_and_update_pin( + Some(connection.stream_description()?), + session, + Some(retryability), + ); + Err(err) + } } - }; + } + }; - let context = ExecutionContext { + if result + .as_ref() + .err() + .is_some_and(|e| e.is_reauthentication_required()) + { + // If the error is a reauthentication error, reauthenticate the connection + // and, if successful, retry the operation regardless of whether is is + // read/write retryable. + self.reauthenticate_connection(connection).await?; + self.execute_operation_on_connection( + op, connection, - session: session.as_deref_mut(), - }; - - match op.handle_response(response, context).await { - Ok(response) => Ok(response), - Err(mut err) => { - err.add_labels_and_update_pin( - Some(connection), - session, - Some(retryability), - )?; - Err(err) - } - } + session, + txn_number, + retryability, + ) + .await + } else { + result } - } + }) + .await } #[cfg(feature = "in-use-encryption")] @@ -841,6 +774,98 @@ impl Client { }) } + async fn reauthenticate_connection(&self, connection: &mut PooledConnection) -> Result<()> { + let credential = + self.inner + .options + .credential + .as_ref() + .ok_or_else(|| ErrorKind::Authentication { + message: "the connection requires reauthentication but no credential was set" + .to_string(), + })?; + let server_api = self.inner.options.server_api.as_ref(); + + credential + .mechanism + .as_ref() + .ok_or(ErrorKind::Authentication { + message: "the connection requires reauthentication but no authentication \ + mechanism was set" + .to_string(), + })? + .reauthenticate_stream(connection, credential, server_api) + .await + } + + async fn handle_response( + &self, + op: &T, + session: &mut Option<&mut ClientSession>, + is_sharded: bool, + response: RawCommandResponse, + ) -> Result { + let raw_doc = RawDocument::from_bytes(response.as_bytes())?; + + let ok = match raw_doc.get("ok")? { + Some(b) => { + crate::bson_util::get_int_raw(b).ok_or_else(|| ErrorKind::InvalidResponse { + message: format!("expected ok value to be a number, instead got {:?}", b), + })? + } + None => { + return Err(ErrorKind::InvalidResponse { + message: "missing 'ok' value in response".to_string(), + } + .into()) + } + }; + + let cluster_time: Option = raw_doc + .get("$clusterTime")? + .and_then(RawBsonRef::as_document) + .map(|d| bson::from_slice(d.as_bytes())) + .transpose()?; + + let at_cluster_time = op.extract_at_cluster_time(raw_doc)?; + + self.update_cluster_time(cluster_time, at_cluster_time, session) + .await; + + if let (Some(session), Some(ts)) = ( + session.as_mut(), + raw_doc + .get("operationTime")? + .and_then(RawBsonRef::as_timestamp), + ) { + session.advance_operation_time(ts); + } + + if ok == 1 { + if let Some(ref mut session) = session { + if is_sharded && session.in_transaction() { + let recovery_token = raw_doc + .get("recoveryToken")? + .and_then(RawBsonRef::as_document) + .map(|d| bson::from_slice(d.as_bytes())) + .transpose()?; + session.transaction.recovery_token = recovery_token; + } + } + + Ok(response) + } else { + Err(response + .body::() + .map(|error_response| error_response.into()) + .unwrap_or_else(|e| { + Error::from(ErrorKind::InvalidResponse { + message: format!("error deserializing command error: {}", e), + }) + })) + } + } + async fn select_data_bearing_server(&self, operation_name: &str) -> Result<()> { let topology_type = self.inner.topology.topology_type(); let criteria = SelectionCriteria::Predicate(Arc::new(move |server_info| { @@ -977,23 +1002,16 @@ impl Error { /// ClientSession should be unpinned. fn add_labels_and_update_pin( &mut self, - conn: Option<&PooledConnection>, + stream_description: Option<&StreamDescription>, session: &mut Option<&mut ClientSession>, retryability: Option, - ) -> Result<()> { + ) { let transaction_state = session.as_ref().map_or(&TransactionState::None, |session| { &session.transaction.state }); - let max_wire_version = if let Some(conn) = conn { - conn.stream_description()?.max_wire_version - } else { - None - }; + let max_wire_version = stream_description.and_then(|sd| sd.max_wire_version); + let server_type = stream_description.map(|sd| sd.initial_server_type); - let server_type = match conn { - Some(c) => Some(c.stream_description()?.initial_server_type), - None => None, - }; match transaction_state { TransactionState::Starting | TransactionState::InProgress => { if self.is_network_error() || self.is_server_selection_error() { @@ -1035,8 +1053,6 @@ impl Error { session.unpin(); } } - - Ok(()) } } From d8dc1fd64576927979042ebfbeddc07e276a0434 Mon Sep 17 00:00:00 2001 From: Isabel Atkinson Date: Fri, 11 Apr 2025 09:41:06 -0600 Subject: [PATCH 6/8] remove recursion --- src/client/executor.rs | 473 +++++++++++++++++++++-------------------- 1 file changed, 239 insertions(+), 234 deletions(-) diff --git a/src/client/executor.rs b/src/client/executor.rs index 99256cd4c..7a0dd7c16 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -472,278 +472,271 @@ impl Client { txn_number: Option, retryability: Retryability, ) -> Result { - Box::pin(async move { - let stream_description = connection.stream_description()?; - let is_sharded = stream_description.initial_server_type == ServerType::Mongos; - let mut cmd = op.build(stream_description)?; - self.inner.topology.update_command_with_read_pref( - connection.address(), - &mut cmd, - op.selection_criteria(), - ); - - match session { - Some(ref mut session) if op.supports_sessions() && op.is_acknowledged() => { - cmd.set_session(session); - if let Some(txn_number) = txn_number { - cmd.set_txn_number(txn_number); - } - if session - .options() - .and_then(|opts| opts.snapshot) - .unwrap_or(false) - { - if connection - .stream_description()? - .max_wire_version - .unwrap_or(0) - < 13 - { - let labels: Option> = None; - return Err(Error::new( - ErrorKind::IncompatibleServer { - message: "Snapshot reads require MongoDB 5.0 or later".into(), - }, - labels, - )); - } - cmd.set_snapshot_read_concern(session); - } - // If this is a causally consistent session, set `readConcern.afterClusterTime`. - // Causal consistency defaults to true, unless snapshot is true. - else if session.causal_consistency() - && matches!( - session.transaction.state, - TransactionState::None | TransactionState::Starting - ) - && op.supports_read_concern(stream_description) + let stream_description = connection.stream_description()?; + let is_sharded = stream_description.initial_server_type == ServerType::Mongos; + let mut cmd = op.build(stream_description)?; + self.inner.topology.update_command_with_read_pref( + connection.address(), + &mut cmd, + op.selection_criteria(), + ); + + match session { + Some(ref mut session) if op.supports_sessions() && op.is_acknowledged() => { + cmd.set_session(session); + if let Some(txn_number) = txn_number { + cmd.set_txn_number(txn_number); + } + if session + .options() + .and_then(|opts| opts.snapshot) + .unwrap_or(false) + { + if connection + .stream_description()? + .max_wire_version + .unwrap_or(0) + < 13 { - cmd.set_after_cluster_time(session); + let labels: Option> = None; + return Err(Error::new( + ErrorKind::IncompatibleServer { + message: "Snapshot reads require MongoDB 5.0 or later".into(), + }, + labels, + )); } + cmd.set_snapshot_read_concern(session); + } + // If this is a causally consistent session, set `readConcern.afterClusterTime`. + // Causal consistency defaults to true, unless snapshot is true. + else if session.causal_consistency() + && matches!( + session.transaction.state, + TransactionState::None | TransactionState::Starting + ) + && op.supports_read_concern(stream_description) + { + cmd.set_after_cluster_time(session); + } - match session.transaction.state { - TransactionState::Starting => { - cmd.set_start_transaction(); - cmd.set_autocommit(); - if session.causal_consistency() { - cmd.set_after_cluster_time(session); - } + match session.transaction.state { + TransactionState::Starting => { + cmd.set_start_transaction(); + cmd.set_autocommit(); + if session.causal_consistency() { + cmd.set_after_cluster_time(session); + } - if let Some(ref options) = session.transaction.options { - if let Some(ref read_concern) = options.read_concern { - cmd.set_read_concern_level(read_concern.level.clone()); - } - } - if self.is_load_balanced() { - session.pin_connection(connection.pin()?); - } else if is_sharded { - session.pin_mongos(connection.address().clone()); + if let Some(ref options) = session.transaction.options { + if let Some(ref read_concern) = options.read_concern { + cmd.set_read_concern_level(read_concern.level.clone()); } - session.transaction.state = TransactionState::InProgress; } - TransactionState::InProgress => cmd.set_autocommit(), - TransactionState::Committed { .. } | TransactionState::Aborted => { - cmd.set_autocommit(); - - // Append the recovery token to the command if we are committing or - // aborting on a sharded transaction. - if is_sharded { - if let Some(ref recovery_token) = session.transaction.recovery_token - { - cmd.set_recovery_token(recovery_token); - } + if self.is_load_balanced() { + session.pin_connection(connection.pin()?); + } else if is_sharded { + session.pin_mongos(connection.address().clone()); + } + session.transaction.state = TransactionState::InProgress; + } + TransactionState::InProgress => cmd.set_autocommit(), + TransactionState::Committed { .. } | TransactionState::Aborted => { + cmd.set_autocommit(); + + // Append the recovery token to the command if we are committing or + // aborting on a sharded transaction. + if is_sharded { + if let Some(ref recovery_token) = session.transaction.recovery_token { + cmd.set_recovery_token(recovery_token); } } - _ => {} } - session.update_last_use(); + _ => {} } - Some(ref session) if !op.supports_sessions() && !session.is_implicit() => { - return Err(ErrorKind::InvalidArgument { - message: format!("{} does not support sessions", cmd.name), - } - .into()); + session.update_last_use(); + } + Some(ref session) if !op.supports_sessions() && !session.is_implicit() => { + return Err(ErrorKind::InvalidArgument { + message: format!("{} does not support sessions", cmd.name), } - Some(ref session) if !op.is_acknowledged() && !session.is_implicit() => { - return Err(ErrorKind::InvalidArgument { - message: "Cannot use ClientSessions with unacknowledged write concern" - .to_string(), - } - .into()); + .into()); + } + Some(ref session) if !op.is_acknowledged() && !session.is_implicit() => { + return Err(ErrorKind::InvalidArgument { + message: "Cannot use ClientSessions with unacknowledged write concern" + .to_string(), } - _ => {} + .into()); } + _ => {} + } - let session_cluster_time = session.as_ref().and_then(|session| session.cluster_time()); - let client_cluster_time = self.inner.topology.cluster_time(); - let max_cluster_time = - std::cmp::max(session_cluster_time, client_cluster_time.as_ref()); - if let Some(cluster_time) = max_cluster_time { - cmd.set_cluster_time(cluster_time); - } + let session_cluster_time = session.as_ref().and_then(|session| session.cluster_time()); + let client_cluster_time = self.inner.topology.cluster_time(); + let max_cluster_time = std::cmp::max(session_cluster_time, client_cluster_time.as_ref()); + if let Some(cluster_time) = max_cluster_time { + cmd.set_cluster_time(cluster_time); + } - let connection_info = connection.info(); - let service_id = connection.service_id(); - let request_id = next_request_id(); + let connection_info = connection.info(); + let service_id = connection.service_id(); + let request_id = next_request_id(); - if let Some(ref server_api) = self.inner.options.server_api { - cmd.set_server_api(server_api); - } + if let Some(ref server_api) = self.inner.options.server_api { + cmd.set_server_api(server_api); + } - let should_redact = cmd.should_redact(); + let should_redact = cmd.should_redact(); - let cmd_name = cmd.name.clone(); - let target_db = cmd.target_db.clone(); + let cmd_name = cmd.name.clone(); + let target_db = cmd.target_db.clone(); - let mut message = Message::try_from(cmd)?; - message.request_id = Some(request_id); - #[cfg(feature = "in-use-encryption")] - { - let guard = self.inner.csfle.read().await; - if let Some(ref csfle) = *guard { - if csfle.opts().bypass_auto_encryption != Some(true) { - let encrypted_payload = self - .auto_encrypt(csfle, &message.document_payload, &target_db) - .await?; - message.document_payload = encrypted_payload; - } + let mut message = Message::try_from(cmd)?; + message.request_id = Some(request_id); + #[cfg(feature = "in-use-encryption")] + { + let guard = self.inner.csfle.read().await; + if let Some(ref csfle) = *guard { + if csfle.opts().bypass_auto_encryption != Some(true) { + let encrypted_payload = self + .auto_encrypt(csfle, &message.document_payload, &target_db) + .await?; + message.document_payload = encrypted_payload; } } + } - self.emit_command_event(|| { - let command_body = if should_redact { - Document::new() - } else { - message.get_command_document() - }; - CommandEvent::Started(CommandStartedEvent { - command: command_body, - db: target_db.clone(), - command_name: cmd_name.clone(), - request_id, - connection: connection_info.clone(), - service_id, - }) + self.emit_command_event(|| { + let command_body = if should_redact { + Document::new() + } else { + message.get_command_document() + }; + CommandEvent::Started(CommandStartedEvent { + command: command_body, + db: target_db.clone(), + command_name: cmd_name.clone(), + request_id, + connection: connection_info.clone(), + service_id, }) - .await; + }) + .await; - let start_time = Instant::now(); - let command_result = match connection.send_message(message).await { - Ok(response) => { - self.handle_response(op, session, is_sharded, response) - .await - } - Err(err) => Err(err), - }; + let start_time = Instant::now(); + let command_result = match connection.send_message(message).await { + Ok(response) => { + self.handle_response(op, session, is_sharded, response) + .await + } + Err(err) => Err(err), + }; - let duration = start_time.elapsed(); + let duration = start_time.elapsed(); - let result = match command_result { - Err(mut err) => { - self.emit_command_event(|| { - let mut err = err.clone(); - if should_redact { - err.redact(); - } + let result = match command_result { + Err(mut err) => { + self.emit_command_event(|| { + let mut err = err.clone(); + if should_redact { + err.redact(); + } - CommandEvent::Failed(CommandFailedEvent { - duration, - command_name: cmd_name.clone(), - failure: err, - request_id, - connection: connection_info.clone(), - service_id, - }) + CommandEvent::Failed(CommandFailedEvent { + duration, + command_name: cmd_name.clone(), + failure: err, + request_id, + connection: connection_info.clone(), + service_id, }) - .await; + }) + .await; - if let Some(ref mut session) = session { - if err.is_network_error() { - session.mark_dirty(); - } + if let Some(ref mut session) = session { + if err.is_network_error() { + session.mark_dirty(); } + } - err.add_labels_and_update_pin( - Some(connection.stream_description()?), - session, - Some(retryability), - ); + err.add_labels_and_update_pin( + Some(connection.stream_description()?), + session, + Some(retryability), + ); - op.handle_error(err) - } - Ok(response) => { - self.emit_command_event(|| { - let reply = if should_redact { - Document::new() - } else { - response - .body() - .unwrap_or_else(|e| doc! { "deserialization error": e.to_string() }) - }; - - CommandEvent::Succeeded(CommandSucceededEvent { - duration, - reply, - command_name: cmd_name.clone(), - request_id, - connection: connection_info.clone(), - service_id, - }) - }) - .await; - - #[cfg(feature = "in-use-encryption")] - let response = { - let guard = self.inner.csfle.read().await; - if let Some(ref csfle) = *guard { - let new_body = self.auto_decrypt(csfle, response.raw_body()).await?; - RawCommandResponse::new_raw(response.source, new_body) - } else { - response - } + op.handle_error(err) + } + Ok(response) => { + self.emit_command_event(|| { + let reply = if should_redact { + Document::new() + } else { + response + .body() + .unwrap_or_else(|e| doc! { "deserialization error": e.to_string() }) }; - let context = ExecutionContext { - connection, - session: session.as_deref_mut(), - }; + CommandEvent::Succeeded(CommandSucceededEvent { + duration, + reply, + command_name: cmd_name.clone(), + request_id, + connection: connection_info.clone(), + service_id, + }) + }) + .await; - match op.handle_response(response, context).await { - Ok(response) => Ok(response), - Err(mut err) => { - err.add_labels_and_update_pin( - Some(connection.stream_description()?), - session, - Some(retryability), - ); - Err(err) - } + #[cfg(feature = "in-use-encryption")] + let response = { + let guard = self.inner.csfle.read().await; + if let Some(ref csfle) = *guard { + let new_body = self.auto_decrypt(csfle, response.raw_body()).await?; + RawCommandResponse::new_raw(response.source, new_body) + } else { + response } - } - }; + }; - if result - .as_ref() - .err() - .is_some_and(|e| e.is_reauthentication_required()) - { - // If the error is a reauthentication error, reauthenticate the connection - // and, if successful, retry the operation regardless of whether is is - // read/write retryable. - self.reauthenticate_connection(connection).await?; - self.execute_operation_on_connection( - op, + let context = ExecutionContext { connection, - session, - txn_number, - retryability, - ) - .await - } else { - result + session: session.as_deref_mut(), + }; + + match op.handle_response(response, context).await { + Ok(response) => Ok(response), + Err(mut err) => { + err.add_labels_and_update_pin( + Some(connection.stream_description()?), + session, + Some(retryability), + ); + Err(err) + } + } } - }) - .await + }; + + if result + .as_ref() + .err() + .is_some_and(|e| e.is_reauthentication_required()) + { + // This retry is done outside of the normal retry loop because all operations, + // regardless of retryability, should be retried after reauthentication. + self.reauthenticate_connection_and_retry_operation( + op, + connection, + session, + txn_number, + retryability, + ) + .await + } else { + result + } } #[cfg(feature = "in-use-encryption")] @@ -774,7 +767,16 @@ impl Client { }) } - async fn reauthenticate_connection(&self, connection: &mut PooledConnection) -> Result<()> { + // Reauthenticates a connection and retries the operation that received a reauthentication + // required error. + async fn reauthenticate_connection_and_retry_operation( + &self, + op: &mut T, + connection: &mut PooledConnection, + session: &mut Option<&mut ClientSession>, + txn_number: Option, + retryability: Retryability, + ) -> Result { let credential = self.inner .options @@ -795,6 +797,9 @@ impl Client { .to_string(), })? .reauthenticate_stream(connection, credential, server_api) + .await?; + + self.execute_operation_on_connection(op, connection, session, txn_number, retryability) .await } From 5377145c2618e12c35f82bdf8f4f4f5ce0bf559a Mon Sep 17 00:00:00 2001 From: Isabel Atkinson Date: Fri, 11 Apr 2025 09:48:24 -0600 Subject: [PATCH 7/8] loop for reauthentication --- src/client/executor.rs | 464 ++++++++++++++++++++--------------------- 1 file changed, 225 insertions(+), 239 deletions(-) diff --git a/src/client/executor.rs b/src/client/executor.rs index 7a0dd7c16..17b5debd2 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -472,270 +472,268 @@ impl Client { txn_number: Option, retryability: Retryability, ) -> Result { - let stream_description = connection.stream_description()?; - let is_sharded = stream_description.initial_server_type == ServerType::Mongos; - let mut cmd = op.build(stream_description)?; - self.inner.topology.update_command_with_read_pref( - connection.address(), - &mut cmd, - op.selection_criteria(), - ); - - match session { - Some(ref mut session) if op.supports_sessions() && op.is_acknowledged() => { - cmd.set_session(session); - if let Some(txn_number) = txn_number { - cmd.set_txn_number(txn_number); - } - if session - .options() - .and_then(|opts| opts.snapshot) - .unwrap_or(false) - { - if connection - .stream_description()? - .max_wire_version - .unwrap_or(0) - < 13 + loop { + let stream_description = connection.stream_description()?; + let is_sharded = stream_description.initial_server_type == ServerType::Mongos; + let mut cmd = op.build(stream_description)?; + self.inner.topology.update_command_with_read_pref( + connection.address(), + &mut cmd, + op.selection_criteria(), + ); + + match session { + Some(ref mut session) if op.supports_sessions() && op.is_acknowledged() => { + cmd.set_session(session); + if let Some(txn_number) = txn_number { + cmd.set_txn_number(txn_number); + } + if session + .options() + .and_then(|opts| opts.snapshot) + .unwrap_or(false) { - let labels: Option> = None; - return Err(Error::new( - ErrorKind::IncompatibleServer { - message: "Snapshot reads require MongoDB 5.0 or later".into(), - }, - labels, - )); + if connection + .stream_description()? + .max_wire_version + .unwrap_or(0) + < 13 + { + let labels: Option> = None; + return Err(Error::new( + ErrorKind::IncompatibleServer { + message: "Snapshot reads require MongoDB 5.0 or later".into(), + }, + labels, + )); + } + cmd.set_snapshot_read_concern(session); + } + // If this is a causally consistent session, set `readConcern.afterClusterTime`. + // Causal consistency defaults to true, unless snapshot is true. + else if session.causal_consistency() + && matches!( + session.transaction.state, + TransactionState::None | TransactionState::Starting + ) + && op.supports_read_concern(stream_description) + { + cmd.set_after_cluster_time(session); } - cmd.set_snapshot_read_concern(session); - } - // If this is a causally consistent session, set `readConcern.afterClusterTime`. - // Causal consistency defaults to true, unless snapshot is true. - else if session.causal_consistency() - && matches!( - session.transaction.state, - TransactionState::None | TransactionState::Starting - ) - && op.supports_read_concern(stream_description) - { - cmd.set_after_cluster_time(session); - } - match session.transaction.state { - TransactionState::Starting => { - cmd.set_start_transaction(); - cmd.set_autocommit(); - if session.causal_consistency() { - cmd.set_after_cluster_time(session); - } + match session.transaction.state { + TransactionState::Starting => { + cmd.set_start_transaction(); + cmd.set_autocommit(); + if session.causal_consistency() { + cmd.set_after_cluster_time(session); + } - if let Some(ref options) = session.transaction.options { - if let Some(ref read_concern) = options.read_concern { - cmd.set_read_concern_level(read_concern.level.clone()); + if let Some(ref options) = session.transaction.options { + if let Some(ref read_concern) = options.read_concern { + cmd.set_read_concern_level(read_concern.level.clone()); + } } + if self.is_load_balanced() { + session.pin_connection(connection.pin()?); + } else if is_sharded { + session.pin_mongos(connection.address().clone()); + } + session.transaction.state = TransactionState::InProgress; } - if self.is_load_balanced() { - session.pin_connection(connection.pin()?); - } else if is_sharded { - session.pin_mongos(connection.address().clone()); - } - session.transaction.state = TransactionState::InProgress; - } - TransactionState::InProgress => cmd.set_autocommit(), - TransactionState::Committed { .. } | TransactionState::Aborted => { - cmd.set_autocommit(); - - // Append the recovery token to the command if we are committing or - // aborting on a sharded transaction. - if is_sharded { - if let Some(ref recovery_token) = session.transaction.recovery_token { - cmd.set_recovery_token(recovery_token); + TransactionState::InProgress => cmd.set_autocommit(), + TransactionState::Committed { .. } | TransactionState::Aborted => { + cmd.set_autocommit(); + + // Append the recovery token to the command if we are committing or + // aborting on a sharded transaction. + if is_sharded { + if let Some(ref recovery_token) = session.transaction.recovery_token + { + cmd.set_recovery_token(recovery_token); + } } } + _ => {} } - _ => {} + session.update_last_use(); } - session.update_last_use(); - } - Some(ref session) if !op.supports_sessions() && !session.is_implicit() => { - return Err(ErrorKind::InvalidArgument { - message: format!("{} does not support sessions", cmd.name), + Some(ref session) if !op.supports_sessions() && !session.is_implicit() => { + return Err(ErrorKind::InvalidArgument { + message: format!("{} does not support sessions", cmd.name), + } + .into()); } - .into()); - } - Some(ref session) if !op.is_acknowledged() && !session.is_implicit() => { - return Err(ErrorKind::InvalidArgument { - message: "Cannot use ClientSessions with unacknowledged write concern" - .to_string(), + Some(ref session) if !op.is_acknowledged() && !session.is_implicit() => { + return Err(ErrorKind::InvalidArgument { + message: "Cannot use ClientSessions with unacknowledged write concern" + .to_string(), + } + .into()); } - .into()); + _ => {} } - _ => {} - } - let session_cluster_time = session.as_ref().and_then(|session| session.cluster_time()); - let client_cluster_time = self.inner.topology.cluster_time(); - let max_cluster_time = std::cmp::max(session_cluster_time, client_cluster_time.as_ref()); - if let Some(cluster_time) = max_cluster_time { - cmd.set_cluster_time(cluster_time); - } + let session_cluster_time = session.as_ref().and_then(|session| session.cluster_time()); + let client_cluster_time = self.inner.topology.cluster_time(); + let max_cluster_time = + std::cmp::max(session_cluster_time, client_cluster_time.as_ref()); + if let Some(cluster_time) = max_cluster_time { + cmd.set_cluster_time(cluster_time); + } - let connection_info = connection.info(); - let service_id = connection.service_id(); - let request_id = next_request_id(); + let connection_info = connection.info(); + let service_id = connection.service_id(); + let request_id = next_request_id(); - if let Some(ref server_api) = self.inner.options.server_api { - cmd.set_server_api(server_api); - } + if let Some(ref server_api) = self.inner.options.server_api { + cmd.set_server_api(server_api); + } - let should_redact = cmd.should_redact(); + let should_redact = cmd.should_redact(); - let cmd_name = cmd.name.clone(); - let target_db = cmd.target_db.clone(); + let cmd_name = cmd.name.clone(); + let target_db = cmd.target_db.clone(); - let mut message = Message::try_from(cmd)?; - message.request_id = Some(request_id); - #[cfg(feature = "in-use-encryption")] - { - let guard = self.inner.csfle.read().await; - if let Some(ref csfle) = *guard { - if csfle.opts().bypass_auto_encryption != Some(true) { - let encrypted_payload = self - .auto_encrypt(csfle, &message.document_payload, &target_db) - .await?; - message.document_payload = encrypted_payload; + let mut message = Message::try_from(cmd)?; + message.request_id = Some(request_id); + #[cfg(feature = "in-use-encryption")] + { + let guard = self.inner.csfle.read().await; + if let Some(ref csfle) = *guard { + if csfle.opts().bypass_auto_encryption != Some(true) { + let encrypted_payload = self + .auto_encrypt(csfle, &message.document_payload, &target_db) + .await?; + message.document_payload = encrypted_payload; + } } } - } - self.emit_command_event(|| { - let command_body = if should_redact { - Document::new() - } else { - message.get_command_document() - }; - CommandEvent::Started(CommandStartedEvent { - command: command_body, - db: target_db.clone(), - command_name: cmd_name.clone(), - request_id, - connection: connection_info.clone(), - service_id, + self.emit_command_event(|| { + let command_body = if should_redact { + Document::new() + } else { + message.get_command_document() + }; + CommandEvent::Started(CommandStartedEvent { + command: command_body, + db: target_db.clone(), + command_name: cmd_name.clone(), + request_id, + connection: connection_info.clone(), + service_id, + }) }) - }) - .await; + .await; - let start_time = Instant::now(); - let command_result = match connection.send_message(message).await { - Ok(response) => { - self.handle_response(op, session, is_sharded, response) - .await - } - Err(err) => Err(err), - }; + let start_time = Instant::now(); + let command_result = match connection.send_message(message).await { + Ok(response) => { + self.handle_response(op, session, is_sharded, response) + .await + } + Err(err) => Err(err), + }; - let duration = start_time.elapsed(); + let duration = start_time.elapsed(); - let result = match command_result { - Err(mut err) => { - self.emit_command_event(|| { - let mut err = err.clone(); - if should_redact { - err.redact(); - } + let result = match command_result { + Err(mut err) => { + self.emit_command_event(|| { + let mut err = err.clone(); + if should_redact { + err.redact(); + } - CommandEvent::Failed(CommandFailedEvent { - duration, - command_name: cmd_name.clone(), - failure: err, - request_id, - connection: connection_info.clone(), - service_id, + CommandEvent::Failed(CommandFailedEvent { + duration, + command_name: cmd_name.clone(), + failure: err, + request_id, + connection: connection_info.clone(), + service_id, + }) }) - }) - .await; + .await; - if let Some(ref mut session) = session { - if err.is_network_error() { - session.mark_dirty(); + if let Some(ref mut session) = session { + if err.is_network_error() { + session.mark_dirty(); + } } - } - err.add_labels_and_update_pin( - Some(connection.stream_description()?), - session, - Some(retryability), - ); + err.add_labels_and_update_pin( + Some(connection.stream_description()?), + session, + Some(retryability), + ); - op.handle_error(err) - } - Ok(response) => { - self.emit_command_event(|| { - let reply = if should_redact { - Document::new() - } else { - response - .body() - .unwrap_or_else(|e| doc! { "deserialization error": e.to_string() }) - }; - - CommandEvent::Succeeded(CommandSucceededEvent { - duration, - reply, - command_name: cmd_name.clone(), - request_id, - connection: connection_info.clone(), - service_id, + op.handle_error(err) + } + Ok(response) => { + self.emit_command_event(|| { + let reply = if should_redact { + Document::new() + } else { + response + .body() + .unwrap_or_else(|e| doc! { "deserialization error": e.to_string() }) + }; + + CommandEvent::Succeeded(CommandSucceededEvent { + duration, + reply, + command_name: cmd_name.clone(), + request_id, + connection: connection_info.clone(), + service_id, + }) }) - }) - .await; - - #[cfg(feature = "in-use-encryption")] - let response = { - let guard = self.inner.csfle.read().await; - if let Some(ref csfle) = *guard { - let new_body = self.auto_decrypt(csfle, response.raw_body()).await?; - RawCommandResponse::new_raw(response.source, new_body) - } else { - response - } - }; + .await; + + #[cfg(feature = "in-use-encryption")] + let response = { + let guard = self.inner.csfle.read().await; + if let Some(ref csfle) = *guard { + let new_body = self.auto_decrypt(csfle, response.raw_body()).await?; + RawCommandResponse::new_raw(response.source, new_body) + } else { + response + } + }; - let context = ExecutionContext { - connection, - session: session.as_deref_mut(), - }; + let context = ExecutionContext { + connection, + session: session.as_deref_mut(), + }; - match op.handle_response(response, context).await { - Ok(response) => Ok(response), - Err(mut err) => { - err.add_labels_and_update_pin( - Some(connection.stream_description()?), - session, - Some(retryability), - ); - Err(err) + match op.handle_response(response, context).await { + Ok(response) => Ok(response), + Err(mut err) => { + err.add_labels_and_update_pin( + Some(connection.stream_description()?), + session, + Some(retryability), + ); + Err(err) + } } } - } - }; + }; - if result - .as_ref() - .err() - .is_some_and(|e| e.is_reauthentication_required()) - { - // This retry is done outside of the normal retry loop because all operations, - // regardless of retryability, should be retried after reauthentication. - self.reauthenticate_connection_and_retry_operation( - op, - connection, - session, - txn_number, - retryability, - ) - .await - } else { - result + if result + .as_ref() + .err() + .is_some_and(|e| e.is_reauthentication_required()) + { + // This retry is done outside of the normal retry loop because all operations, + // regardless of retryability, should be retried after reauthentication. + self.reauthenticate_connection(connection).await?; + continue; + } else { + return result; + } } } @@ -767,16 +765,7 @@ impl Client { }) } - // Reauthenticates a connection and retries the operation that received a reauthentication - // required error. - async fn reauthenticate_connection_and_retry_operation( - &self, - op: &mut T, - connection: &mut PooledConnection, - session: &mut Option<&mut ClientSession>, - txn_number: Option, - retryability: Retryability, - ) -> Result { + async fn reauthenticate_connection(&self, connection: &mut PooledConnection) -> Result<()> { let credential = self.inner .options @@ -797,9 +786,6 @@ impl Client { .to_string(), })? .reauthenticate_stream(connection, credential, server_api) - .await?; - - self.execute_operation_on_connection(op, connection, session, txn_number, retryability) .await } From a9892af4bbfd95888684f67dc3c6cd3505cdd91d Mon Sep 17 00:00:00 2001 From: Isabel Atkinson Date: Thu, 17 Apr 2025 13:09:06 -0600 Subject: [PATCH 8/8] remove assert --- src/operation/bulk_write.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/operation/bulk_write.rs b/src/operation/bulk_write.rs index cabf1a096..2ad01b26d 100644 --- a/src/operation/bulk_write.rs +++ b/src/operation/bulk_write.rs @@ -127,7 +127,7 @@ where }; let mut run_command = RunCommand::new(namespace.db.clone(), kill_cursors, None, None)?; - let result = self + let _ = self .client .execute_operation_on_connection( &mut run_command, @@ -137,7 +137,6 @@ where Retryability::None, ) .await; - debug_assert!(result.is_ok()); } return Err(error); }