From 3d3ed57c591099e185c1753c06513f78d24fd4c2 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Fri, 9 Jul 2021 23:18:53 -0400 Subject: [PATCH 01/10] serialize insert to raw bson --- benchmarks/src/main.rs | 2 +- src/bson_util/mod.rs | 333 +++++++++++----------- src/client/executor.rs | 34 ++- src/client/mod.rs | 2 + src/client/options/mod.rs | 8 +- src/cmap/conn/command.rs | 91 +++--- src/cmap/conn/mod.rs | 40 ++- src/cmap/conn/wire/message.rs | 18 +- src/cmap/conn/wire/test.rs | 4 +- src/cmap/mod.rs | 2 +- src/operation/abort_transaction/mod.rs | 5 +- src/operation/aggregate/mod.rs | 4 +- src/operation/commit_transaction/mod.rs | 5 +- src/operation/count/mod.rs | 3 +- src/operation/count_documents/mod.rs | 3 +- src/operation/create/mod.rs | 5 +- src/operation/delete/mod.rs | 3 +- src/operation/distinct/mod.rs | 3 +- src/operation/drop_collection/mod.rs | 5 +- src/operation/drop_database/mod.rs | 5 +- src/operation/find/mod.rs | 1 + src/operation/find_and_modify/mod.rs | 1 + src/operation/get_more/mod.rs | 2 + src/operation/insert/mod.rs | 161 ++++++++--- src/operation/insert/test.rs | 185 ++++++++++-- src/operation/list_collections/mod.rs | 1 + src/operation/list_databases/mod.rs | 1 + src/operation/mod.rs | 30 +- src/operation/run_command/mod.rs | 1 + src/operation/update/mod.rs | 1 + src/sdam/description/topology/mod.rs | 8 +- src/sdam/state/mod.rs | 8 +- src/selection_criteria.rs | 122 ++++---- src/test/spec/crud_v1/insert_many.rs | 9 +- src/test/spec/unified_runner/test_file.rs | 27 +- src/test/spec/v2_runner/test_event.rs | 1 + 36 files changed, 740 insertions(+), 394 deletions(-) diff --git a/benchmarks/src/main.rs b/benchmarks/src/main.rs index ad034bb29..0f50bda8b 100644 --- a/benchmarks/src/main.rs +++ b/benchmarks/src/main.rs @@ -444,7 +444,7 @@ fn parse_ids(matches: ArgMatches) -> Vec { ids } -#[cfg_attr(feature = "tokio-runtime", tokio::main)] +#[cfg_attr(feature = "tokio-runtime", tokio::main(flavor = "current_thread"))] #[cfg_attr(feature = "async-std-runtime", async_std::main)] async fn main() { let matches = App::new("RustDriverBenchmark") diff --git a/src/bson_util/mod.rs b/src/bson_util/mod.rs index 40d78b20e..a1d7cd68d 100644 --- a/src/bson_util/mod.rs +++ b/src/bson_util/mod.rs @@ -1,10 +1,15 @@ -use std::{convert::TryFrom, io::Read, time::Duration}; +use std::{ + convert::{TryFrom, TryInto}, + io::{Read, Write}, + time::Duration, +}; -use serde::{de::Error, ser, Deserialize, Deserializer, Serialize, Serializer}; +use bson::spec::ElementType; +use serde::{de::Error as SerdeDeError, ser, Deserialize, Deserializer, Serialize, Serializer}; use crate::{ - bson::{doc, Binary, Bson, Document, JavaScriptCodeWithScope, Regex}, - error::{ErrorKind, Result}, + bson::{doc, Bson, Document}, + error::{Error, ErrorKind, Result}, runtime::{SyncLittleEndianRead, SyncLittleEndianWrite}, }; @@ -164,128 +169,30 @@ where .ok_or_else(|| D::Error::custom(format!("could not deserialize u64 from {:?}", bson))) } -pub fn doc_size_bytes(doc: &Document) -> u64 { - // - // * i32 length prefix (4 bytes) - // * for each element: - // * type (1 byte) - // * number of UTF-8 bytes in key - // * null terminator for the key (1 byte) - // * size of the value - // * null terminator (1 byte) - 4 + doc - .into_iter() - .map(|(key, val)| 1 + key.len() as u64 + 1 + size_bytes(val)) - .sum::() - + 1 -} - -pub fn size_bytes(val: &Bson) -> u64 { - match val { - Bson::Double(_) => 8, - // - // * length prefix (4 bytes) - // * number of UTF-8 bytes - // * null terminator (1 byte) - Bson::String(s) => 4 + s.len() as u64 + 1, - // An array is serialized as a document with the keys "0", "1", "2", etc., so the size of - // an array is: - // - // * length prefix (4 bytes) - // * for each element: - // * type (1 byte) - // * number of decimal digits in key - // * null terminator for the key (1 byte) - // * size of value - // * null terminator (1 byte) - Bson::Array(arr) => { - 4 + arr - .iter() - .enumerate() - .map(|(i, val)| 1 + num_decimal_digits(i) + 1 + size_bytes(val)) - .sum::() - + 1 - } - Bson::Document(doc) => doc_size_bytes(doc), - Bson::Boolean(_) => 1, - Bson::Null => 0, - // for $pattern and $opts: - // * number of UTF-8 bytes - // * null terminator (1 byte) - Bson::RegularExpression(Regex { pattern, options }) => { - pattern.len() as u64 + 1 + options.len() as u64 + 1 - } - // - // * length prefix (4 bytes) - // * number of UTF-8 bytes - // * null terminator (1 byte) - Bson::JavaScriptCode(code) => 4 + code.len() as u64 + 1, - // - // * i32 length prefix (4 bytes) - // * i32 length prefix for code (4 bytes) - // * number of UTF-8 bytes in code - // * null terminator for code (1 byte) - // * length of document - Bson::JavaScriptCodeWithScope(JavaScriptCodeWithScope { code, scope }) => { - 4 + 4 + code.len() as u64 + 1 + doc_size_bytes(scope) - } - Bson::Int32(_) => 4, - Bson::Int64(_) => 8, - Bson::Timestamp(_) => 8, - // - // * i32 length prefix (4 bytes) - // * subtype (1 byte) - // * number of bytes - Bson::Binary(Binary { bytes, .. }) => 4 + 1 + bytes.len() as u64, - Bson::ObjectId(_) => 12, - Bson::DateTime(_) => 8, - // - // * i32 length prefix (4 bytes) - // * subtype (1 byte) - // * number of UTF-8 bytes - Bson::Symbol(s) => 4 + 1 + s.len() as u64, - Bson::Decimal128(..) => 128 / 8, - Bson::Undefined | Bson::MaxKey | Bson::MinKey => 0, - // DbPointer doesn't have public details exposed by the BSON library, but it comprises of a - // namespace and an ObjectId. Since our methods to calculate the size of BSON values are - // only used to estimate the cutoff for batches when making a large insert, we can just - // assume the largest possible size for a namespace, which is 120 bytes. Therefore, the size - // is: - // - // * i32 length prefix (4 bytes) - // * namespace (120 bytes) - // * null terminator (1 byte) - // * objectid (12 bytes) - Bson::DbPointer(..) => 4 + 120 + 1 + 12, - } -} - /// The size in bytes of the provided document's entry in a BSON array at the given index. -pub(crate) fn array_entry_size_bytes(index: usize, doc: &Document) -> u64 { +pub(crate) fn array_entry_size_bytes(index: usize, doc_len: usize) -> u64 { // // * type (1 byte) // * number of decimal digits in key // * null terminator for the key (1 byte) // * size of value - 1 + num_decimal_digits(index) + 1 + doc_size_bytes(doc) + + 1 + num_decimal_digits(index) + 1 + doc_len as u64 } /// The number of digits in `n` in base 10. /// Useful for calculating the size of an array entry in BSON. -fn num_decimal_digits(n: usize) -> u64 { - let mut digits = 1; - let mut curr = 10; - - while curr < n { - curr = match curr.checked_mul(10) { - Some(val) => val, - None => break, - }; +fn num_decimal_digits(mut n: usize) -> u64 { + let mut digits = 0; + loop { + n /= 10; digits += 1; - } - digits + if n == 0 { + return digits; + } + } } /// Read a document's raw BSON bytes from the provided reader. @@ -300,63 +207,161 @@ pub(crate) fn read_document_bytes(mut reader: R) -> Result> { Ok(bytes) } -/// Serialize the document to raw BSON and return a vec containing the bytes. -#[cfg(test)] -pub(crate) fn document_to_vec(doc: Document) -> Result> { - let mut v = Vec::new(); - doc.to_writer(&mut v)?; - Ok(v) +/// Get the value for the provided key from a buffer containing a BSON document. +/// If the key is not present, None will be returned. +/// If the BSON is not properly formatted, an internal error would be returned. +/// +/// TODO: RUST-924 replace this with raw document API usage. +pub(crate) fn raw_get(doc: &[u8], key: &str) -> Result> { + fn read_i32(reader: &mut std::io::Cursor<&[u8]>) -> Result { + reader.read_i32().map_err(deserialize_error) + } + + fn read_u8(reader: &mut std::io::Cursor<&[u8]>) -> Result { + reader.read_u8().map_err(deserialize_error) + } + + fn deserialize_error(_e: T) -> Error { + deserialize_error_no_arg() + } + + fn deserialize_error_no_arg() -> Error { + Error::from(ErrorKind::Internal { + message: "failed to read from serialized document".to_string(), + }) + } + + let mut reader = std::io::Cursor::new(doc); + let len: u64 = read_i32(&mut reader)? + .try_into() + .map_err(deserialize_error)?; + + while reader.position() < len { + let element_start: usize = reader.position().try_into().map_err(deserialize_error)?; + + // read the element type + let tag = read_u8(&mut reader)?; + + // check if we reached the end of the document + if tag == 0 && reader.position() == len { + return Ok(None); + } + + let element_type = ElementType::from(tag).ok_or_else(deserialize_error_no_arg)?; + + // walk through the document until a null byte is encountered + while read_u8(&mut reader)? != 0 { + if reader.position() >= len { + return Err(deserialize_error_no_arg()); + } + } + + // parse the key + let string_end: usize = reader + .position() + .checked_sub(1) // back from null byte + .and_then(|u| usize::try_from(u).ok()) + .ok_or_else(deserialize_error_no_arg)?; + let slice = &reader.get_ref()[(element_start + 1)..string_end]; + let k = std::str::from_utf8(slice).map_err(deserialize_error)?; + + // move to the end of the element + let skip_len = match element_type { + ElementType::Array + | ElementType::EmbeddedDocument + | ElementType::JavaScriptCodeWithScope => { + let l = read_i32(&mut reader)?; + // length includes the 4 bytes for the length, so subtrack them out + l.checked_sub(4).ok_or_else(deserialize_error_no_arg)? + } + ElementType::Binary => read_i32(&mut reader)? + .checked_add(1) // add one for subtype + .ok_or_else(deserialize_error_no_arg)?, + ElementType::Int32 => 4, + ElementType::Int64 => 8, + ElementType::String | ElementType::Symbol | ElementType::JavaScriptCode => { + read_i32(&mut reader)? + } + ElementType::Boolean => 1, + ElementType::Double => 8, + ElementType::Timestamp => 8, + ElementType::Decimal128 => 16, + ElementType::MinKey + | ElementType::MaxKey + | ElementType::Null + | ElementType::Undefined => 0, + ElementType::DateTime => 8, + ElementType::ObjectId => 12, + ElementType::DbPointer => read_i32(&mut reader)? + .checked_add(12) // add 12 for objectid + .ok_or_else(deserialize_error_no_arg)?, + ElementType::RegularExpression => { + // read two cstr's + for _i in 0..2 { + while read_u8(&mut reader)? != 0 { + if reader.position() >= len { + return Err(deserialize_error_no_arg()); + } + } + } + + 0 // don't need to skip anymore since we already read the whole value + } + }; + let skip_len: u64 = skip_len.try_into().map_err(deserialize_error)?; + reader.set_position( + reader + .position() + .checked_add(skip_len) + .ok_or_else(deserialize_error_no_arg)?, + ); + + if k == key { + // if this is the element we're looking for, extract it. + let element_end: usize = reader.position().try_into().map_err(deserialize_error)?; + let element_slice = &reader.get_ref()[element_start..element_end]; + let element_length: i32 = element_slice.len().try_into().map_err(deserialize_error)?; + + // create a new temporary document which just has the element we want and grab the value + let mut temp_doc = Vec::new(); + + // write the document length + let temp_len: i32 = element_length + .checked_add(4 + 1) + .ok_or_else(deserialize_error_no_arg)?; + temp_doc + .write_all(&temp_len.to_le_bytes()) + .map_err(deserialize_error)?; + + // add in the element + temp_doc.extend(element_slice); + + // write the null byte + temp_doc.push(0); + + let d = Document::from_reader(temp_doc.as_slice()).map_err(deserialize_error)?; + return Ok(Some( + d.get("_id").ok_or_else(deserialize_error_no_arg)?.clone(), + )); + } + } + + // read all bytes but didn't reach null byte + Err(deserialize_error_no_arg()) } #[cfg(test)] mod test { - use crate::bson::{ - doc, - oid::ObjectId, - spec::BinarySubtype, - Binary, - Bson, - DateTime, - JavaScriptCodeWithScope, - Regex, - Timestamp, - }; - - use super::doc_size_bytes; + use crate::bson_util::num_decimal_digits; #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] - async fn doc_size_bytes_eq_serialized_size_bytes() { - let doc = doc! { - "double": -12.3, - "string": "foo", - "array": ["foobar", -7, Bson::Null, Bson::Timestamp(Timestamp { time: 12345, increment: 67890 }), false], - "document": { - "x": 1, - "yyz": "Rush is one of the greatest bands of all time", - }, - "bool": true, - "null": Bson::Null, - "regex": Bson::RegularExpression(Regex { pattern: "foobar".into(), options: "i".into() }), - "code": Bson::JavaScriptCode("foo(x) { return x + 1; }".into()), - "code with scope": Bson::JavaScriptCodeWithScope(JavaScriptCodeWithScope { - code: "foo(x) { return x + y; }".into(), - scope: doc! { "y": -17 }, - }), - "i32": 12i32, - "i64": -126i64, - "timestamp": Bson::Timestamp(Timestamp { time: 12233, increment: 34444 }), - "binary": Bson::Binary(Binary{ subtype: BinarySubtype::Generic, bytes: vec![3, 222, 11] }), - "objectid": ObjectId::from_bytes([1; 12]), - "datetime": DateTime::from_millis(4444333221), - "symbol": Bson::Symbol("foobar".into()), - }; - - let size_bytes = doc_size_bytes(&doc); - - let mut serialized_bytes = Vec::new(); - doc.to_writer(&mut serialized_bytes).unwrap(); - - assert_eq!(size_bytes, serialized_bytes.len() as u64); + async fn num_digits() { + assert_eq!(num_decimal_digits(0), 1); + assert_eq!(num_decimal_digits(1), 1); + assert_eq!(num_decimal_digits(10), 2); + assert_eq!(num_decimal_digits(15), 2); + assert_eq!(num_decimal_digits(100), 3); + assert_eq!(num_decimal_digits(125), 3); } } diff --git a/src/client/executor.rs b/src/client/executor.rs index 9ffa36a65..ba1c8963a 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -6,7 +6,7 @@ use std::{collections::HashSet, sync::Arc, time::Instant}; use super::{session::TransactionState, Client, ClientSession}; use crate::{ bson::Document, - cmap::{Connection, RawCommandResponse}, + cmap::{Connection, RawCommand, RawCommandResponse}, error::{ Error, ErrorKind, @@ -50,7 +50,7 @@ lazy_static! { hash_set.insert("copydb"); hash_set }; - static ref HELLO_COMMAND_NAMES: HashSet<&'static str> = { + pub(crate) static ref HELLO_COMMAND_NAMES: HashSet<&'static str> = { let mut hash_set = HashSet::new(); hash_set.insert("hello"); hash_set.insert("ismaster"); @@ -388,23 +388,30 @@ impl Client { cmd.set_server_api(server_api); } - let should_redact = { - let name = cmd.name.to_lowercase(); - REDACTED_COMMANDS.contains(name.as_str()) - || HELLO_COMMAND_NAMES.contains(name.as_str()) - && cmd.body.contains_key("speculativeAuthenticate") + let should_redact = cmd.should_redact(); + + let cmd_name = cmd.name.clone(); + let target_db = cmd.target_db.clone(); + + let serialized = op.serialize_command(cmd)?; + let raw_cmd = RawCommand { + name: cmd_name.clone(), + target_db, + bytes: serialized, }; + let start_time = Instant::now(); self.emit_command_event(|handler| { let command_body = if should_redact { Document::new() } else { - cmd.body.clone() + Document::from_reader(raw_cmd.bytes.as_slice()) + .unwrap_or_else(|e| doc! { "serialization error": e.to_string() }) }; let command_started_event = CommandStartedEvent { command: command_body, - db: cmd.target_db.clone(), - command_name: cmd.name.clone(), + db: raw_cmd.target_db.clone(), + command_name: raw_cmd.name.clone(), request_id, connection: connection_info.clone(), }; @@ -412,10 +419,7 @@ impl Client { handler.handle_command_started_event(command_started_event); }); - let start_time = Instant::now(); - let cmd_name = cmd.name.clone(); - - let command_result = match connection.send_command(cmd, request_id).await { + let command_result = match connection.send_raw_command(raw_cmd, request_id).await { Ok(response) => { match T::Response::deserialize_response(&response) { Ok(r) => { @@ -522,7 +526,7 @@ impl Client { response .raw .body() - .unwrap_or_else(|_| doc! { "error": "failed to deserialize" }) + .unwrap_or_else(|e| doc! { "deserialization error": e.to_string() }) }; let command_succeeded_event = CommandSucceededEvent { diff --git a/src/client/mod.rs b/src/client/mod.rs index a477c75d1..c29facee6 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -32,7 +32,9 @@ use crate::{ sdam::{SelectedServer, SessionSupportStatus, Topology}, ClientSession, }; +pub(crate) use executor::{HELLO_COMMAND_NAMES, REDACTED_COMMANDS}; pub(crate) use session::{ClusterTime, SESSIONS_UNSUPPORTED_COMMANDS}; + use session::{ServerSession, ServerSessionPool}; const DEFAULT_SERVER_SELECTION_TIMEOUT: Duration = Duration::from_secs(30); diff --git a/src/client/options/mod.rs b/src/client/options/mod.rs index 9db39c822..326cf6c68 100644 --- a/src/client/options/mod.rs +++ b/src/client/options/mod.rs @@ -292,10 +292,11 @@ impl fmt::Display for ServerAddress { } /// Specifies the server API version to declare -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[non_exhaustive] pub enum ServerApiVersion { /// Use API version 1. + #[serde(rename = "1")] V1, } @@ -335,17 +336,20 @@ impl<'de> Deserialize<'de> for ServerApiVersion { /// Options used to declare a versioned server API. For more information, see the [Versioned API]( /// https://docs.mongodb.com/v5.0/reference/versioned-api/) manual page. -#[derive(Clone, Debug, Deserialize, PartialEq, TypedBuilder)] +#[serde_with::skip_serializing_none] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, TypedBuilder)] #[builder(field_defaults(setter(into)))] #[serde(rename_all = "camelCase")] #[non_exhaustive] pub struct ServerApi { /// The declared API version. + #[serde(rename = "apiVersion")] pub version: ServerApiVersion, /// Whether the MongoDB server should reject all commands that are not part of the /// declared API version. This includes command options and aggregation pipeline stages. #[builder(default)] + #[serde(rename = "apiStrict")] pub strict: Option, /// Whether the MongoDB server should return command failures when functionality that is diff --git a/src/cmap/conn/command.rs b/src/cmap/conn/command.rs index 1aaa37a5a..711b33c3a 100644 --- a/src/cmap/conn/command.rs +++ b/src/cmap/conn/command.rs @@ -1,4 +1,4 @@ -pub(crate) use serde::de::DeserializeOwned; +use serde::{de::DeserializeOwned, Serialize}; use super::wire::Message; use crate::{ @@ -11,34 +11,71 @@ use crate::{ ClientSession, }; -/// `Command` is a driver side abstraction of a server command containing all the information -/// necessary to serialize it to a wire message. -#[derive(Debug, Clone)] -pub(crate) struct Command { +/// A command that has been serialized to BSON. +#[derive(Debug)] +pub(crate) struct RawCommand { + pub(crate) name: String, + pub(crate) target_db: String, + pub(crate) bytes: Vec, +} + +/// Driver-side model of a database command. +#[serde_with::skip_serializing_none] +#[derive(Clone, Debug, Serialize, Default)] +#[serde(rename_all = "camelCase")] +pub(crate) struct Command { + #[serde(skip)] pub(crate) name: String, + + #[serde(flatten)] + pub(crate) body: T, + + #[serde(rename = "$db")] pub(crate) target_db: String, - pub(crate) body: Document, + + lsid: Option, + + #[serde(rename = "$clusterTime")] + cluster_time: Option, + + #[serde(flatten)] + server_api: Option, + + #[serde(rename = "$readPreference")] + read_preference: Option, + + txn_number: Option, + + start_transaction: Option, + + autocommit: Option, + + read_concern: Option, } -impl Command { - /// Constructs a new command. - pub(crate) fn new(name: String, target_db: String, body: Document) -> Self { +impl Command { + pub(crate) fn new(name: String, target_db: String, body: T) -> Self { Self { name, target_db, body, + lsid: None, + cluster_time: None, + server_api: None, + read_preference: None, + txn_number: None, + start_transaction: None, + autocommit: None, + read_concern: None, } } pub(crate) fn set_session(&mut self, session: &ClientSession) { - self.body.insert("lsid", session.id()); + self.lsid = Some(session.id().clone()) } pub(crate) fn set_cluster_time(&mut self, cluster_time: &ClusterTime) { - // this should never fail. - if let Ok(doc) = bson::to_bson(cluster_time) { - self.body.insert("$clusterTime", doc); - } + self.cluster_time = Some(cluster_time.clone()); } pub(crate) fn set_recovery_token(&mut self, recovery_token: &Document) { @@ -46,41 +83,30 @@ impl Command { } pub(crate) fn set_txn_number(&mut self, txn_number: i64) { - self.body.insert("txnNumber", txn_number); + self.txn_number = Some(txn_number); } pub(crate) fn set_server_api(&mut self, server_api: &ServerApi) { - self.body - .insert("apiVersion", format!("{}", server_api.version)); - - if let Some(strict) = server_api.strict { - self.body.insert("apiStrict", strict); - } - - if let Some(deprecation_errors) = server_api.deprecation_errors { - self.body.insert("apiDeprecationErrors", deprecation_errors); - } + self.server_api = Some(server_api.clone()); } pub(crate) fn set_read_preference(&mut self, read_preference: ReadPreference) -> Result<()> { - self.body - .insert("$readPreference", read_preference.into_document()?); + self.read_preference = Some(read_preference); Ok(()) } pub(crate) fn set_start_transaction(&mut self) { - self.body.insert("startTransaction", true); + self.start_transaction = Some(true); } pub(crate) fn set_autocommit(&mut self) { - self.body.insert("autocommit", false); + self.autocommit = Some(false); } pub(crate) fn set_txn_read_concern(&mut self, session: &ClientSession) -> Result<()> { if let Some(ref options) = session.transaction.options { if let Some(ref read_concern) = options.read_concern { - self.body - .insert("readConcern", bson::to_document(read_concern)?); + self.read_concern = Some(read_concern.clone()); } } Ok(()) @@ -89,8 +115,7 @@ impl Command { pub(crate) fn set_snapshot_read_concern(&mut self, session: &ClientSession) -> Result<()> { let mut concern = ReadConcern::snapshot(); concern.at_cluster_time = session.snapshot_time; - self.body - .insert("readConcern", bson::to_document(&concern)?); + self.read_concern = Some(concern); Ok(()) } } diff --git a/src/cmap/conn/mod.rs b/src/cmap/conn/mod.rs index f3bfd6aa4..794d8650e 100644 --- a/src/cmap/conn/mod.rs +++ b/src/cmap/conn/mod.rs @@ -26,7 +26,7 @@ use crate::{ options::{ServerAddress, TlsOptions}, runtime::AsyncStream, }; -pub(crate) use command::{Command, RawCommandResponse}; +pub(crate) use command::{Command, RawCommand, RawCommandResponse}; pub(crate) use stream_description::StreamDescription; pub(crate) use wire::next_request_id; @@ -229,6 +229,19 @@ impl Connection { } } + async fn send_message(&mut self, message: Message) -> Result { + self.command_executing = true; + let write_result = message.write_to(&mut self.stream).await; + self.error = write_result.is_err(); + write_result?; + + let response_message_result = Message::read_from(&mut self.stream).await; + self.command_executing = false; + self.error = response_message_result.is_err(); + + RawCommandResponse::new(self.address.clone(), response_message_result?) + } + /// Executes a `Command` and returns a `CommandResponse` containing the result from the server. /// /// An `Ok(...)` result simply means the server received the command and that the driver @@ -240,17 +253,22 @@ impl Connection { request_id: impl Into>, ) -> Result { let message = Message::with_command(command, request_id.into())?; + self.send_message(message).await + } - self.command_executing = true; - let write_result = message.write_to(&mut self.stream).await; - self.error = write_result.is_err(); - write_result?; - - let response_message_result = Message::read_from(&mut self.stream).await; - self.command_executing = false; - self.error = response_message_result.is_err(); - - RawCommandResponse::new(self.address.clone(), response_message_result?) + /// Executes a `RawCommand` and returns a `CommandResponse` containing the result from the + /// server. + /// + /// An `Ok(...)` result simply means the server received the command and that the driver + /// driver received the response; it does not imply anything about the success of the command + /// itself. + pub(crate) async fn send_raw_command( + &mut self, + command: RawCommand, + request_id: impl Into>, + ) -> Result { + let message = Message::with_raw_command(command, request_id.into())?; + self.send_message(message).await } /// Gets the connection's StreamDescription. diff --git a/src/cmap/conn/wire/message.rs b/src/cmap/conn/wire/message.rs index 830c56dbe..bf0ceeb24 100644 --- a/src/cmap/conn/wire/message.rs +++ b/src/cmap/conn/wire/message.rs @@ -11,7 +11,10 @@ use futures_util::{ use super::header::{Header, OpCode}; use crate::{ bson_util, - cmap::conn::{command::Command, wire::util::SyncCountReader}, + cmap::{ + conn::{command::RawCommand, wire::util::SyncCountReader}, + Command, + }, error::{Error, ErrorKind, Result}, runtime::{AsyncLittleEndianWrite, AsyncStream, SyncLittleEndianRead}, }; @@ -44,6 +47,19 @@ impl Message { }) } + /// Creates a `Message` from a given `Command`. + /// + /// Note that `response_to` will need to be set manually. + pub(crate) fn with_raw_command(command: RawCommand, request_id: Option) -> Result { + Ok(Self { + response_to: 0, + flags: MessageFlags::empty(), + sections: vec![MessageSection::Document(command.bytes)], + checksum: None, + request_id, + }) + } + /// Gets the first document contained in this Message. pub(crate) fn single_document_response(self) -> Result> { let section = self.sections.into_iter().next().ok_or_else(|| { diff --git a/src/cmap/conn/wire/test.rs b/src/cmap/conn/wire/test.rs index c9b4eccb6..78fb497f9 100644 --- a/src/cmap/conn/wire/test.rs +++ b/src/cmap/conn/wire/test.rs @@ -4,7 +4,6 @@ use tokio::sync::RwLockReadGuard; use super::message::{Message, MessageFlags, MessageSection}; use crate::{ bson::{doc, Bson}, - bson_util, cmap::options::StreamOptions, runtime::AsyncStream, test::{CLIENT_OPTIONS, LOCK}, @@ -23,8 +22,7 @@ async fn basic() { response_to: 0, flags: MessageFlags::empty(), sections: vec![MessageSection::Document( - bson_util::document_to_vec(doc! { "isMaster": 1, "$db": "admin", "apiVersion": "1" }) - .unwrap(), + bson::to_vec(&doc! { "isMaster": 1, "$db": "admin", "apiVersion": "1" }).unwrap(), )], checksum: None, request_id: None, diff --git a/src/cmap/mod.rs b/src/cmap/mod.rs index 6613ce124..b4ba5439a 100644 --- a/src/cmap/mod.rs +++ b/src/cmap/mod.rs @@ -15,7 +15,7 @@ use derivative::Derivative; pub use self::conn::ConnectionInfo; pub(crate) use self::{ - conn::{Command, Connection, RawCommandResponse, StreamDescription}, + conn::{Command, Connection, RawCommand, RawCommandResponse, StreamDescription}, establish::handshake::Handshaker, status::PoolGenerationSubscriber, }; diff --git a/src/operation/abort_transaction/mod.rs b/src/operation/abort_transaction/mod.rs index 60510f301..bad53b609 100644 --- a/src/operation/abort_transaction/mod.rs +++ b/src/operation/abort_transaction/mod.rs @@ -1,3 +1,5 @@ +use bson::Document; + use crate::{ bson::doc, cmap::{Command, StreamDescription}, @@ -28,11 +30,12 @@ impl AbortTransaction { impl Operation for AbortTransaction { type O = (); + type Command = Document; type Response = CommandResponse; const NAME: &'static str = "abortTransaction"; - fn build(&mut self, _description: &StreamDescription) -> Result { + fn build(&mut self, _description: &StreamDescription) -> Result> { let mut body = doc! { Self::NAME: 1, }; diff --git a/src/operation/aggregate/mod.rs b/src/operation/aggregate/mod.rs index 66af7025b..3e79d30e4 100644 --- a/src/operation/aggregate/mod.rs +++ b/src/operation/aggregate/mod.rs @@ -42,10 +42,12 @@ impl Aggregate { impl Operation for Aggregate { type O = CursorSpecification; + type Command = Document; type Response = CursorResponse; + const NAME: &'static str = "aggregate"; - fn build(&mut self, _description: &StreamDescription) -> Result { + fn build(&mut self, _description: &StreamDescription) -> Result> { let mut body = doc! { Self::NAME: self.target.to_bson(), "pipeline": bson_util::to_bson_array(&self.pipeline), diff --git a/src/operation/commit_transaction/mod.rs b/src/operation/commit_transaction/mod.rs index ca91c7eb7..908189791 100644 --- a/src/operation/commit_transaction/mod.rs +++ b/src/operation/commit_transaction/mod.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use bson::doc; +use bson::{doc, Document}; use crate::{ cmap::{Command, StreamDescription}, @@ -23,11 +23,12 @@ impl CommitTransaction { impl Operation for CommitTransaction { type O = (); + type Command = Document; type Response = CommandResponse; const NAME: &'static str = "commitTransaction"; - fn build(&mut self, _description: &StreamDescription) -> Result { + fn build(&mut self, _description: &StreamDescription) -> Result> { let mut body = doc! { Self::NAME: 1, }; diff --git a/src/operation/count/mod.rs b/src/operation/count/mod.rs index 42c4a4993..799bae20a 100644 --- a/src/operation/count/mod.rs +++ b/src/operation/count/mod.rs @@ -41,11 +41,12 @@ impl Count { impl Operation for Count { type O = u64; + type Command = Document; type Response = CommandResponse; const NAME: &'static str = "count"; - fn build(&mut self, description: &StreamDescription) -> Result { + fn build(&mut self, description: &StreamDescription) -> Result> { let mut body = match description.max_wire_version { Some(v) if v >= SERVER_4_9_0_WIRE_VERSION => { doc! { diff --git a/src/operation/count_documents/mod.rs b/src/operation/count_documents/mod.rs index c7a9dc540..ec94d3667 100644 --- a/src/operation/count_documents/mod.rs +++ b/src/operation/count_documents/mod.rs @@ -77,11 +77,12 @@ impl CountDocuments { impl Operation for CountDocuments { type O = u64; + type Command = Document; type Response = CursorResponse; const NAME: &'static str = Aggregate::NAME; - fn build(&mut self, description: &StreamDescription) -> Result { + fn build(&mut self, description: &StreamDescription) -> Result> { self.aggregate.build(description) } diff --git a/src/operation/create/mod.rs b/src/operation/create/mod.rs index 97530b900..820dce6e9 100644 --- a/src/operation/create/mod.rs +++ b/src/operation/create/mod.rs @@ -1,6 +1,8 @@ #[cfg(test)] mod test; +use bson::Document; + use crate::{ bson::doc, cmap::{Command, StreamDescription}, @@ -37,11 +39,12 @@ impl Create { impl Operation for Create { type O = (); + type Command = Document; type Response = CommandResponse; const NAME: &'static str = "create"; - fn build(&mut self, _description: &StreamDescription) -> Result { + fn build(&mut self, _description: &StreamDescription) -> Result> { let mut body = doc! { Self::NAME: self.ns.coll.clone(), }; diff --git a/src/operation/delete/mod.rs b/src/operation/delete/mod.rs index 6c4806f52..8e42570f1 100644 --- a/src/operation/delete/mod.rs +++ b/src/operation/delete/mod.rs @@ -57,11 +57,12 @@ impl Delete { impl Operation for Delete { type O = DeleteResult; + type Command = Document; type Response = CommandResponse; const NAME: &'static str = "delete"; - fn build(&mut self, _description: &StreamDescription) -> Result { + fn build(&mut self, _description: &StreamDescription) -> Result> { let mut delete = doc! { "q": self.filter.clone(), "limit": self.limit, diff --git a/src/operation/distinct/mod.rs b/src/operation/distinct/mod.rs index 2291297b6..4ba56a768 100644 --- a/src/operation/distinct/mod.rs +++ b/src/operation/distinct/mod.rs @@ -52,11 +52,12 @@ impl Distinct { impl Operation for Distinct { type O = Vec; + type Command = Document; type Response = CommandResponse; const NAME: &'static str = "distinct"; - fn build(&mut self, _description: &StreamDescription) -> Result { + fn build(&mut self, _description: &StreamDescription) -> Result> { let mut body: Document = doc! { Self::NAME: self.ns.coll.clone(), "key": self.field_name.clone(), diff --git a/src/operation/drop_collection/mod.rs b/src/operation/drop_collection/mod.rs index 96261727a..ff22c693b 100644 --- a/src/operation/drop_collection/mod.rs +++ b/src/operation/drop_collection/mod.rs @@ -1,6 +1,8 @@ #[cfg(test)] mod test; +use bson::Document; + use crate::{ bson::doc, cmap::{Command, StreamDescription}, @@ -37,11 +39,12 @@ impl DropCollection { impl Operation for DropCollection { type O = (); + type Command = Document; type Response = CommandResponse; const NAME: &'static str = "drop"; - fn build(&mut self, _description: &StreamDescription) -> Result { + fn build(&mut self, _description: &StreamDescription) -> Result> { let mut body = doc! { Self::NAME: self.ns.coll.clone(), }; diff --git a/src/operation/drop_database/mod.rs b/src/operation/drop_database/mod.rs index e0418ee62..a44dc9727 100644 --- a/src/operation/drop_database/mod.rs +++ b/src/operation/drop_database/mod.rs @@ -1,6 +1,8 @@ #[cfg(test)] mod test; +use bson::Document; + use crate::{ bson::doc, cmap::{Command, StreamDescription}, @@ -30,11 +32,12 @@ impl DropDatabase { impl Operation for DropDatabase { type O = (); + type Command = Document; type Response = CommandResponse; const NAME: &'static str = "dropDatabase"; - fn build(&mut self, _description: &StreamDescription) -> Result { + fn build(&mut self, _description: &StreamDescription) -> Result> { let mut body = doc! { Self::NAME: 1, }; diff --git a/src/operation/find/mod.rs b/src/operation/find/mod.rs index e4cc64782..0090e6aae 100644 --- a/src/operation/find/mod.rs +++ b/src/operation/find/mod.rs @@ -54,6 +54,7 @@ impl Find { impl Operation for Find { type O = CursorSpecification; + type Command = Document; type Response = CursorResponse; const NAME: &'static str = "find"; diff --git a/src/operation/find_and_modify/mod.rs b/src/operation/find_and_modify/mod.rs index 4bea1816c..94ea9d5a4 100644 --- a/src/operation/find_and_modify/mod.rs +++ b/src/operation/find_and_modify/mod.rs @@ -102,6 +102,7 @@ where T: DeserializeOwned, { type O = Option; + type Command = Document; type Response = CommandResponse; const NAME: &'static str = "findAndModify"; diff --git a/src/operation/get_more/mod.rs b/src/operation/get_more/mod.rs index 53455dec7..a6894498e 100644 --- a/src/operation/get_more/mod.rs +++ b/src/operation/get_more/mod.rs @@ -3,6 +3,7 @@ mod test; use std::{collections::VecDeque, marker::PhantomData, time::Duration}; +use bson::Document; use serde::{de::DeserializeOwned, Deserialize}; use crate::{ @@ -43,6 +44,7 @@ impl GetMore { impl Operation for GetMore { type O = GetMoreResult; + type Command = Document; type Response = CommandResponse>; const NAME: &'static str = "getMore"; diff --git a/src/operation/insert/mod.rs b/src/operation/insert/mod.rs index 1828b3f1d..d0d796849 100644 --- a/src/operation/insert/mod.rs +++ b/src/operation/insert/mod.rs @@ -1,38 +1,42 @@ #[cfg(test)] mod test; -use std::collections::HashMap; +use std::{collections::HashMap, io::Write}; -use bson::{oid::ObjectId, Bson}; +use bson::{oid::ObjectId, spec::ElementType, Bson}; use serde::Serialize; use crate::{ - bson::{doc, Document}, + bson::doc, bson_util, cmap::{Command, StreamDescription}, error::{BulkWriteFailure, Error, ErrorKind, Result}, - operation::{append_options, Operation, Retryability, WriteResponseBody}, + operation::{Operation, Retryability, WriteResponseBody}, options::{InsertManyOptions, WriteConcern}, results::InsertManyResult, + runtime::SyncLittleEndianWrite, Namespace, }; -use super::CommandResponse; +use super::{CommandBody, CommandResponse}; #[derive(Debug)] -pub(crate) struct Insert { +pub(crate) struct Insert<'a, T> { ns: Namespace, - documents: Vec, + documents: Vec<&'a T>, inserted_ids: Vec, - options: Option, + options: InsertManyOptions, } -impl Insert { +impl<'a, T> Insert<'a, T> { pub(crate) fn new( ns: Namespace, - documents: Vec, + documents: Vec<&'a T>, options: Option, ) -> Self { + let mut options = + options.unwrap_or_else(|| InsertManyOptions::builder().ordered(true).build()); + options.ordered = Some(options.ordered.unwrap_or(true)); Self { ns, options, @@ -42,21 +46,26 @@ impl Insert { } fn is_ordered(&self) -> bool { - self.options - .as_ref() - .and_then(|options| options.ordered) - .unwrap_or(true) + self.options.ordered.unwrap_or(true) } } -impl Operation for Insert { +impl<'a, T: Serialize> Operation for Insert<'a, T> { type O = InsertManyResult; + type Command = InsertCommand; type Response = CommandResponse; const NAME: &'static str = "insert"; - fn build(&mut self, description: &StreamDescription) -> Result { - let mut docs: Vec = vec![]; + fn build(&mut self, description: &StreamDescription) -> Result> { + if self.documents.is_empty() { + return Err(ErrorKind::InvalidArgument { + message: "must specify at least one document to insert".to_string(), + } + .into()); + } + + let mut docs: Vec> = Vec::new(); let mut size = 0; for (i, d) in self @@ -65,25 +74,32 @@ impl Operation for Insert { .take(description.max_write_batch_size as usize) .enumerate() { - let mut doc = bson::to_document(d)?; - let id = doc - .entry("_id".to_string()) - .or_insert_with(|| { - self.inserted_ids - .get(i) - .cloned() - .unwrap_or_else(|| Bson::ObjectId(ObjectId::new())) - }) - .clone(); - - let doc_size = bson_util::array_entry_size_bytes(i, &doc); + let mut doc = bson::to_vec(d)?; + let id = match bson_util::raw_get(doc.as_slice(), "_id")? { + Some(b) => b, + None => { + let oid = ObjectId::new(); + let new_len = doc.len() as i32 + 1 + 4 + 12; + doc.splice(0..4, new_len.to_le_bytes().iter().cloned()); + + let mut new_doc = Vec::new(); + new_doc.write_u8(ElementType::ObjectId as u8)?; + new_doc.write_all(b"_id\0")?; + new_doc.extend(oid.bytes().iter()); + doc.splice(4..4, new_doc.into_iter()); + + Bson::ObjectId(oid) + } + }; + + let doc_size = bson_util::array_entry_size_bytes(i, doc.len()); if (size + doc_size) <= description.max_bson_object_size as u64 { if self.inserted_ids.len() <= i { self.inserted_ids.push(id); } docs.push(doc); - size += doc_size + size += doc_size; } else { break; } @@ -96,20 +112,58 @@ impl Operation for Insert { .into()); } - let mut body = doc! { - Self::NAME: self.ns.coll.clone(), - "documents": docs, + let body = InsertCommand { + insert: self.ns.coll.clone(), + documents: DocumentArraySpec { + documents: docs, + length: size as i32, + }, + options: self.options.clone(), }; - append_options(&mut body, self.options.as_ref())?; + Ok(Command::new("insert".to_string(), self.ns.db.clone(), body)) + } + + fn serialize_command(&mut self, cmd: Command) -> Result> { + // TODO: RUST-924 Use raw document API here instead. + let mut serialized = bson::to_vec(&cmd)?; + + serialized.pop(); // drop null byte + + // write element type + serialized.push(ElementType::Array as u8); + + // write key cstring + serialized.write_all("documents".as_bytes())?; + serialized.push(0); + + // write length of array + let array_length = 4 + cmd.body.documents.length + 1; // add in 4 for length of array, 1 for null byte + serialized.write_all(&array_length.to_le_bytes())?; + + for (i, doc) in cmd.body.documents.documents.into_iter().enumerate() { + // write type of document + serialized.push(ElementType::EmbeddedDocument as u8); - body.insert("ordered", self.is_ordered()); + // write array index + serialized.write_all(i.to_string().as_bytes())?; + serialized.push(0); - Ok(Command::new( - Self::NAME.to_string(), - self.ns.db.clone(), - body, - )) + // write document + serialized.extend(doc); + } + + // write null byte for array + serialized.push(0); + + // write null byte for containing document + serialized.push(0); + + // update length of original doc + let final_length = serialized.len() as i32; + (&mut serialized[0..4]).write_all(&final_length.to_le_bytes())?; + + Ok(serialized) } fn handle_response( @@ -157,12 +211,33 @@ impl Operation for Insert { } fn write_concern(&self) -> Option<&WriteConcern> { - self.options - .as_ref() - .and_then(|opts| opts.write_concern.as_ref()) + self.options.write_concern.as_ref() } fn retryability(&self) -> Retryability { Retryability::Write } } + +/// Data used for creating a BSON array. +struct DocumentArraySpec { + /// The sum of the lengths of all the documents. + length: i32, + + /// The serialized documents to be inserted. + documents: Vec>, +} + +#[derive(Serialize)] +pub(crate) struct InsertCommand { + insert: String, + + /// will be serialized in `serialize_command` + #[serde(skip)] + documents: DocumentArraySpec, + + #[serde(flatten)] + options: InsertManyOptions, +} + +impl CommandBody for InsertCommand {} diff --git a/src/operation/insert/test.rs b/src/operation/insert/test.rs index c16f982af..3cd7a8c36 100644 --- a/src/operation/insert/test.rs +++ b/src/operation/insert/test.rs @@ -1,3 +1,15 @@ +use bson::{ + oid::ObjectId, + spec::BinarySubtype, + Binary, + DateTime, + JavaScriptCodeWithScope, + Regex, + Timestamp, +}; +use lazy_static::lazy_static; +use serde::{Deserialize, Serialize}; + use crate::{ bson::{doc, Bson, Document}, cmap::StreamDescription, @@ -9,18 +21,20 @@ use crate::{ }; struct TestFixtures { - op: Insert, + op: Insert<'static, Document>, documents: Vec, options: InsertManyOptions, } /// Get an Insert operation and the documents/options used to construct it. fn fixtures() -> TestFixtures { - let documents = vec![ - Document::new(), - doc! {"_id": 1234, "a": 1}, - doc! {"a": 123, "b": "hello world" }, - ]; + lazy_static! { + static ref DOCUMENTS: Vec = vec![ + Document::new(), + doc! {"_id": 1234, "a": 1}, + doc! {"a": 123, "b": "hello world" }, + ]; + } let options = InsertManyOptions { ordered: Some(true), @@ -33,13 +47,13 @@ fn fixtures() -> TestFixtures { db: "test_db".to_string(), coll: "test_coll".to_string(), }, - documents.clone(), + DOCUMENTS.iter().collect(), Some(options.clone()), ); TestFixtures { op, - documents, + documents: DOCUMENTS.clone(), options, } } @@ -55,19 +69,14 @@ async fn build() { assert_eq!(cmd.name.as_str(), "insert"); assert_eq!(cmd.target_db.as_str(), "test_db"); - assert_eq!( - cmd.body.get("insert").unwrap(), - &Bson::String("test_coll".to_string()) - ); + assert_eq!(cmd.body.insert, "test_coll".to_string()); let mut cmd_docs: Vec = cmd .body - .get("documents") - .unwrap() - .as_array() - .unwrap() + .documents + .documents .iter() - .map(|b| b.as_document().unwrap().clone()) + .map(|b| Document::from_reader(b.as_slice()).unwrap()) .collect(); assert_eq!(cmd_docs.len(), fixtures.documents.len()); @@ -79,12 +88,15 @@ async fn build() { assert_eq!(original_doc, cmd_doc); } + let serialized = fixtures.op.serialize_command(cmd).unwrap(); + let cmd_doc = Document::from_reader(serialized.as_slice()).unwrap(); + assert_eq!( - cmd.body.get("ordered"), + cmd_doc.get("ordered"), fixtures.options.ordered.map(Bson::Boolean).as_ref() ); assert_eq!( - cmd.body.get("bypassDocumentValidation"), + cmd_doc.get("bypassDocumentValidation"), fixtures .options .bypass_document_validation @@ -92,7 +104,7 @@ async fn build() { .as_ref() ); assert_eq!( - cmd.body.get("writeConcern"), + cmd_doc.get("writeConcern"), fixtures .options .write_concern @@ -105,31 +117,152 @@ async fn build() { #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn build_ordered() { - let mut insert = Insert::new(Namespace::empty(), vec![Document::new()], None); + let docs = vec![Document::new()]; + let mut insert = Insert::new(Namespace::empty(), docs.iter().collect(), None); let cmd = insert .build(&StreamDescription::new_testing()) .expect("should succeed"); - assert_eq!(cmd.body.get("ordered"), Some(&Bson::Boolean(true))); + let serialized = insert.serialize_command(cmd).unwrap(); + let cmd_doc = Document::from_reader(serialized.as_slice()).unwrap(); + assert_eq!(cmd_doc.get("ordered"), Some(&Bson::Boolean(true))); let mut insert = Insert::new( Namespace::empty(), - vec![Document::new()], + docs.iter().collect(), Some(InsertManyOptions::builder().ordered(false).build()), ); let cmd = insert .build(&StreamDescription::new_testing()) .expect("should succeed"); - assert_eq!(cmd.body.get("ordered"), Some(&Bson::Boolean(false))); + let serialized = insert.serialize_command(cmd).unwrap(); + let cmd_doc = Document::from_reader(serialized.as_slice()).unwrap(); + assert_eq!(cmd_doc.get("ordered"), Some(&Bson::Boolean(false))); let mut insert = Insert::new( Namespace::empty(), - vec![Document::new()], + docs.iter().collect(), Some(InsertManyOptions::builder().ordered(true).build()), ); let cmd = insert .build(&StreamDescription::new_testing()) .expect("should succeed"); - assert_eq!(cmd.body.get("ordered"), Some(&Bson::Boolean(true))); + let serialized = insert.serialize_command(cmd).unwrap(); + let cmd_doc = Document::from_reader(serialized.as_slice()).unwrap(); + assert_eq!(cmd_doc.get("ordered"), Some(&Bson::Boolean(true))); + + let mut insert = Insert::new( + Namespace::empty(), + docs.iter().collect(), + Some(InsertManyOptions::builder().build()), + ); + let cmd = insert + .build(&StreamDescription::new_testing()) + .expect("should succeed"); + let serialized = insert.serialize_command(cmd).unwrap(); + let cmd_doc = Document::from_reader(serialized.as_slice()).unwrap(); + assert_eq!(cmd_doc.get("ordered"), Some(&Bson::Boolean(true))); +} + +#[derive(Debug, Serialize, Deserialize)] +struct Documents { + documents: Vec, +} + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn generate_ids() { + let docs = vec![doc! { "x": 1 }, doc! { "_id": 1_i32, "x": 2 }]; + + let mut insert = Insert::new(Namespace::empty(), docs.iter().collect(), None); + let cmd = insert.build(&StreamDescription::new_testing()).unwrap(); + let serialized = insert.serialize_command(cmd).unwrap(); + + #[derive(Debug, Serialize, Deserialize)] + struct D { + x: i32, + + #[serde(rename = "_id")] + id: Bson, + } + + let docs: Documents = bson::from_slice(serialized.as_slice()).unwrap(); + + assert_eq!(docs.documents.len(), 2); + let docs = docs.documents; + + docs[0].id.as_object_id().unwrap(); + assert_eq!(docs[0].x, 1); + + assert_eq!(docs[1].id, Bson::Int32(1)); + assert_eq!(docs[1].x, 2); + + // ensure the _id was prepended to the document + let docs: Documents = bson::from_slice(serialized.as_slice()).unwrap(); + assert_eq!(docs.documents[0].iter().next().unwrap().0, "_id") +} + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn serialize_all_types() { + let binary = Binary { + bytes: vec![36, 36, 36], + subtype: BinarySubtype::Generic, + }; + let date = DateTime::now(); + let regex = Regex { + pattern: "hello".to_string(), + options: "x".to_string(), + }; + let timestamp = Timestamp { + time: 123, + increment: 456, + }; + let code = Bson::JavaScriptCode("console.log(1)".to_string()); + let code_w_scope = JavaScriptCodeWithScope { + code: "console.log(a)".to_string(), + scope: doc! { "a": 1 }, + }; + let oid = ObjectId::new(); + let subdoc = doc! { "k": true, "b": { "hello": "world" } }; + + let decimal = { + let bytes = hex::decode("18000000136400D0070000000000000000000000003A3000").unwrap(); + let d = Document::from_reader(bytes.as_slice()).unwrap(); + d.get("d").unwrap().clone() + }; + + let docs = vec![doc! { + "x": 1_i32, + "y": 2_i64, + "s": "oke", + "array": [ true, "oke", { "12": 24 } ], + "bson": 1234.5, + "oid": oid, + "null": Bson::Null, + "subdoc": subdoc.clone(), + "b": true, + "d": 12.5, + "binary": binary.clone(), + "date": date, + "regex": regex.clone(), + "ts": timestamp, + "i": { "a": 300, "b": 12345 }, + "undefined": Bson::Undefined, + "code": code.clone(), + "code_w_scope": code_w_scope.clone(), + "decimal": decimal.clone(), + "symbol": Bson::Symbol("ok".to_string()), + "min_key": Bson::MinKey, + "max_key": Bson::MaxKey, + "_id": ObjectId::new(), + }]; + + let mut insert = Insert::new(Namespace::empty(), docs.iter().collect(), None); + let cmd = insert.build(&StreamDescription::new_testing()).unwrap(); + let serialized = insert.serialize_command(cmd).unwrap(); + let cmd: Documents = bson::from_slice(serialized.as_slice()).unwrap(); + + assert_eq!(cmd.documents, docs); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] diff --git a/src/operation/list_collections/mod.rs b/src/operation/list_collections/mod.rs index 3e2141a68..21cd2e440 100644 --- a/src/operation/list_collections/mod.rs +++ b/src/operation/list_collections/mod.rs @@ -52,6 +52,7 @@ where T: DeserializeOwned + Unpin + Send + Sync, { type O = CursorSpecification; + type Command = Document; type Response = CursorResponse; const NAME: &'static str = "listCollections"; diff --git a/src/operation/list_databases/mod.rs b/src/operation/list_databases/mod.rs index 356a59e11..e52a1733f 100644 --- a/src/operation/list_databases/mod.rs +++ b/src/operation/list_databases/mod.rs @@ -46,6 +46,7 @@ impl ListDatabases { impl Operation for ListDatabases { type O = Vec; + type Command = Document; type Response = CommandResponse; const NAME: &'static str = "listDatabases"; diff --git a/src/operation/mod.rs b/src/operation/mod.rs index b7f0fa237..22fe1bd8f 100644 --- a/src/operation/mod.rs +++ b/src/operation/mod.rs @@ -28,7 +28,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::{ bson::{self, Bson, Document}, bson_util, - client::ClusterTime, + client::{ClusterTime, HELLO_COMMAND_NAMES, REDACTED_COMMANDS}, cmap::{Command, RawCommandResponse, StreamDescription}, error::{ BulkWriteError, @@ -69,6 +69,8 @@ pub(crate) trait Operation { /// The output type of this operation. type O; + type Command: CommandBody; + /// The format of the command response from the server. type Response: Response; @@ -77,7 +79,11 @@ pub(crate) trait Operation { /// Returns the command that should be sent to the server as part of this operation. /// The operation may store some additional state that is required for handling the response. - fn build(&mut self, description: &StreamDescription) -> Result; + fn build(&mut self, description: &StreamDescription) -> Result>; + + fn serialize_command(&mut self, cmd: Command) -> Result> { + Ok(bson::to_vec(&cmd)?) + } /// Interprets the server response to the command. fn handle_response( @@ -127,6 +133,26 @@ pub(crate) trait Operation { } } +pub(crate) trait CommandBody: Serialize { + fn should_redact(&self) -> bool { + false + } +} + +impl CommandBody for Document { + fn should_redact(&self) -> bool { + self.contains_key("speculativeAuthenticate") + } +} + +impl Command { + pub(crate) fn should_redact(&self) -> bool { + let name = self.name.to_lowercase(); + REDACTED_COMMANDS.contains(name.as_str()) + || HELLO_COMMAND_NAMES.contains(name.as_str()) && self.body.should_redact() + } +} + /// Trait modeling the behavior of a command response to a server operation. pub(crate) trait Response: Sized { /// The command-specific portion of a command response. diff --git a/src/operation/run_command/mod.rs b/src/operation/run_command/mod.rs index a540ccc86..573a91c36 100644 --- a/src/operation/run_command/mod.rs +++ b/src/operation/run_command/mod.rs @@ -47,6 +47,7 @@ impl RunCommand { impl Operation for RunCommand { type O = Document; + type Command = Document; type Response = Response; // Since we can't actually specify a string statically here, we just put a descriptive string diff --git a/src/operation/update/mod.rs b/src/operation/update/mod.rs index ebe0bd961..a64473e23 100644 --- a/src/operation/update/mod.rs +++ b/src/operation/update/mod.rs @@ -59,6 +59,7 @@ impl Update { impl Operation for Update { type O = UpdateResult; + type Command = Document; type Response = CommandResponse>; const NAME: &'static str = "update"; diff --git a/src/sdam/description/topology/mod.rs b/src/sdam/description/topology/mod.rs index 3870aef47..bfbf2b08b 100644 --- a/src/sdam/description/topology/mod.rs +++ b/src/sdam/description/topology/mod.rs @@ -153,10 +153,10 @@ impl TopologyDescription { self.servers.get(address) } - pub(crate) fn update_command_with_read_pref( + pub(crate) fn update_command_with_read_pref( &self, server_type: ServerType, - command: &mut Command, + command: &mut Command, criteria: Option<&SelectionCriteria>, ) -> crate::error::Result<()> { match (self.topology_type, server_type) { @@ -192,9 +192,9 @@ impl TopologyDescription { } } - fn update_command_read_pref_for_mongos( + fn update_command_read_pref_for_mongos( &self, - command: &mut Command, + command: &mut Command, criteria: Option<&SelectionCriteria>, ) -> crate::error::Result<()> { match criteria { diff --git a/src/sdam/state/mod.rs b/src/sdam/state/mod.rs index b07961d47..ed2d67f99 100644 --- a/src/sdam/state/mod.rs +++ b/src/sdam/state/mod.rs @@ -385,10 +385,10 @@ impl Topology { } /// Updates the given `command` as needed based on the `critiera`. - pub(crate) async fn update_command_with_read_pref( + pub(crate) async fn update_command_with_read_pref( &self, server_address: &ServerAddress, - command: &mut Command, + command: &mut Command, criteria: Option<&SelectionCriteria>, ) -> Result<()> { self.state @@ -489,10 +489,10 @@ impl TopologyState { } /// Updates the given `command` as needed based on the `criteria`. - pub(crate) fn update_command_with_read_pref( + pub(crate) fn update_command_with_read_pref( &self, server_address: &ServerAddress, - command: &mut Command, + command: &mut Command, criteria: Option<&SelectionCriteria>, ) -> Result<()> { let server_type = self diff --git a/src/selection_criteria.rs b/src/selection_criteria.rs index bcf35cf8c..106ce84ae 100644 --- a/src/selection_criteria.rs +++ b/src/selection_criteria.rs @@ -1,13 +1,13 @@ -use std::{collections::HashMap, convert::TryInto, sync::Arc, time::Duration}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use derivative::Derivative; -use serde::{de::Error as SerdeError, Deserialize, Deserializer}; +use serde::{de::Error as SerdeError, Deserialize, Deserializer, Serialize}; use typed_builder::TypedBuilder; use crate::{ - bson::{doc, Bson, Document}, + bson::{doc}, bson_util, - error::{Error, ErrorKind, Result}, + error::{ErrorKind, Result}, options::ServerAddress, sdam::public::ServerInfo, }; @@ -133,12 +133,13 @@ pub enum ReadPreference { Nearest { options: ReadPreferenceOptions }, } + impl<'de> Deserialize<'de> for ReadPreference { fn deserialize(deserializer: D) -> std::result::Result where D: Deserializer<'de>, { - #[derive(Deserialize)] + #[derive(Serialize, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] struct ReadPreferenceHelper { mode: String, @@ -169,8 +170,50 @@ impl<'de> Deserialize<'de> for ReadPreference { } } +impl Serialize for ReadPreference { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + #[serde_with::skip_serializing_none] + #[derive(Serialize)] + #[serde(rename_all = "camelCase", deny_unknown_fields)] + struct ReadPreferenceHelper<'a> { + mode: &'static str, + #[serde(flatten)] + options: Option<&'a ReadPreferenceOptions>, + } + + let helper = match self { + ReadPreference::Primary => ReadPreferenceHelper { + mode: "primary", + options: None, + }, + ReadPreference::PrimaryPreferred { options } => ReadPreferenceHelper { + mode: "primaryPreferred", + options: Some(options), + }, + ReadPreference::Secondary { options } => ReadPreferenceHelper { + mode: "secondary", + options: Some(options), + }, + ReadPreference::SecondaryPreferred { options } => ReadPreferenceHelper { + mode: "secondaryPreferred", + options: Some(options), + }, + ReadPreference::Nearest { options } => ReadPreferenceHelper { + mode: "nearest", + options: Some(options), + }, + }; + + helper.serialize(serializer) + } +} + /// Specifies read preference options for non-primary read preferences. -#[derive(Clone, Debug, Default, Deserialize, PartialEq, TypedBuilder)] +#[serde_with::skip_serializing_none] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, TypedBuilder)] #[builder(field_defaults(default, setter(into)))] #[serde(rename_all = "camelCase")] #[non_exhaustive] @@ -188,7 +231,8 @@ pub struct ReadPreferenceOptions { #[serde( rename = "maxStalenessSeconds", default, - deserialize_with = "bson_util::deserialize_duration_option_from_u64_seconds" + deserialize_with = "bson_util::deserialize_duration_option_from_u64_seconds", + serialize_with = "bson_util::serialize_duration_option_as_int_secs", )] pub max_staleness: Option, @@ -203,7 +247,7 @@ pub struct ReadPreferenceOptions { /// Specifies hedging behavior for reads. /// /// See the [MongoDB docs](https://docs.mongodb.com/manual/core/read-preference-hedge-option/) for more details. -#[derive(Clone, Debug, Deserialize, PartialEq, TypedBuilder)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, TypedBuilder)] #[non_exhaustive] pub struct HedgedReadOptions { /// Whether or not to allow reads from a sharded cluster to be "hedged" across two replica @@ -272,66 +316,6 @@ impl ReadPreference { Ok(self) } - pub(crate) fn into_document(self) -> Result { - let (mode, tag_sets, max_staleness, hedge) = match self { - ReadPreference::Primary => ("primary", None, None, None), - ReadPreference::PrimaryPreferred { options } => ( - "primaryPreferred", - options.tag_sets, - options.max_staleness, - options.hedge, - ), - ReadPreference::Secondary { options } => ( - "secondary", - options.tag_sets, - options.max_staleness, - options.hedge, - ), - ReadPreference::SecondaryPreferred { options } => ( - "secondaryPreferred", - options.tag_sets, - options.max_staleness, - options.hedge, - ), - ReadPreference::Nearest { options } => ( - "nearest", - options.tag_sets, - options.max_staleness, - options.hedge, - ), - }; - - let mut doc = doc! { "mode": mode }; - - if let Some(max_stale) = max_staleness { - let s: i64 = max_stale.as_secs().try_into().map_err(|_| { - Error::from(ErrorKind::InvalidArgument { - message: format!( - "provided maxStalenessSeconds {:?} exceeds the range of i64 seconds", - max_stale - ), - }) - })?; - doc.insert("maxStalenessSeconds", s); - } - - if let Some(tag_sets) = tag_sets { - let tags: Vec = tag_sets - .into_iter() - .map(|tag_set| { - Bson::Document(tag_set.into_iter().map(|(k, v)| (k, v.into())).collect()) - }) - .collect(); - doc.insert("tags", tags); - } - - if let Some(hedge) = hedge { - doc.insert("hedge", doc! { "enabled": hedge.enabled }); - } - - Ok(doc) - } - #[cfg(test)] pub(crate) fn serialize_for_client_options( read_preference: &ReadPreference, @@ -399,7 +383,7 @@ mod test { .build(); let read_pref = ReadPreference::Secondary { options }; - let doc = read_pref.into_document().unwrap(); + let doc = bson::to_document(&read_pref).unwrap(); assert_eq!( doc, diff --git a/src/test/spec/crud_v1/insert_many.rs b/src/test/spec/crud_v1/insert_many.rs index ea3cad1a8..b7dd7bb11 100644 --- a/src/test/spec/crud_v1/insert_many.rs +++ b/src/test/spec/crud_v1/insert_many.rs @@ -58,8 +58,13 @@ async fn run_insert_many_test(test_file: TestFile) { assert_ne!(outcome.error, Some(true), "{}", test_case.description); result.inserted_ids } - Err(_) => { - assert!(outcome.error.unwrap_or(false), "{}", test_case.description); + Err(e) => { + assert!( + outcome.error.unwrap_or(false), + "{}: expected no error, got {:?}", + test_case.description, + e + ); Default::default() } }; diff --git a/src/test/spec/unified_runner/test_file.rs b/src/test/spec/unified_runner/test_file.rs index b0f7bf42c..d5350a488 100644 --- a/src/test/spec/unified_runner/test_file.rs +++ b/src/test/spec/unified_runner/test_file.rs @@ -7,7 +7,7 @@ use super::{Operation, TestEvent}; use crate::{ bson::{doc, Bson, Deserializer as BsonDeserializer, Document}, - client::options::{ServerApi, SessionOptions}, + client::options::{ServerApi, ServerApiVersion, SessionOptions}, concern::{Acknowledgment, ReadConcernLevel}, error::Error, options::{ @@ -145,8 +145,9 @@ pub struct Client { pub use_multiple_mongoses: Option, pub observe_events: Option>, pub ignore_command_monitoring_events: Option>, - pub observe_sensitive_commands: Option, #[serde(default)] + pub observe_sensitive_commands: Option, + #[serde(default, deserialize_with = "deserialize_server_api")] pub server_api: Option, } @@ -154,6 +155,28 @@ fn default_uri() -> String { DEFAULT_URI.clone() } +pub fn deserialize_server_api<'de, D>( + deserializer: D, +) -> std::result::Result, D::Error> +where + D: Deserializer<'de>, +{ + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase", deny_unknown_fields)] + struct ApiHelper { + version: ServerApiVersion, + strict: Option, + deprecation_errors: Option, + } + + let h = ApiHelper::deserialize(deserializer)?; + Ok(Some(ServerApi { + version: h.version, + strict: h.strict, + deprecation_errors: h.deprecation_errors, + })) +} + pub fn deserialize_uri_options_to_uri_string<'de, D>( deserializer: D, ) -> std::result::Result diff --git a/src/test/spec/v2_runner/test_event.rs b/src/test/spec/v2_runner/test_event.rs index f410a7f78..59729cd19 100644 --- a/src/test/spec/v2_runner/test_event.rs +++ b/src/test/spec/v2_runner/test_event.rs @@ -35,6 +35,7 @@ impl CommandStartedEvent { other => panic!("unknown session name: {}", other), } } + self.command.content_matches(&expected) } } From b2afcb1966680c00b37ed567f12622e346bfff78 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 29 Jul 2021 17:11:23 -0400 Subject: [PATCH 02/10] fix lint --- src/client/auth/mod.rs | 6 +++--- src/client/mod.rs | 2 -- src/operation/insert/test.rs | 12 ++++++------ 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/client/auth/mod.rs b/src/client/auth/mod.rs index b464f9963..5d9713b61 100644 --- a/src/client/auth/mod.rs +++ b/src/client/auth/mod.rs @@ -229,9 +229,9 @@ impl AuthMechanism { Ok(Some(ClientFirst::Scram(ScramVersion::Sha256, client_first))) } - Self::MongoDbX509 => Ok(Some(ClientFirst::X509( + Self::MongoDbX509 => Ok(Some(ClientFirst::X509(Box::new( x509::build_speculative_client_first(credential), - ))), + )))), Self::Plain => Ok(None), #[cfg(feature = "tokio-runtime")] AuthMechanism::MongoDbAws => Ok(None), @@ -469,7 +469,7 @@ impl Debug for Credential { /// Contains the first client message sent as part of the authentication handshake. pub(crate) enum ClientFirst { Scram(ScramVersion, scram::ClientFirst), - X509(Command), + X509(Box), } impl ClientFirst { diff --git a/src/client/mod.rs b/src/client/mod.rs index c29facee6..e8b8f1067 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -9,8 +9,6 @@ use bson::Bson; use derivative::Derivative; use std::time::Instant; -#[cfg(test)] -pub(crate) use self::executor::REDACTED_COMMANDS; #[cfg(test)] use crate::options::ServerAddress; use crate::{ diff --git a/src/operation/insert/test.rs b/src/operation/insert/test.rs index 3cd7a8c36..b56ac0d69 100644 --- a/src/operation/insert/test.rs +++ b/src/operation/insert/test.rs @@ -239,18 +239,18 @@ async fn serialize_all_types() { "bson": 1234.5, "oid": oid, "null": Bson::Null, - "subdoc": subdoc.clone(), + "subdoc": subdoc, "b": true, "d": 12.5, - "binary": binary.clone(), + "binary": binary, "date": date, - "regex": regex.clone(), + "regex": regex, "ts": timestamp, "i": { "a": 300, "b": 12345 }, "undefined": Bson::Undefined, - "code": code.clone(), - "code_w_scope": code_w_scope.clone(), - "decimal": decimal.clone(), + "code": code, + "code_w_scope": code_w_scope, + "decimal": decimal, "symbol": Bson::Symbol("ok".to_string()), "min_key": Bson::MinKey, "max_key": Bson::MaxKey, From 845a5d08481ee5cd86bf4253f8c07e92bc1ece8b Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 29 Jul 2021 18:17:44 -0400 Subject: [PATCH 03/10] various cleanup --- src/client/executor.rs | 2 +- src/operation/insert/mod.rs | 45 +++++++++++++++++++------------------ src/operation/mod.rs | 13 ++++++++--- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/client/executor.rs b/src/client/executor.rs index ba1c8963a..c48ccce8a 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -400,7 +400,6 @@ impl Client { bytes: serialized, }; - let start_time = Instant::now(); self.emit_command_event(|handler| { let command_body = if should_redact { Document::new() @@ -419,6 +418,7 @@ impl Client { handler.handle_command_started_event(command_started_event); }); + let start_time = Instant::now(); let command_result = match connection.send_raw_command(raw_cmd, request_id).await { Ok(response) => { match T::Response::deserialize_response(&response) { diff --git a/src/operation/insert/mod.rs b/src/operation/insert/mod.rs index d0d796849..49473d60f 100644 --- a/src/operation/insert/mod.rs +++ b/src/operation/insert/mod.rs @@ -25,7 +25,7 @@ pub(crate) struct Insert<'a, T> { ns: Namespace, documents: Vec<&'a T>, inserted_ids: Vec, - options: InsertManyOptions, + options: Option, } impl<'a, T> Insert<'a, T> { @@ -34,9 +34,6 @@ impl<'a, T> Insert<'a, T> { documents: Vec<&'a T>, options: Option, ) -> Self { - let mut options = - options.unwrap_or_else(|| InsertManyOptions::builder().ordered(true).build()); - options.ordered = Some(options.ordered.unwrap_or(true)); Self { ns, options, @@ -46,7 +43,10 @@ impl<'a, T> Insert<'a, T> { } fn is_ordered(&self) -> bool { - self.options.ordered.unwrap_or(true) + self.options + .as_ref() + .and_then(|o| o.ordered) + .unwrap_or(true) } } @@ -58,13 +58,6 @@ impl<'a, T: Serialize> Operation for Insert<'a, T> { const NAME: &'static str = "insert"; fn build(&mut self, description: &StreamDescription) -> Result> { - if self.documents.is_empty() { - return Err(ErrorKind::InvalidArgument { - message: "must specify at least one document to insert".to_string(), - } - .into()); - } - let mut docs: Vec> = Vec::new(); let mut size = 0; @@ -79,14 +72,19 @@ impl<'a, T: Serialize> Operation for Insert<'a, T> { Some(b) => b, None => { let oid = ObjectId::new(); - let new_len = doc.len() as i32 + 1 + 4 + 12; - doc.splice(0..4, new_len.to_le_bytes().iter().cloned()); - let mut new_doc = Vec::new(); - new_doc.write_u8(ElementType::ObjectId as u8)?; - new_doc.write_all(b"_id\0")?; - new_doc.extend(oid.bytes().iter()); - doc.splice(4..4, new_doc.into_iter()); + // write element to temporary buffer + let mut new_id = Vec::new(); + new_id.write_u8(ElementType::ObjectId as u8)?; + new_id.write_all(b"_id\0")?; + new_id.extend(oid.bytes().iter()); + + // insert element to beginning of existing doc after length + doc.splice(4..4, new_id.into_iter()); + + // update length of doc + let new_len = doc.len() as i32; + doc.splice(0..4, new_len.to_le_bytes().iter().cloned()); Bson::ObjectId(oid) } @@ -112,13 +110,16 @@ impl<'a, T: Serialize> Operation for Insert<'a, T> { .into()); } + let mut options = self.options.clone().unwrap_or_default(); + options.ordered = Some(self.is_ordered()); + let body = InsertCommand { insert: self.ns.coll.clone(), documents: DocumentArraySpec { documents: docs, length: size as i32, }, - options: self.options.clone(), + options, }; Ok(Command::new("insert".to_string(), self.ns.db.clone(), body)) @@ -161,7 +162,7 @@ impl<'a, T: Serialize> Operation for Insert<'a, T> { // update length of original doc let final_length = serialized.len() as i32; - (&mut serialized[0..4]).write_all(&final_length.to_le_bytes())?; + serialized.splice(0..4, final_length.to_le_bytes().iter().cloned()); Ok(serialized) } @@ -211,7 +212,7 @@ impl<'a, T: Serialize> Operation for Insert<'a, T> { } fn write_concern(&self) -> Option<&WriteConcern> { - self.options.write_concern.as_ref() + self.options.as_ref().and_then(|o| o.write_concern.as_ref()) } fn retryability(&self) -> Retryability { diff --git a/src/operation/mod.rs b/src/operation/mod.rs index 22fe1bd8f..fd383dbdf 100644 --- a/src/operation/mod.rs +++ b/src/operation/mod.rs @@ -69,6 +69,7 @@ pub(crate) trait Operation { /// The output type of this operation. type O; + /// The format of the command body constructed in `build`. type Command: CommandBody; /// The format of the command response from the server. @@ -81,6 +82,8 @@ pub(crate) trait Operation { /// The operation may store some additional state that is required for handling the response. fn build(&mut self, description: &StreamDescription) -> Result>; + /// Perform custom serialization of the built command. + /// By default, this will just call through to the `Serialize` implementation of the command. fn serialize_command(&mut self, cmd: Command) -> Result> { Ok(bson::to_vec(&cmd)?) } @@ -141,15 +144,19 @@ pub(crate) trait CommandBody: Serialize { impl CommandBody for Document { fn should_redact(&self) -> bool { - self.contains_key("speculativeAuthenticate") + if let Some(command_name) = bson_util::first_key(self) { + HELLO_COMMAND_NAMES.contains(command_name.to_lowercase().as_str()) + && self.contains_key("speculativeAuthenticate") + } else { + false + } } } impl Command { pub(crate) fn should_redact(&self) -> bool { let name = self.name.to_lowercase(); - REDACTED_COMMANDS.contains(name.as_str()) - || HELLO_COMMAND_NAMES.contains(name.as_str()) && self.body.should_redact() + REDACTED_COMMANDS.contains(name.as_str()) || self.body.should_redact() } } From e2db5d78b88de5bfc2ba00d5ca2d940ba204441c Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Mon, 2 Aug 2021 16:20:33 -0400 Subject: [PATCH 04/10] support setting recovery token in `Command` --- src/cmap/conn/command.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cmap/conn/command.rs b/src/cmap/conn/command.rs index 711b33c3a..0deb6de93 100644 --- a/src/cmap/conn/command.rs +++ b/src/cmap/conn/command.rs @@ -51,6 +51,8 @@ pub(crate) struct Command { autocommit: Option, read_concern: Option, + + recovery_token: Option, } impl Command { @@ -67,6 +69,7 @@ impl Command { start_transaction: None, autocommit: None, read_concern: None, + recovery_token: None, } } @@ -79,7 +82,7 @@ impl Command { } pub(crate) fn set_recovery_token(&mut self, recovery_token: &Document) { - self.body.insert("recoveryToken", recovery_token); + self.recovery_token = Some(recovery_token.clone()); } pub(crate) fn set_txn_number(&mut self, txn_number: i64) { From d7dab28fb5042675a05d75788309839e5e14d383 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 29 Jul 2021 23:07:31 -0400 Subject: [PATCH 05/10] fix Message::with_command, rename deprecation_errors --- src/client/options/mod.rs | 2 +- src/cmap/conn/mod.rs | 2 +- src/cmap/conn/wire/message.rs | 26 ++++++++++++-------------- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/client/options/mod.rs b/src/client/options/mod.rs index 326cf6c68..f60d707d1 100644 --- a/src/client/options/mod.rs +++ b/src/client/options/mod.rs @@ -339,7 +339,6 @@ impl<'de> Deserialize<'de> for ServerApiVersion { #[serde_with::skip_serializing_none] #[derive(Clone, Debug, Deserialize, Serialize, PartialEq, TypedBuilder)] #[builder(field_defaults(setter(into)))] -#[serde(rename_all = "camelCase")] #[non_exhaustive] pub struct ServerApi { /// The declared API version. @@ -356,6 +355,7 @@ pub struct ServerApi { /// deprecated from the declared API version is used. /// Note that at the time of this writing, no deprecations in version 1 exist. #[builder(default)] + #[serde(rename = "apiDeprecationErrors")] pub deprecation_errors: Option, } diff --git a/src/cmap/conn/mod.rs b/src/cmap/conn/mod.rs index 794d8650e..4b8413dfc 100644 --- a/src/cmap/conn/mod.rs +++ b/src/cmap/conn/mod.rs @@ -267,7 +267,7 @@ impl Connection { command: RawCommand, request_id: impl Into>, ) -> Result { - let message = Message::with_raw_command(command, request_id.into())?; + let message = Message::with_raw_command(command, request_id.into()); self.send_message(message).await } diff --git a/src/cmap/conn/wire/message.rs b/src/cmap/conn/wire/message.rs index bf0ceeb24..4e2d9edc3 100644 --- a/src/cmap/conn/wire/message.rs +++ b/src/cmap/conn/wire/message.rs @@ -33,31 +33,29 @@ impl Message { /// Creates a `Message` from a given `Command`. /// /// Note that `response_to` will need to be set manually. - pub(crate) fn with_command(mut command: Command, request_id: Option) -> Result { - command.body.insert("$db", command.target_db); - - let mut bytes = Vec::new(); - command.body.to_writer(&mut bytes)?; - Ok(Self { - response_to: 0, - flags: MessageFlags::empty(), - sections: vec![MessageSection::Document(bytes)], - checksum: None, + pub(crate) fn with_command(command: Command, request_id: Option) -> Result { + let bytes = bson::to_vec(&command)?; + Ok(Self::with_raw_command( + RawCommand { + bytes, + target_db: command.target_db, + name: command.name, + }, request_id, - }) + )) } /// Creates a `Message` from a given `Command`. /// /// Note that `response_to` will need to be set manually. - pub(crate) fn with_raw_command(command: RawCommand, request_id: Option) -> Result { - Ok(Self { + pub(crate) fn with_raw_command(command: RawCommand, request_id: Option) -> Self { + Self { response_to: 0, flags: MessageFlags::empty(), sections: vec![MessageSection::Document(command.bytes)], checksum: None, request_id, - }) + } } /// Gets the first document contained in this Message. From fd13cec015d8e53abd3e86035f43376d45c4ed41 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Tue, 3 Aug 2021 15:31:30 -0400 Subject: [PATCH 06/10] fix typo --- src/cmap/conn/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cmap/conn/mod.rs b/src/cmap/conn/mod.rs index 4b8413dfc..1fbfeac93 100644 --- a/src/cmap/conn/mod.rs +++ b/src/cmap/conn/mod.rs @@ -260,7 +260,7 @@ impl Connection { /// server. /// /// An `Ok(...)` result simply means the server received the command and that the driver - /// driver received the response; it does not imply anything about the success of the command + /// received the response; it does not imply anything about the success of the command /// itself. pub(crate) async fn send_raw_command( &mut self, From fd8a2f07c941cfd018141329783ea1aec3ca0762 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Tue, 3 Aug 2021 16:07:53 -0400 Subject: [PATCH 07/10] stop returning result from command setters --- src/client/executor.rs | 6 +++--- src/cmap/conn/command.rs | 9 +++------ src/cmap/establish/test.rs | 8 +++----- src/cmap/test/integration.rs | 2 +- src/sdam/description/topology/mod.rs | 8 ++++---- src/sdam/state/mod.rs | 4 ++-- 6 files changed, 16 insertions(+), 21 deletions(-) diff --git a/src/client/executor.rs b/src/client/executor.rs index c48ccce8a..f6b8c16e4 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -303,7 +303,7 @@ impl Client { self.inner .topology .update_command_with_read_pref(connection.address(), &mut cmd, op.selection_criteria()) - .await?; + .await; match session { Some(ref mut session) if op.supports_sessions() && op.is_acknowledged() => { @@ -330,13 +330,13 @@ impl Client { labels, )); } - cmd.set_snapshot_read_concern(session)?; + cmd.set_snapshot_read_concern(session); } match session.transaction.state { TransactionState::Starting => { cmd.set_start_transaction(); cmd.set_autocommit(); - cmd.set_txn_read_concern(*session)?; + cmd.set_txn_read_concern(*session); if is_sharded { session.pin_mongos(connection.address().clone()); } diff --git a/src/cmap/conn/command.rs b/src/cmap/conn/command.rs index 0deb6de93..c6f4c83d3 100644 --- a/src/cmap/conn/command.rs +++ b/src/cmap/conn/command.rs @@ -93,9 +93,8 @@ impl Command { self.server_api = Some(server_api.clone()); } - pub(crate) fn set_read_preference(&mut self, read_preference: ReadPreference) -> Result<()> { + pub(crate) fn set_read_preference(&mut self, read_preference: ReadPreference) { self.read_preference = Some(read_preference); - Ok(()) } pub(crate) fn set_start_transaction(&mut self) { @@ -106,20 +105,18 @@ impl Command { self.autocommit = Some(false); } - pub(crate) fn set_txn_read_concern(&mut self, session: &ClientSession) -> Result<()> { + pub(crate) fn set_txn_read_concern(&mut self, session: &ClientSession) { if let Some(ref options) = session.transaction.options { if let Some(ref read_concern) = options.read_concern { self.read_concern = Some(read_concern.clone()); } } - Ok(()) } - pub(crate) fn set_snapshot_read_concern(&mut self, session: &ClientSession) -> Result<()> { + pub(crate) fn set_snapshot_read_concern(&mut self, session: &ClientSession) { let mut concern = ReadConcern::snapshot(); concern.at_cluster_time = session.snapshot_time; self.read_concern = Some(concern); - Ok(()) } } diff --git a/src/cmap/establish/test.rs b/src/cmap/establish/test.rs index 544caae34..cd209ab89 100644 --- a/src/cmap/establish/test.rs +++ b/src/cmap/establish/test.rs @@ -65,11 +65,9 @@ async fn speculative_auth_test( authorized_db_name.into(), doc! { "find": "foo", "limit": 1 }, ); - command - .set_read_preference(ReadPreference::PrimaryPreferred { - options: Default::default(), - }) - .unwrap(); + command.set_read_preference(ReadPreference::PrimaryPreferred { + options: Default::default(), + }); let response = conn.send_command(command, None).await.unwrap(); let doc_response = response.into_document_response().unwrap(); diff --git a/src/cmap/test/integration.rs b/src/cmap/test/integration.rs index 680aff1ac..339ac1ac8 100644 --- a/src/cmap/test/integration.rs +++ b/src/cmap/test/integration.rs @@ -49,7 +49,7 @@ async fn acquire_connection_and_send_command() { options: Default::default(), }; let mut cmd = Command::new("listDatabases".to_string(), "admin".to_string(), body); - cmd.set_read_preference(read_pref).unwrap(); + cmd.set_read_preference(read_pref); if let Some(server_api) = client_options.server_api.as_ref() { cmd.set_server_api(server_api); } diff --git a/src/sdam/description/topology/mod.rs b/src/sdam/description/topology/mod.rs index bfbf2b08b..aaf90e2de 100644 --- a/src/sdam/description/topology/mod.rs +++ b/src/sdam/description/topology/mod.rs @@ -158,13 +158,13 @@ impl TopologyDescription { server_type: ServerType, command: &mut Command, criteria: Option<&SelectionCriteria>, - ) -> crate::error::Result<()> { + ) { match (self.topology_type, server_type) { (TopologyType::Sharded, ServerType::Mongos) | (TopologyType::Single, ServerType::Mongos) => { self.update_command_read_pref_for_mongos(command, criteria) } - (TopologyType::Single, ServerType::Standalone) => Ok(()), + (TopologyType::Single, ServerType::Standalone) => {} (TopologyType::Single, _) => { let specified_read_pref = criteria .and_then(SelectionCriteria::as_read_pref) @@ -196,7 +196,7 @@ impl TopologyDescription { &self, command: &mut Command, criteria: Option<&SelectionCriteria>, - ) -> crate::error::Result<()> { + ) { match criteria { Some(SelectionCriteria::ReadPreference(ReadPreference::Secondary { ref options })) => { command.set_read_preference(ReadPreference::Secondary { @@ -219,7 +219,7 @@ impl TopologyDescription { options: options.clone(), }) } - _ => Ok(()), + _ => {} } } diff --git a/src/sdam/state/mod.rs b/src/sdam/state/mod.rs index ed2d67f99..c35af1ed4 100644 --- a/src/sdam/state/mod.rs +++ b/src/sdam/state/mod.rs @@ -390,7 +390,7 @@ impl Topology { server_address: &ServerAddress, command: &mut Command, criteria: Option<&SelectionCriteria>, - ) -> Result<()> { + ) { self.state .read() .await @@ -494,7 +494,7 @@ impl TopologyState { server_address: &ServerAddress, command: &mut Command, criteria: Option<&SelectionCriteria>, - ) -> Result<()> { + ) { let server_type = self .description .get_server_description(server_address) From a3c17c8149391a2299c41f8f3d0627129f66f087 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Wed, 4 Aug 2021 18:30:07 -0400 Subject: [PATCH 08/10] use `Command` instead of `Command` where possible --- src/operation/abort_transaction/mod.rs | 2 +- src/operation/aggregate/mod.rs | 2 +- src/operation/commit_transaction/mod.rs | 2 +- src/operation/count/mod.rs | 2 +- src/operation/count_documents/mod.rs | 2 +- src/operation/create/mod.rs | 2 +- src/operation/delete/mod.rs | 2 +- src/operation/distinct/mod.rs | 2 +- src/operation/drop_collection/mod.rs | 2 +- src/operation/drop_database/mod.rs | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/operation/abort_transaction/mod.rs b/src/operation/abort_transaction/mod.rs index bad53b609..79be50040 100644 --- a/src/operation/abort_transaction/mod.rs +++ b/src/operation/abort_transaction/mod.rs @@ -35,7 +35,7 @@ impl Operation for AbortTransaction { const NAME: &'static str = "abortTransaction"; - fn build(&mut self, _description: &StreamDescription) -> Result> { + fn build(&mut self, _description: &StreamDescription) -> Result { let mut body = doc! { Self::NAME: 1, }; diff --git a/src/operation/aggregate/mod.rs b/src/operation/aggregate/mod.rs index 3e79d30e4..a5da15d57 100644 --- a/src/operation/aggregate/mod.rs +++ b/src/operation/aggregate/mod.rs @@ -47,7 +47,7 @@ impl Operation for Aggregate { const NAME: &'static str = "aggregate"; - fn build(&mut self, _description: &StreamDescription) -> Result> { + fn build(&mut self, _description: &StreamDescription) -> Result { let mut body = doc! { Self::NAME: self.target.to_bson(), "pipeline": bson_util::to_bson_array(&self.pipeline), diff --git a/src/operation/commit_transaction/mod.rs b/src/operation/commit_transaction/mod.rs index 908189791..a5bd1c659 100644 --- a/src/operation/commit_transaction/mod.rs +++ b/src/operation/commit_transaction/mod.rs @@ -28,7 +28,7 @@ impl Operation for CommitTransaction { const NAME: &'static str = "commitTransaction"; - fn build(&mut self, _description: &StreamDescription) -> Result> { + fn build(&mut self, _description: &StreamDescription) -> Result { let mut body = doc! { Self::NAME: 1, }; diff --git a/src/operation/count/mod.rs b/src/operation/count/mod.rs index 799bae20a..f1a0e6ca2 100644 --- a/src/operation/count/mod.rs +++ b/src/operation/count/mod.rs @@ -46,7 +46,7 @@ impl Operation for Count { const NAME: &'static str = "count"; - fn build(&mut self, description: &StreamDescription) -> Result> { + fn build(&mut self, description: &StreamDescription) -> Result { let mut body = match description.max_wire_version { Some(v) if v >= SERVER_4_9_0_WIRE_VERSION => { doc! { diff --git a/src/operation/count_documents/mod.rs b/src/operation/count_documents/mod.rs index ec94d3667..63b98a985 100644 --- a/src/operation/count_documents/mod.rs +++ b/src/operation/count_documents/mod.rs @@ -82,7 +82,7 @@ impl Operation for CountDocuments { const NAME: &'static str = Aggregate::NAME; - fn build(&mut self, description: &StreamDescription) -> Result> { + fn build(&mut self, description: &StreamDescription) -> Result { self.aggregate.build(description) } diff --git a/src/operation/create/mod.rs b/src/operation/create/mod.rs index 820dce6e9..c17928505 100644 --- a/src/operation/create/mod.rs +++ b/src/operation/create/mod.rs @@ -44,7 +44,7 @@ impl Operation for Create { const NAME: &'static str = "create"; - fn build(&mut self, _description: &StreamDescription) -> Result> { + fn build(&mut self, _description: &StreamDescription) -> Result { let mut body = doc! { Self::NAME: self.ns.coll.clone(), }; diff --git a/src/operation/delete/mod.rs b/src/operation/delete/mod.rs index 8e42570f1..4408f3098 100644 --- a/src/operation/delete/mod.rs +++ b/src/operation/delete/mod.rs @@ -62,7 +62,7 @@ impl Operation for Delete { const NAME: &'static str = "delete"; - fn build(&mut self, _description: &StreamDescription) -> Result> { + fn build(&mut self, _description: &StreamDescription) -> Result { let mut delete = doc! { "q": self.filter.clone(), "limit": self.limit, diff --git a/src/operation/distinct/mod.rs b/src/operation/distinct/mod.rs index 4ba56a768..a48398ab9 100644 --- a/src/operation/distinct/mod.rs +++ b/src/operation/distinct/mod.rs @@ -57,7 +57,7 @@ impl Operation for Distinct { const NAME: &'static str = "distinct"; - fn build(&mut self, _description: &StreamDescription) -> Result> { + fn build(&mut self, _description: &StreamDescription) -> Result { let mut body: Document = doc! { Self::NAME: self.ns.coll.clone(), "key": self.field_name.clone(), diff --git a/src/operation/drop_collection/mod.rs b/src/operation/drop_collection/mod.rs index ff22c693b..3d734da11 100644 --- a/src/operation/drop_collection/mod.rs +++ b/src/operation/drop_collection/mod.rs @@ -44,7 +44,7 @@ impl Operation for DropCollection { const NAME: &'static str = "drop"; - fn build(&mut self, _description: &StreamDescription) -> Result> { + fn build(&mut self, _description: &StreamDescription) -> Result { let mut body = doc! { Self::NAME: self.ns.coll.clone(), }; diff --git a/src/operation/drop_database/mod.rs b/src/operation/drop_database/mod.rs index a44dc9727..6341d1711 100644 --- a/src/operation/drop_database/mod.rs +++ b/src/operation/drop_database/mod.rs @@ -37,7 +37,7 @@ impl Operation for DropDatabase { const NAME: &'static str = "dropDatabase"; - fn build(&mut self, _description: &StreamDescription) -> Result> { + fn build(&mut self, _description: &StreamDescription) -> Result { let mut body = doc! { Self::NAME: 1, }; From 8a1a93a8a6077ef14d6571c81e887c883b281be7 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Wed, 4 Aug 2021 18:33:01 -0400 Subject: [PATCH 09/10] add TODO --- src/operation/insert/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operation/insert/mod.rs b/src/operation/insert/mod.rs index 49473d60f..76eb5cdc5 100644 --- a/src/operation/insert/mod.rs +++ b/src/operation/insert/mod.rs @@ -71,6 +71,7 @@ impl<'a, T: Serialize> Operation for Insert<'a, T> { let id = match bson_util::raw_get(doc.as_slice(), "_id")? { Some(b) => b, None => { + // TODO: RUST-924 Use raw document API here instead. let oid = ObjectId::new(); // write element to temporary buffer From 15db10dcb2505a6783713173c10f1774653fa1cc Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 5 Aug 2021 14:14:44 -0400 Subject: [PATCH 10/10] clearer name of deserialize function --- src/test/spec/unified_runner/test_file.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/spec/unified_runner/test_file.rs b/src/test/spec/unified_runner/test_file.rs index d5350a488..1795f63b9 100644 --- a/src/test/spec/unified_runner/test_file.rs +++ b/src/test/spec/unified_runner/test_file.rs @@ -147,7 +147,7 @@ pub struct Client { pub ignore_command_monitoring_events: Option>, #[serde(default)] pub observe_sensitive_commands: Option, - #[serde(default, deserialize_with = "deserialize_server_api")] + #[serde(default, deserialize_with = "deserialize_server_api_test_format")] pub server_api: Option, } @@ -155,7 +155,7 @@ fn default_uri() -> String { DEFAULT_URI.clone() } -pub fn deserialize_server_api<'de, D>( +pub fn deserialize_server_api_test_format<'de, D>( deserializer: D, ) -> std::result::Result, D::Error> where