diff --git a/.evergreen/config.yml b/.evergreen/config.yml index c90472dde..c653b8cc7 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -1182,9 +1182,9 @@ axes: - id: "extra-rust-versions" values: - id: "min" - display_name: "1.47 (minimum supported version)" + display_name: "1.48 (minimum supported version)" variables: - RUST_VERSION: "1.47.0" + RUST_VERSION: "1.48.0" - id: "nightly" display_name: "nightly" variables: diff --git a/Cargo.toml b/Cargo.toml index da31288ab..bdba14823 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ bson-uuid-0_8 = ["bson/uuid-0_8"] async-trait = "0.1.42" base64 = "0.13.0" bitflags = "1.1.0" -bson = "2.0.0-beta.2" +bson = { git = "https://github.com/mongodb/bson-rust" } chrono = "0.4.7" derivative = "2.1.1" futures-core = "0.3.14" diff --git a/README.md b/README.md index f1b3f408e..a5dd14125 100644 --- a/README.md +++ b/README.md @@ -27,13 +27,9 @@ This repository contains the officially supported MongoDB Rust driver, a client ## Installation ### Requirements -- Rust 1.47+ +- Rust 1.48+ - MongoDB 3.6+ -**Note**: A bug affecting Rust 1.46-1.47 may cause out-of-memory errors when compiling an application that uses the 1.1 -version of the driver with a framework like actix-web. Upgrading Rust to 1.48+ or the driver to 1.2+ fixes this -issue. For more information, see https://github.com/rust-lang/rust/issues/75992. - ### Importing The driver is available on [crates.io](https://crates.io/crates/mongodb). To use the driver in your application, simply add it to your project's `Cargo.toml`. ```toml diff --git a/src/bson_util/async_encoding.rs b/src/bson_util/async_encoding.rs deleted file mode 100644 index 61b3def2b..000000000 --- a/src/bson_util/async_encoding.rs +++ /dev/null @@ -1,38 +0,0 @@ -use futures_io::{AsyncRead, AsyncWrite}; -use futures_util::{AsyncReadExt, AsyncWriteExt}; - -use crate::{ - bson::Document, - error::Result, - runtime::{AsyncLittleEndianRead, AsyncLittleEndianWrite}, -}; - -pub(crate) async fn decode_document( - reader: &mut R, -) -> Result { - let length = reader.read_i32().await?; - - let mut bytes = Vec::new(); - bytes.write_i32(length).await?; - - reader - .take(length as u64 - 4) - .read_to_end(&mut bytes) - .await?; - - let document = Document::from_reader(&mut bytes.as_slice())?; - Ok(document) -} - -pub(crate) async fn encode_document( - writer: &mut W, - document: &Document, -) -> Result<()> { - let mut bytes = Vec::new(); - - document.to_writer(&mut bytes)?; - - writer.write_all(&bytes).await?; - - Ok(()) -} diff --git a/src/bson_util/mod.rs b/src/bson_util/mod.rs index e0c371faf..705db73b0 100644 --- a/src/bson_util/mod.rs +++ b/src/bson_util/mod.rs @@ -1,12 +1,11 @@ -pub(crate) mod async_encoding; - -use std::{convert::TryFrom, time::Duration}; +use std::{convert::TryFrom, io::Read, time::Duration}; use serde::{de::Error, ser, Deserialize, Deserializer, Serialize, Serializer}; use crate::{ bson::{doc, Binary, Bson, Document, JavaScriptCodeWithScope, Regex}, error::{ErrorKind, Result}, + runtime::{SyncLittleEndianRead, SyncLittleEndianWrite}, }; /// Coerce numeric types into an `i64` if it would be lossless to do so. If this Bson is not numeric @@ -289,6 +288,26 @@ fn num_decimal_digits(n: usize) -> u64 { digits } +/// Read a document's raw BSON bytes from the provided reader. +pub(crate) fn read_document_bytes(mut reader: R) -> Result> { + let length = reader.read_i32()?; + + let mut bytes = Vec::with_capacity(length as usize); + bytes.write_i32(length)?; + + reader.take(length as u64 - 4).read_to_end(&mut bytes)?; + + Ok(bytes) +} + +/// Serialize the document to raw BSON and return a vec containing the bytes. +#[cfg(test)] +pub(crate) fn document_to_vec(doc: Document) -> Result> { + let mut v = Vec::new(); + doc.to_writer(&mut v)?; + Ok(v) +} + #[cfg(test)] mod test { use crate::bson::{ diff --git a/src/client/auth/aws.rs b/src/client/auth/aws.rs index f3d427e62..5ade910ce 100644 --- a/src/client/auth/aws.rs +++ b/src/client/auth/aws.rs @@ -61,7 +61,8 @@ pub(super) async fn authenticate_stream( let server_first_response = conn.send_command(client_first, None).await?; - let server_first = ServerFirst::parse(server_first_response.raw_response)?; + let server_first = + ServerFirst::parse(server_first_response.auth_response_body("MONGODB-AWS")?)?; server_first.validate(&nonce)?; let aws_credential = AwsCredential::get(credential, http_client).await?; @@ -96,7 +97,10 @@ pub(super) async fn authenticate_stream( let client_second = sasl_continue.into_command(); let server_second_response = conn.send_command(client_second, None).await?; - let server_second = SaslResponse::parse("MONGODB-AWS", server_second_response.raw_response)?; + let server_second = SaslResponse::parse( + "MONGODB-AWS", + server_second_response.auth_response_body("MONGODB-AWS")?, + )?; if server_second.conversation_id != server_first.conversation_id { return Err(Error::invalid_authentication_response("MONGODB-AWS")); @@ -373,12 +377,10 @@ impl ServerFirst { done, } = SaslResponse::parse("MONGODB-AWS", response)?; - let payload_document = Document::from_reader(&mut payload.as_slice())?; - let ServerFirstPayload { server_nonce, sts_host, - } = bson::from_bson(Bson::Document(payload_document)) + } = bson::from_slice(payload.as_slice()) .map_err(|_| Error::invalid_authentication_response("MONGODB-AWS"))?; Ok(Self { diff --git a/src/client/auth/plain.rs b/src/client/auth/plain.rs index 3dd52503c..081659bbd 100644 --- a/src/client/auth/plain.rs +++ b/src/client/auth/plain.rs @@ -36,7 +36,7 @@ pub(crate) async fn authenticate_stream( .into_command(); let response = conn.send_command(sasl_start, None).await?; - let sasl_response = SaslResponse::parse("PLAIN", response.raw_response)?; + let sasl_response = SaslResponse::parse("PLAIN", response.auth_response_body("PLAIN")?)?; if !sasl_response.done { return Err(Error::invalid_authentication_response("PLAIN")); diff --git a/src/client/auth/scram.rs b/src/client/auth/scram.rs index f2cefa50f..709f2d9b1 100644 --- a/src/client/auth/scram.rs +++ b/src/client/auth/scram.rs @@ -154,7 +154,7 @@ impl ScramVersion { Ok(FirstRound { client_first, - server_first: server_first.raw_response, + server_first: server_first.auth_response_body("SCRAM")?, }) } @@ -215,7 +215,7 @@ impl ScramVersion { let command = client_final.to_command(); let server_final_response = conn.send_command(command, None).await?; - let server_final = ServerFinal::parse(server_final_response.raw_response)?; + let server_final = ServerFinal::parse(server_final_response.auth_response_body("SCRAM")?)?; server_final.validate(salted_password.as_slice(), &client_final, self)?; if !server_final.done { @@ -231,9 +231,10 @@ impl ScramVersion { let command = noop.into_command(); let server_noop_response = conn.send_command(command, None).await?; + let server_noop_response_document: Document = + server_noop_response.auth_response_body("SCRAM")?; - if server_noop_response - .raw_response + if server_noop_response_document .get("conversationId") .map(|id| id == server_final.conversation_id()) != Some(true) @@ -244,8 +245,7 @@ impl ScramVersion { )); }; - if !server_noop_response - .raw_response + if !server_noop_response_document .get_bool("done") .unwrap_or(false) { diff --git a/src/client/auth/x509.rs b/src/client/auth/x509.rs index c4df52804..69ee13d70 100644 --- a/src/client/auth/x509.rs +++ b/src/client/auth/x509.rs @@ -1,7 +1,7 @@ use crate::{ bson::{doc, Document}, client::options::ServerApi, - cmap::{Command, CommandResponse, Connection}, + cmap::{Command, Connection, RawCommandResponse}, error::{Error, Result}, options::Credential, }; @@ -38,7 +38,7 @@ pub(crate) async fn send_client_first( conn: &mut Connection, credential: &Credential, server_api: Option<&ServerApi>, -) -> Result { +) -> Result { let command = build_client_first(credential, server_api); conn.send_command(command, None).await @@ -53,11 +53,9 @@ pub(super) async fn authenticate_stream( ) -> Result<()> { let server_response = match server_first.into() { Some(server_first) => server_first, - None => { - send_client_first(conn, credential, server_api) - .await? - .raw_response - } + None => send_client_first(conn, credential, server_api) + .await? + .auth_response_body("MONGODB-X509")?, }; if server_response.get_str("dbname") != Ok("$external") { diff --git a/src/client/executor.rs b/src/client/executor.rs index a9abc230e..0ba0d6e4b 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -1,13 +1,12 @@ -use super::{session::TransactionState, Client, ClientSession}; - -use std::{collections::HashSet, sync::Arc}; - +use bson::doc; use lazy_static::lazy_static; -use std::time::Instant; +use std::{collections::HashSet, sync::Arc, time::Instant}; + +use super::{session::TransactionState, Client, ClientSession}; use crate::{ bson::Document, - cmap::Connection, + cmap::{Connection, RawCommandResponse}, error::{ Error, ErrorKind, @@ -17,7 +16,15 @@ use crate::{ UNKNOWN_TRANSACTION_COMMIT_RESULT, }, event::command::{CommandFailedEvent, CommandStartedEvent, CommandSucceededEvent}, - operation::{AbortTransaction, CommitTransaction, Operation, Retryability}, + operation::{ + AbortTransaction, + CommandErrorBody, + CommandResponse, + CommitTransaction, + Operation, + Response, + Retryability, + }, options::SelectionCriteria, sdam::{HandshakePhase, SelectedServer, SessionSupportStatus, TransactionSupportStatus}, selection_criteria::ReadPreference, @@ -374,27 +381,72 @@ impl Client { let start_time = Instant::now(); let cmd_name = cmd.name.clone(); - let response_result = match connection.send_command(cmd, request_id).await { + let command_result = match connection.send_command(cmd, request_id).await { Ok(response) => { - if let Some(cluster_time) = response.cluster_time() { - self.inner.topology.advance_cluster_time(cluster_time).await; - if let Some(ref mut session) = session { - session.advance_cluster_time(cluster_time) + match T::Response::deserialize_response(&response) { + Ok(r) => { + self.update_cluster_time(&r, session).await; + if r.is_success() { + Ok(CommandResult { + raw: response, + deserialized: r.into_body(), + }) + } else { + // if command was ok: 0, try to deserialize the command error. + // if that fails, return a generic error. + Err(response + .body::() + .map(|error_response| error_response.into()) + .unwrap_or_else(|e| { + Error::from(ErrorKind::InvalidResponse { + message: format!( + "error deserializing command error: {}", + e + ), + }) + })) + } + } + Err(deserialize_error) => { + // if we failed to deserialize the whole response, try deserializing + // a generic command response without the operation's body. + match response.body::>>() { + Ok(error_response) => { + self.update_cluster_time(&error_response, session).await; + match error_response.body { + // if the response was ok: 0, return the command error. + Some(command_error_response) + if !error_response.is_success() => + { + Err(command_error_response.into()) + } + // if the response was ok: 0 but we couldnt deserialize the + // command error, + // return a generic error indicating so. + None if !error_response.is_success() => { + Err(Error::from(ErrorKind::InvalidResponse { + message: "got command error but failed to deserialize \ + response" + .to_string(), + })) + } + // for ok: 1 just return the original deserialization error. + _ => Err(deserialize_error), + } + } + // We failed to deserialize even that, so just return the original + // deserialization error. + Err(_) => Err(deserialize_error), + } } } - if let (Some(timestamp), Some(session)) = - (response.snapshot_time(), session.as_mut()) - { - session.snapshot_time = Some(*timestamp); - } - response.validate().map(|_| response) } - err => err, + Err(err) => Err(err), }; let duration = start_time.elapsed(); - match response_result { + match command_result { Err(mut err) => { self.emit_command_event(|handler| { let command_failed_event = CommandFailedEvent { @@ -424,7 +476,10 @@ impl Client { let reply = if should_redact { Document::new() } else { - response.raw_response.clone() + response + .raw + .body() + .unwrap_or_else(|_| doc! { "error": "failed to deserialize" }) }; let command_succeeded_event = CommandSucceededEvent { @@ -437,7 +492,7 @@ impl Client { handler.handle_command_succeeded_event(command_succeeded_event); }); - match op.handle_response(response, connection.stream_description()?) { + match op.handle_response(response.deserialized, connection.stream_description()?) { Ok(response) => Ok(response), Err(mut err) => { err.add_labels(Some(connection), session, Some(retryability))?; @@ -532,6 +587,25 @@ impl Client { } Ok(Retryability::None) } + + async fn update_cluster_time( + &self, + command_response: &T, + session: &mut Option<&mut ClientSession>, + ) { + if let Some(cluster_time) = command_response.cluster_time() { + self.inner.topology.advance_cluster_time(cluster_time).await; + if let Some(ref mut session) = session { + session.advance_cluster_time(cluster_time) + } + } + + if let Some(timestamp) = command_response.at_cluster_time() { + if let Some(ref mut session) = session { + session.snapshot_time = Some(timestamp); + } + } + } } impl Error { @@ -595,3 +669,8 @@ impl Error { Ok(()) } } + +struct CommandResult { + raw: RawCommandResponse, + deserialized: T, +} diff --git a/src/client/options/mod.rs b/src/client/options/mod.rs index 4a592a834..d7e54e1be 100644 --- a/src/client/options/mod.rs +++ b/src/client/options/mod.rs @@ -390,6 +390,8 @@ pub struct ClientOptions { /// The handler that should process all command-related events. See the CommandEventHandler /// type documentation for more details. + /// + /// Note that monitoring command events may incur a performance penalty. #[derivative(Debug = "ignore", PartialEq = "ignore")] #[builder(default)] #[serde(skip)] diff --git a/src/client/session/cluster_time.rs b/src/client/session/cluster_time.rs index 50878f7e6..1a1856df9 100644 --- a/src/client/session/cluster_time.rs +++ b/src/client/session/cluster_time.rs @@ -11,10 +11,10 @@ use crate::bson::{Document, Timestamp}; #[derivative(PartialEq, Eq)] #[serde(rename_all = "camelCase")] pub struct ClusterTime { - cluster_time: Timestamp, + pub(crate) cluster_time: Timestamp, #[derivative(PartialEq = "ignore")] - signature: Document, + pub(crate) signature: Document, } impl std::cmp::Ord for ClusterTime { diff --git a/src/cmap/conn/command.rs b/src/cmap/conn/command.rs index 9901d442c..31b300e3f 100644 --- a/src/cmap/conn/command.rs +++ b/src/cmap/conn/command.rs @@ -1,13 +1,12 @@ -use serde::{de::DeserializeOwned, Deserialize}; +pub(crate) use serde::de::DeserializeOwned; use super::wire::Message; use crate::{ - bson::{Bson, Document, Timestamp}, - bson_util, + bson::Document, client::{options::ServerApi, ClusterTime}, - concern::ReadConcern, - error::{CommandError, Error, ErrorKind, Result}, - options::ServerAddress, + error::{Error, ErrorKind, Result}, + operation::{CommandErrorBody, CommandResponse}, + options::{ReadConcern, ServerAddress}, selection_criteria::ReadPreference, ClientSession, }; @@ -96,27 +95,22 @@ impl Command { } #[derive(Debug, Clone)] -pub(crate) struct CommandResponse { - source: ServerAddress, - pub(crate) raw_response: Document, - cluster_time: Option, - snapshot_time: Option, +pub(crate) struct RawCommandResponse { + pub(crate) source: ServerAddress, + raw: Vec, } -impl CommandResponse { +impl RawCommandResponse { #[cfg(test)] - pub(crate) fn with_document_and_address(source: ServerAddress, doc: Document) -> Self { - Self { - source, - raw_response: doc, - cluster_time: None, - snapshot_time: None, - } + pub(crate) fn with_document_and_address(source: ServerAddress, doc: Document) -> Result { + let mut raw = Vec::new(); + doc.to_writer(&mut raw)?; + Ok(Self { source, raw }) } /// Initialize a response from a document. #[cfg(test)] - pub(crate) fn with_document(doc: Document) -> Self { + pub(crate) fn with_document(doc: Document) -> Result { Self::with_document_and_address( ServerAddress::Tcp { host: "localhost".to_string(), @@ -127,44 +121,52 @@ impl CommandResponse { } pub(crate) fn new(source: ServerAddress, message: Message) -> Result { - let raw_response = message.single_document_response()?; - let cluster_time = raw_response - .get("$clusterTime") - .and_then(|subdoc| bson::from_bson(subdoc.clone()).ok()); - let snapshot_time = raw_response - .get("atClusterTime") - .or_else(|| { - raw_response - .get("cursor") - .and_then(|b| b.as_document()) - .and_then(|subdoc| subdoc.get("atClusterTime")) - }) - .and_then(|subdoc| bson::from_bson(subdoc.clone()).ok()); + let raw = message.single_document_response()?; + Ok(Self { source, raw }) + } - Ok(Self { - source, - raw_response, - cluster_time, - snapshot_time, + pub(crate) fn body(&self) -> Result { + bson::from_slice(self.raw.as_slice()).map_err(|e| { + Error::from(ErrorKind::InvalidResponse { + message: format!("{}", e), + }) }) } - /// Returns whether this response indicates a success or not (i.e. if "ok: 1") - pub(crate) fn is_success(&self) -> bool { - match self.raw_response.get("ok") { - Some(b) => bson_util::get_int(b) == Some(1), - _ => false, - } + /// Deserialize the body of this response, returning an authentication error if it fails. + pub(crate) fn auth_response_body( + &self, + mechanism_name: &str, + ) -> Result { + self.body() + .map_err(|_| Error::invalid_authentication_response(mechanism_name)) } + /// Deserialize the raw bytes into a response backed by a `Document` for further processing. + pub(crate) fn into_document_response(self) -> Result { + let response: CommandResponse = self.body()?; + Ok(DocumentCommandResponse { response }) + } + + /// The address of the server that sent this response. + pub(crate) fn source_address(&self) -> &ServerAddress { + &self.source + } +} + +/// A command response backed by a `Document` rather than raw bytes. +/// Use this for simple command responses where deserialization performance is not a high priority. +pub(crate) struct DocumentCommandResponse { + response: CommandResponse, +} + +impl DocumentCommandResponse { /// Returns a result indicating whether this response corresponds to a command failure. pub(crate) fn validate(&self) -> Result<()> { - if !self.is_success() { - let error_response: CommandErrorResponse = - bson::from_bson(Bson::Document(self.raw_response.clone())).map_err(|_| { - ErrorKind::InvalidResponse { - message: "invalid server response".to_string(), - } + if !self.response.is_success() { + let error_response: CommandErrorBody = bson::from_document(self.response.body.clone()) + .map_err(|_| ErrorKind::InvalidResponse { + message: "invalid server response".to_string(), })?; Err(Error::new( ErrorKind::Command(error_response.command_error), @@ -177,7 +179,7 @@ impl CommandResponse { /// Deserialize the body of the response. pub(crate) fn body(self) -> Result { - match bson::from_document(self.raw_response) { + match bson::from_document(self.response.body) { Ok(body) => Ok(body), Err(e) => Err(ErrorKind::InvalidResponse { message: format!("{}", e), @@ -186,27 +188,7 @@ impl CommandResponse { } } - /// Gets the cluster time from the response, if any. pub(crate) fn cluster_time(&self) -> Option<&ClusterTime> { - self.cluster_time.as_ref() + self.response.cluster_time.as_ref() } - - /// Gets the snapshot time from the response, if any. - pub(crate) fn snapshot_time(&self) -> Option<&Timestamp> { - self.snapshot_time.as_ref() - } - - /// The address of the server that sent this response. - pub(crate) fn source_address(&self) -> &ServerAddress { - &self.source - } -} - -#[derive(Deserialize, Debug)] -struct CommandErrorResponse { - #[serde(rename = "errorLabels")] - error_labels: Option>, - - #[serde(flatten)] - command_error: CommandError, } diff --git a/src/cmap/conn/mod.rs b/src/cmap/conn/mod.rs index aa48645e1..f3bfd6aa4 100644 --- a/src/cmap/conn/mod.rs +++ b/src/cmap/conn/mod.rs @@ -26,7 +26,7 @@ use crate::{ options::{ServerAddress, TlsOptions}, runtime::AsyncStream, }; -pub(crate) use command::{Command, CommandResponse}; +pub(crate) use command::{Command, RawCommandResponse}; pub(crate) use stream_description::StreamDescription; pub(crate) use wire::next_request_id; @@ -238,8 +238,8 @@ impl Connection { &mut self, command: Command, request_id: impl Into>, - ) -> Result { - let message = Message::with_command(command, request_id.into()); + ) -> Result { + let message = Message::with_command(command, request_id.into())?; self.command_executing = true; let write_result = message.write_to(&mut self.stream).await; @@ -250,7 +250,7 @@ impl Connection { self.command_executing = false; self.error = response_message_result.is_err(); - CommandResponse::new(self.address.clone(), response_message_result?) + RawCommandResponse::new(self.address.clone(), response_message_result?) } /// Gets the connection's StreamDescription. diff --git a/src/cmap/conn/stream_description.rs b/src/cmap/conn/stream_description.rs index 7b006892e..14b7e06a1 100644 --- a/src/cmap/conn/stream_description.rs +++ b/src/cmap/conn/stream_description.rs @@ -1,10 +1,13 @@ use std::time::Duration; -use crate::{is_master::IsMasterReply, sdam::ServerType}; +use crate::{client::options::ServerAddress, is_master::IsMasterReply, sdam::ServerType}; /// Contains information about a given server in a format digestible by a connection. #[derive(Debug, Default, Clone)] pub(crate) struct StreamDescription { + /// The address of the server. + pub(crate) server_address: ServerAddress, + /// The type of the server when the handshake occurred. pub(crate) initial_server_type: ServerType, @@ -34,6 +37,7 @@ impl StreamDescription { /// Constructs a new StreamDescription from an IsMasterReply. pub(crate) fn from_is_master(reply: IsMasterReply) -> Self { Self { + server_address: reply.server_address, initial_server_type: reply.command_response.server_type(), max_wire_version: reply.command_response.max_wire_version, min_wire_version: reply.command_response.min_wire_version, @@ -58,10 +62,17 @@ impl StreamDescription { /// Gets a description of a stream for a 4.2 connection. #[cfg(test)] pub(crate) fn new_testing() -> Self { + Self::with_wire_version(8) + } + + /// Gets a description of a stream for a 4.2 connection. + #[cfg(test)] + pub(crate) fn with_wire_version(max_wire_version: i32) -> Self { Self { + server_address: Default::default(), initial_server_type: Default::default(), - max_wire_version: Some(8), - min_wire_version: Some(8), + max_wire_version: Some(max_wire_version), + min_wire_version: Some(max_wire_version), sasl_supported_mechs: Default::default(), logical_session_timeout: Some(Duration::from_secs(30 * 60)), max_bson_object_size: 16 * 1024 * 1024, diff --git a/src/cmap/conn/wire/message.rs b/src/cmap/conn/wire/message.rs index 320b125b5..830c56dbe 100644 --- a/src/cmap/conn/wire/message.rs +++ b/src/cmap/conn/wire/message.rs @@ -1,21 +1,19 @@ +use std::io::Read; + use bitflags::bitflags; -use futures_io::{AsyncRead, AsyncWrite}; +use futures_io::AsyncWrite; use futures_util::{ io::{BufReader, BufWriter}, AsyncReadExt, AsyncWriteExt, }; -use super::{ - header::{Header, OpCode}, - util::CountReader, -}; +use super::header::{Header, OpCode}; use crate::{ - bson::Document, - bson_util::async_encoding, - cmap::conn::command::Command, - error::{ErrorKind, Result}, - runtime::{AsyncLittleEndianRead, AsyncLittleEndianWrite, AsyncStream}, + bson_util, + cmap::conn::{command::Command, wire::util::SyncCountReader}, + error::{Error, ErrorKind, Result}, + runtime::{AsyncLittleEndianWrite, AsyncStream, SyncLittleEndianRead}, }; /// Represents an OP_MSG wire protocol operation. @@ -32,64 +30,60 @@ impl Message { /// Creates a `Message` from a given `Command`. /// /// Note that `response_to` will need to be set manually. - pub(crate) fn with_command(mut command: Command, request_id: Option) -> Self { + pub(crate) fn with_command(mut command: Command, request_id: Option) -> Result { command.body.insert("$db", command.target_db); - Self { + let mut bytes = Vec::new(); + command.body.to_writer(&mut bytes)?; + Ok(Self { response_to: 0, flags: MessageFlags::empty(), - sections: vec![MessageSection::Document(command.body)], + sections: vec![MessageSection::Document(bytes)], checksum: None, request_id, - } + }) } /// Gets the first document contained in this Message. - pub(crate) fn single_document_response(self) -> Result { - self.sections - .into_iter() - .next() - .and_then(|section| match section { - MessageSection::Document(doc) => Some(doc), - MessageSection::Sequence { documents, .. } => documents.into_iter().next(), - }) - .ok_or_else(|| { + pub(crate) fn single_document_response(self) -> Result> { + let section = self.sections.into_iter().next().ok_or_else(|| { + Error::new( ErrorKind::InvalidResponse { message: "no response received from server".into(), - } - .into() - }) - } - - /// Gets all documents contained in this Message flattened to a single Vec. - #[allow(dead_code)] - pub(crate) fn documents(self) -> Vec { - self.sections - .into_iter() - .flat_map(|section| match section { - MessageSection::Document(doc) => vec![doc], - MessageSection::Sequence { documents, .. } => documents, + }, + Option::>::None, + ) + })?; + match section { + MessageSection::Document(doc) => Some(doc), + MessageSection::Sequence { documents, .. } => documents.into_iter().next(), + } + .ok_or_else(|| { + Error::from(ErrorKind::InvalidResponse { + message: "no message received from the server".to_string(), }) - .collect() + }) } /// Reads bytes from `reader` and deserializes them into a Message. pub(crate) async fn read_from(reader: &mut AsyncStream) -> Result { let mut reader = BufReader::new(reader); let header = Header::read_from(&mut reader).await?; + + // TODO: RUST-616 ensure length is < maxMessageSizeBytes let mut length_remaining = header.length - Header::LENGTH as i32; let mut buf = vec![0u8; length_remaining as usize]; reader.read_exact(&mut buf).await?; let mut reader = buf.as_slice(); - let flags = MessageFlags::from_bits_truncate(reader.read_u32().await?); + let flags = MessageFlags::from_bits_truncate(reader.read_u32()?); length_remaining -= std::mem::size_of::() as i32; - let mut count_reader = CountReader::new(&mut reader); + let mut count_reader = SyncCountReader::new(&mut reader); let mut sections = Vec::new(); while length_remaining - count_reader.bytes_read() as i32 > 4 { - sections.push(MessageSection::read(&mut count_reader).await?); + sections.push(MessageSection::read(&mut count_reader)?); } length_remaining -= count_reader.bytes_read() as i32; @@ -97,7 +91,7 @@ impl Message { let mut checksum = None; if length_remaining == 4 && flags.contains(MessageFlags::CHECKSUM_PRESENT) { - checksum = Some(reader.read_u32().await?); + checksum = Some(reader.read_u32()?); } else if length_remaining != 0 { return Err(ErrorKind::InvalidResponse { message: format!( @@ -170,36 +164,36 @@ bitflags! { /// Represents a section as defined by the OP_MSG spec. #[derive(Debug)] pub(crate) enum MessageSection { - Document(Document), + Document(Vec), Sequence { size: i32, identifier: String, - documents: Vec, + documents: Vec>, }, } impl MessageSection { /// Reads bytes from `reader` and deserializes them into a MessageSection. - async fn read(reader: &mut R) -> Result { - let payload_type = reader.read_u8().await?; + fn read(reader: &mut R) -> Result { + let payload_type = reader.read_u8()?; if payload_type == 0 { - return Ok(MessageSection::Document( - async_encoding::decode_document(reader).await?, - )); + return Ok(MessageSection::Document(bson_util::read_document_bytes( + reader, + )?)); } - let size = reader.read_i32().await?; + let size = reader.read_i32()?; let mut length_remaining = size - std::mem::size_of::() as i32; let mut identifier = String::new(); - length_remaining -= reader.read_to_string(&mut identifier).await? as i32; + length_remaining -= reader.read_to_string(&mut identifier)? as i32; let mut documents = Vec::new(); - let mut count_reader = CountReader::new(reader); + let mut count_reader = SyncCountReader::new(reader); while length_remaining > count_reader.bytes_read() as i32 { - documents.push(async_encoding::decode_document(&mut count_reader).await?); + documents.push(bson_util::read_document_bytes(&mut count_reader)?); } if length_remaining != count_reader.bytes_read() as i32 { @@ -227,7 +221,7 @@ impl MessageSection { Self::Document(doc) => { // Write payload type. writer.write_u8(0).await?; - async_encoding::encode_document(writer, doc).await?; + writer.write_all(doc.as_slice()).await?; } Self::Sequence { size, @@ -241,7 +235,7 @@ impl MessageSection { super::util::write_cstring(writer, identifier).await?; for doc in documents { - async_encoding::encode_document(writer, doc).await?; + writer.write_all(doc.as_slice()).await?; } } } diff --git a/src/cmap/conn/wire/test.rs b/src/cmap/conn/wire/test.rs index ab51aa85a..904d2e0ab 100644 --- a/src/cmap/conn/wire/test.rs +++ b/src/cmap/conn/wire/test.rs @@ -1,8 +1,10 @@ +use bson::Document; use tokio::sync::RwLockReadGuard; use super::message::{Message, MessageFlags, MessageSection}; use crate::{ bson::{doc, Bson}, + bson_util, cmap::options::StreamOptions, runtime::AsyncStream, test::{CLIENT_OPTIONS, LOCK}, @@ -21,7 +23,8 @@ async fn basic() { response_to: 0, flags: MessageFlags::empty(), sections: vec![MessageSection::Document( - doc! { "isMaster": 1, "$db": "admin", "apiVersion": "1" }, + bson_util::document_to_vec(doc! { "isMaster": 1, "$db": "admin", "apiVersion": "1" }) + .unwrap(), )], checksum: None, request_id: None, @@ -38,10 +41,11 @@ async fn basic() { let reply = Message::read_from(&mut stream).await.unwrap(); - let response_doc = match reply.sections.into_iter().next().unwrap() { + let response_doc_bytes = match reply.sections.into_iter().next().unwrap() { MessageSection::Document(doc) => doc, MessageSection::Sequence { documents, .. } => documents.into_iter().next().unwrap(), }; + let response_doc: Document = bson::from_slice(&response_doc_bytes.as_slice()).unwrap(); assert_eq!(response_doc.get("ok"), Some(&Bson::Double(1.0))); } diff --git a/src/cmap/conn/wire/util.rs b/src/cmap/conn/wire/util.rs index 38b2e40e0..f549ad104 100644 --- a/src/cmap/conn/wire/util.rs +++ b/src/cmap/conn/wire/util.rs @@ -1,10 +1,9 @@ use std::{ - pin::Pin, + io::Read, sync::atomic::{AtomicI32, Ordering}, - task::{Context, Poll}, }; -use futures_io::{self, AsyncRead, AsyncWrite}; +use futures_io::{self, AsyncWrite}; use futures_util::AsyncWriteExt; use lazy_static::lazy_static; @@ -33,16 +32,15 @@ pub(super) async fn write_cstring( Ok(()) } -/// A wrapper around `futures_io::AsyncRead` that keeps track of the number of bytes it has read. -pub(super) struct CountReader<'a, R: AsyncRead + Unpin + Send + 'a> { - reader: &'a mut R, +pub(super) struct SyncCountReader { + reader: R, bytes_read: usize, } -impl<'a, R: AsyncRead + Unpin + Send + 'a> CountReader<'a, R> { +impl SyncCountReader { /// Constructs a new CountReader that wraps `reader`. - pub(super) fn new(reader: &'a mut R) -> Self { - CountReader { + pub(super) fn new(reader: R) -> Self { + SyncCountReader { reader, bytes_read: 0, } @@ -54,18 +52,10 @@ impl<'a, R: AsyncRead + Unpin + Send + 'a> CountReader<'a, R> { } } -impl<'a, R: AsyncRead + Unpin + Send + 'a> AsyncRead for CountReader<'a, R> { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { - let result = Pin::new(&mut self.reader).poll_read(cx, buf); - - if let Poll::Ready(Ok(count)) = result { - self.bytes_read += count; - } - - result +impl Read for SyncCountReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let bytes = self.reader.read(buf)?; + self.bytes_read += bytes; + Ok(bytes) } } diff --git a/src/cmap/establish/test.rs b/src/cmap/establish/test.rs index 6973e305b..cae6ce545 100644 --- a/src/cmap/establish/test.rs +++ b/src/cmap/establish/test.rs @@ -68,8 +68,11 @@ async fn speculative_auth_test( }); let response = conn.send_command(command, None).await.unwrap(); + let doc_response = response.into_document_response().unwrap(); - assert!(response.is_success()); + doc_response + .validate() + .expect("response should be successful"); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] diff --git a/src/cmap/mod.rs b/src/cmap/mod.rs index cdfb30b9c..6613ce124 100644 --- a/src/cmap/mod.rs +++ b/src/cmap/mod.rs @@ -15,7 +15,7 @@ use derivative::Derivative; pub use self::conn::ConnectionInfo; pub(crate) use self::{ - conn::{Command, CommandResponse, Connection, StreamDescription}, + conn::{Command, Connection, RawCommandResponse, StreamDescription}, establish::handshake::Handshaker, status::PoolGenerationSubscriber, }; diff --git a/src/cmap/test/integration.rs b/src/cmap/test/integration.rs index 6d84b1421..8054e901b 100644 --- a/src/cmap/test/integration.rs +++ b/src/cmap/test/integration.rs @@ -55,10 +55,13 @@ async fn acquire_connection_and_send_command() { } let response = connection.send_command(cmd, None).await.unwrap(); + let doc_response = response.into_document_response().unwrap(); - assert!(response.is_success()); + doc_response + .validate() + .expect("response should be successful"); - let response: ListDatabasesResponse = response.body().unwrap(); + let response: ListDatabasesResponse = doc_response.body().unwrap(); let names: Vec<_> = response .databases diff --git a/src/coll/mod.rs b/src/coll/mod.rs index dfdf2bc8d..8806c4207 100644 --- a/src/coll/mod.rs +++ b/src/coll/mod.rs @@ -561,7 +561,7 @@ impl Collection { impl Collection where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { /// Finds the documents in the collection matching `filter`. pub async fn find( @@ -572,7 +572,7 @@ where let mut options = options.into(); resolve_options!(self, options, [read_concern, selection_criteria]); - let find = Find::new(self.namespace(), filter.into(), options); + let find = Find::::new(self.namespace(), filter.into(), options); let client = self.client(); client @@ -592,7 +592,7 @@ where resolve_read_concern_with_session!(self, options, Some(&mut *session))?; resolve_selection_criteria_with_session!(self, options, Some(&mut *session))?; - let find = Find::new(self.namespace(), filter.into(), options); + let find = Find::::new(self.namespace(), filter.into(), options); let client = self.client(); client diff --git a/src/cursor/common.rs b/src/cursor/common.rs index d520d2f6b..63137b461 100644 --- a/src/cursor/common.rs +++ b/src/cursor/common.rs @@ -7,9 +7,9 @@ use std::{ use derivative::Derivative; use futures_core::{Future, Stream}; +use serde::de::DeserializeOwned; use crate::{ - bson::Document, error::{Error, ErrorKind, Result}, operation, options::ServerAddress, @@ -21,17 +21,24 @@ use crate::{ /// An internal cursor that can be used in a variety of contexts depending on its `GetMoreProvider`. #[derive(Derivative)] #[derivative(Debug)] -pub(super) struct GenericCursor { +pub(super) struct GenericCursor +where + P: GetMoreProvider, +{ #[derivative(Debug = "ignore")] - provider: T, + provider: P, client: Client, info: CursorInformation, - buffer: VecDeque, + buffer: VecDeque, exhausted: bool, } -impl GenericCursor { - pub(super) fn new(client: Client, spec: CursorSpecification, get_more_provider: T) -> Self { +impl GenericCursor +where + P: GetMoreProvider, + T: DeserializeOwned, +{ + pub(super) fn new(client: Client, spec: CursorSpecification, get_more_provider: P) -> Self { let exhausted = spec.id() == 0; Self { exhausted, @@ -42,7 +49,7 @@ impl GenericCursor { } } - pub(super) fn take_buffer(&mut self) -> VecDeque { + pub(super) fn take_buffer(&mut self) -> VecDeque { std::mem::take(&mut self.buffer) } @@ -65,8 +72,12 @@ impl GenericCursor { } } -impl Stream for GenericCursor { - type Item = Result; +impl Stream for GenericCursor +where + P: GetMoreProvider, + T: DeserializeOwned + Unpin, +{ + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { @@ -105,11 +116,14 @@ impl Stream for GenericCursor { /// A trait implemented by objects that can provide batches of documents to a cursor via the getMore /// command. pub(super) trait GetMoreProvider: Unpin { + /// The type that the invididual documents will be deserialized to. + type DocumentType; + /// The result type that the future running the getMore evaluates to. - type GetMoreResult: GetMoreProviderResult; + type ResultType: GetMoreProviderResult; /// The type of future created by this provider when running a getMore. - type GetMoreFuture: Future + Unpin; + type GetMoreFuture: Future + Unpin; /// Get the future being evaluated, if there is one. fn executing_future(&mut self) -> Option<&mut Self::GetMoreFuture>; @@ -117,7 +131,7 @@ pub(super) trait GetMoreProvider: Unpin { /// Clear out any state remaining from previous getMore executions. fn clear_execution( &mut self, - session: ::Session, + session: ::Session, exhausted: bool, ); @@ -128,10 +142,11 @@ pub(super) trait GetMoreProvider: Unpin { /// Trait describing results returned from a `GetMoreProvider`. pub(super) trait GetMoreProviderResult { type Session; + type DocumentType; - fn as_ref(&self) -> std::result::Result<&GetMoreResult, &Error>; + fn as_ref(&self) -> std::result::Result<&GetMoreResult, &Error>; - fn into_parts(self) -> (Result, Self::Session); + fn into_parts(self) -> (Result>, Self::Session); /// Whether the response from the server indicated the cursor was exhausted or not. fn exhausted(&self) -> bool { @@ -146,14 +161,14 @@ pub(super) trait GetMoreProviderResult { /// Specification used to create a new cursor. #[derive(Debug, Clone)] -pub(crate) struct CursorSpecification { +pub(crate) struct CursorSpecification { pub(crate) info: CursorInformation, - pub(crate) initial_buffer: VecDeque, + pub(crate) initial_buffer: VecDeque, } -impl CursorSpecification { +impl CursorSpecification { pub(crate) fn new( - info: operation::CursorInfo, + info: operation::CursorInfo, address: ServerAddress, batch_size: impl Into>, max_time: impl Into>, diff --git a/src/cursor/mod.rs b/src/cursor/mod.rs index 2f2479a6b..b385ed0fa 100644 --- a/src/cursor/mod.rs +++ b/src/cursor/mod.rs @@ -10,7 +10,7 @@ use futures_core::{future::BoxFuture, Stream}; use serde::de::DeserializeOwned; use crate::{ - bson::{from_document, Document}, + bson::Document, error::{Error, Result}, operation::GetMore, results::GetMoreResult, @@ -81,20 +81,20 @@ use common::{GenericCursor, GetMoreProvider, GetMoreProviderResult}; #[derive(Debug)] pub struct Cursor where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { client: Client, - wrapped_cursor: ImplicitSessionCursor, + wrapped_cursor: ImplicitSessionCursor, _phantom: std::marker::PhantomData, } impl Cursor where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { pub(crate) fn new( client: Client, - spec: CursorSpecification, + spec: CursorSpecification, session: Option, ) -> Self { let provider = ImplicitSessionGetMoreProvider::new(&spec, session); @@ -109,24 +109,18 @@ where impl Stream for Cursor where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = Pin::new(&mut self.wrapped_cursor).poll_next(cx); - match next { - Poll::Ready(opt) => Poll::Ready( - opt.map(|result| result.and_then(|doc| from_document(doc).map_err(Into::into))), - ), - Poll::Pending => Poll::Pending, - } + Pin::new(&mut self.wrapped_cursor).poll_next(cx) } } impl Drop for Cursor where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { fn drop(&mut self) { if self.wrapped_cursor.is_exhausted() { @@ -145,35 +139,36 @@ where /// A `GenericCursor` that optionally owns its own sessions. /// This is to be used by cursors associated with implicit sessions. -type ImplicitSessionCursor = GenericCursor; +type ImplicitSessionCursor = GenericCursor, T>; -struct ImplicitSessionGetMoreResult { - get_more_result: Result, +struct ImplicitSessionGetMoreResult { + get_more_result: Result>, session: Option>, } -impl GetMoreProviderResult for ImplicitSessionGetMoreResult { +impl GetMoreProviderResult for ImplicitSessionGetMoreResult { type Session = Option>; + type DocumentType = T; - fn as_ref(&self) -> std::result::Result<&GetMoreResult, &Error> { + fn as_ref(&self) -> std::result::Result<&GetMoreResult, &Error> { self.get_more_result.as_ref() } - fn into_parts(self) -> (Result, Self::Session) { + fn into_parts(self) -> (Result>, Self::Session) { (self.get_more_result, self.session) } } /// A `GetMoreProvider` that optionally owns its own session. /// This is to be used with cursors associated with implicit sessions. -enum ImplicitSessionGetMoreProvider { - Executing(BoxFuture<'static, ImplicitSessionGetMoreResult>), +enum ImplicitSessionGetMoreProvider { + Executing(BoxFuture<'static, ImplicitSessionGetMoreResult>), Idle(Option>), Done, } -impl ImplicitSessionGetMoreProvider { - fn new(spec: &CursorSpecification, session: Option) -> Self { +impl ImplicitSessionGetMoreProvider { + fn new(spec: &CursorSpecification, session: Option) -> Self { if spec.id() == 0 { Self::Done } else { @@ -182,9 +177,10 @@ impl ImplicitSessionGetMoreProvider { } } -impl GetMoreProvider for ImplicitSessionGetMoreProvider { - type GetMoreResult = ImplicitSessionGetMoreResult; - type GetMoreFuture = BoxFuture<'static, ImplicitSessionGetMoreResult>; +impl GetMoreProvider for ImplicitSessionGetMoreProvider { + type DocumentType = T; + type ResultType = ImplicitSessionGetMoreResult; + type GetMoreFuture = BoxFuture<'static, ImplicitSessionGetMoreResult>; fn executing_future(&mut self) -> Option<&mut Self::GetMoreFuture> { match self { diff --git a/src/cursor/session.rs b/src/cursor/session.rs index 80c183ac3..e2fa0dcd1 100644 --- a/src/cursor/session.rs +++ b/src/cursor/session.rs @@ -10,7 +10,7 @@ use serde::de::DeserializeOwned; use super::common::{CursorInformation, GenericCursor, GetMoreProvider, GetMoreProviderResult}; use crate::{ - bson::{from_document, Document}, + bson::Document, cursor::CursorSpecification, error::{Error, Result}, operation::GetMore, @@ -55,15 +55,14 @@ where exhausted: bool, client: Client, info: CursorInformation, - buffer: VecDeque, - _phantom: std::marker::PhantomData, + buffer: VecDeque, } impl SessionCursor where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { - pub(crate) fn new(client: Client, spec: CursorSpecification) -> Self { + pub(crate) fn new(client: Client, spec: CursorSpecification) -> Self { let exhausted = spec.id() == 0; Self { @@ -71,7 +70,6 @@ where client, info: spec.info, buffer: spec.initial_buffer, - _phantom: Default::default(), } } @@ -188,7 +186,8 @@ where /// A `GenericCursor` that borrows its session. /// This is to be used with cursors associated with explicit sessions borrowed from the user. -type ExplicitSessionCursor<'session> = GenericCursor>; +type ExplicitSessionCursor<'session, T> = + GenericCursor, T>; /// A type that implements [`Stream`](https://docs.rs/futures/latest/futures/stream/index.html) which can be used to /// stream the results of a [`SessionCursor`]. Returned from [`SessionCursor::stream`]. @@ -197,32 +196,26 @@ type ExplicitSessionCursor<'session> = GenericCursor where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { session_cursor: &'cursor mut SessionCursor, - generic_cursor: ExplicitSessionCursor<'session>, + generic_cursor: ExplicitSessionCursor<'session, T>, } impl<'cursor, 'session, T> Stream for SessionCursorStream<'cursor, 'session, T> where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = Pin::new(&mut self.generic_cursor).poll_next(cx); - match next { - Poll::Ready(opt) => Poll::Ready( - opt.map(|result| result.and_then(|doc| from_document(doc).map_err(Into::into))), - ), - Poll::Pending => Poll::Pending, - } + Pin::new(&mut self.generic_cursor).poll_next(cx) } } impl<'cursor, 'session, T> Drop for SessionCursorStream<'cursor, 'session, T> where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { fn drop(&mut self) { // Update the parent cursor's state based on any iteration performed on this handle. @@ -233,11 +226,11 @@ where /// Enum determining whether a `SessionCursorHandle` is excuting a getMore or not. /// In charge of maintaining ownership of the session reference. -enum ExplicitSessionGetMoreProvider<'session> { +enum ExplicitSessionGetMoreProvider<'session, T> { /// The handle is currently executing a getMore via the future. /// /// This future owns the reference to the session and will return it on completion. - Executing(BoxFuture<'session, ExecutionResult<'session>>), + Executing(BoxFuture<'session, ExecutionResult<'session, T>>), /// No future is being executed. /// @@ -246,15 +239,18 @@ enum ExplicitSessionGetMoreProvider<'session> { Idle(MutableSessionReference<'session>), } -impl<'session> ExplicitSessionGetMoreProvider<'session> { +impl<'session, T> ExplicitSessionGetMoreProvider<'session, T> { fn new(session: &'session mut ClientSession) -> Self { Self::Idle(MutableSessionReference { reference: session }) } } -impl<'session> GetMoreProvider for ExplicitSessionGetMoreProvider<'session> { - type GetMoreResult = ExecutionResult<'session>; - type GetMoreFuture = BoxFuture<'session, ExecutionResult<'session>>; +impl<'session, T: Send + Sync + DeserializeOwned> GetMoreProvider + for ExplicitSessionGetMoreProvider<'session, T> +{ + type DocumentType = T; + type ResultType = ExecutionResult<'session, T>; + type GetMoreFuture = BoxFuture<'session, ExecutionResult<'session, T>>; fn executing_future(&mut self) -> Option<&mut Self::GetMoreFuture> { match self { @@ -289,19 +285,20 @@ impl<'session> GetMoreProvider for ExplicitSessionGetMoreProvider<'session> { /// Struct returned from awaiting on a `GetMoreFuture` containing the result of the getMore as /// well as the reference to the `ClientSession` used for the getMore. -struct ExecutionResult<'session> { - get_more_result: Result, +struct ExecutionResult<'session, T> { + get_more_result: Result>, session: &'session mut ClientSession, } -impl<'session> GetMoreProviderResult for ExecutionResult<'session> { +impl<'session, T> GetMoreProviderResult for ExecutionResult<'session, T> { type Session = &'session mut ClientSession; + type DocumentType = T; - fn as_ref(&self) -> std::result::Result<&GetMoreResult, &Error> { + fn as_ref(&self) -> std::result::Result<&GetMoreResult, &Error> { self.get_more_result.as_ref() } - fn into_parts(self) -> (Result, Self::Session) { + fn into_parts(self) -> (Result>, Self::Session) { (self.get_more_result, self.session) } } diff --git a/src/is_master.rs b/src/is_master.rs index 3f94c1060..1a443d394 100644 --- a/src/is_master.rs +++ b/src/is_master.rs @@ -4,7 +4,10 @@ use serde::Deserialize; use crate::{ bson::{doc, oid::ObjectId, DateTime, Document, Timestamp}, - client::{options::ServerApi, ClusterTime}, + client::{ + options::{ServerAddress, ServerApi}, + ClusterTime, + }, cmap::{Command, Connection}, error::{ErrorKind, Result}, sdam::ServerType, @@ -43,11 +46,14 @@ pub(crate) async fn run_is_master( let response = conn.send_command(command, None).await?; let end_time = Instant::now(); - response.validate()?; - let cluster_time = response.cluster_time().cloned(); - let command_response: IsMasterCommandResponse = response.body()?; + let server_address = response.source_address().clone(); + let basic_response = response.into_document_response()?; + basic_response.validate()?; + let cluster_time = basic_response.cluster_time().cloned(); + let command_response: IsMasterCommandResponse = basic_response.body()?; Ok(IsMasterReply { + server_address, command_response, round_trip_time: Some(end_time.duration_since(start_time)), cluster_time, @@ -56,6 +62,7 @@ pub(crate) async fn run_is_master( #[derive(Debug, Clone)] pub(crate) struct IsMasterReply { + pub server_address: ServerAddress, pub command_response: IsMasterCommandResponse, pub round_trip_time: Option, pub cluster_time: Option, @@ -67,7 +74,6 @@ pub(crate) struct IsMasterCommandResponse { pub is_writable_primary: Option, #[serde(rename = "ismaster")] pub is_master: Option, - pub ok: Option, pub hosts: Option>, pub passives: Option>, pub arbiters: Option>, @@ -115,9 +121,7 @@ impl PartialEq for IsMasterCommandResponse { impl IsMasterCommandResponse { pub(crate) fn server_type(&self) -> ServerType { - if self.ok != Some(1.0) { - ServerType::Unknown - } else if self.msg.as_deref() == Some("isdbgrid") { + if self.msg.as_deref() == Some("isdbgrid") { ServerType::Mongos } else if self.set_name.is_some() { if let Some(true) = self.hidden { diff --git a/src/operation/abort_transaction/mod.rs b/src/operation/abort_transaction/mod.rs index c48dbb42d..5c6206edc 100644 --- a/src/operation/abort_transaction/mod.rs +++ b/src/operation/abort_transaction/mod.rs @@ -1,12 +1,12 @@ use crate::{ bson::doc, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, error::Result, operation::{Operation, Retryability}, options::WriteConcern, }; -use super::WriteConcernOnlyBody; +use super::{CommandResponse, Response, WriteConcernOnlyBody}; pub(crate) struct AbortTransaction { write_concern: Option, @@ -20,6 +20,8 @@ impl AbortTransaction { impl Operation for AbortTransaction { type O = (); + type Response = CommandResponse; + const NAME: &'static str = "abortTransaction"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -39,10 +41,10 @@ impl Operation for AbortTransaction { fn handle_response( &self, - response: CommandResponse, + response: ::Body, _description: &StreamDescription, ) -> Result { - response.body::()?.validate() + response.validate() } fn write_concern(&self) -> Option<&WriteConcern> { diff --git a/src/operation/aggregate/mod.rs b/src/operation/aggregate/mod.rs index 1ac9265a5..66af7025b 100644 --- a/src/operation/aggregate/mod.rs +++ b/src/operation/aggregate/mod.rs @@ -4,14 +4,16 @@ mod test; use crate::{ bson::{doc, Bson, Document}, bson_util, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, cursor::CursorSpecification, error::Result, - operation::{append_options, CursorBody, Operation, Retryability, WriteConcernOnlyBody}, + operation::{append_options, Operation, Retryability}, options::{AggregateOptions, SelectionCriteria, WriteConcern}, Namespace, }; +use super::{CursorBody, CursorResponse}; + #[derive(Debug)] pub(crate) struct Aggregate { target: AggregateTarget, @@ -39,7 +41,8 @@ impl Aggregate { } impl Operation for Aggregate { - type O = CursorSpecification; + type O = CursorSpecification; + type Response = CursorResponse; const NAME: &'static str = "aggregate"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -65,21 +68,16 @@ impl Operation for Aggregate { fn handle_response( &self, - response: CommandResponse, - _description: &StreamDescription, + response: CursorBody, + description: &StreamDescription, ) -> Result { - let source_address = response.source_address().clone(); - if self.is_out_or_merge() { - let error_body: WriteConcernOnlyBody = response.clone().body()?; - error_body.validate()?; + response.write_concern_info.validate()?; }; - let body: CursorBody = response.body()?; - Ok(CursorSpecification::new( - body.cursor, - source_address, + response.cursor, + description.server_address.clone(), self.options.as_ref().and_then(|opts| opts.batch_size), self.options.as_ref().and_then(|opts| opts.max_await_time), )) diff --git a/src/operation/aggregate/test.rs b/src/operation/aggregate/test.rs index b2bec8b2e..336e2b9f8 100644 --- a/src/operation/aggregate/test.rs +++ b/src/operation/aggregate/test.rs @@ -1,13 +1,17 @@ -use std::time::Duration; +use std::{collections::VecDeque, time::Duration}; use super::AggregateTarget; use crate::{ bson::{doc, Document}, bson_util, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, concern::{ReadConcern, ReadConcernLevel}, error::{ErrorKind, WriteFailure}, - operation::{test, Aggregate, Operation}, + operation::{ + test::{self, handle_response_test}, + Aggregate, + Operation, + }, options::{AggregateOptions, Hint, ServerAddress}, Namespace, }; @@ -189,34 +193,21 @@ async fn handle_success() { let aggregate = Aggregate::new(ns.clone(), Vec::new(), None); - let first_batch = vec![doc! {"_id": 1}, doc! {"_id": 2}]; - + let first_batch = VecDeque::from(vec![doc! {"_id": 1}, doc! {"_id": 2}]); let response = doc! { + "ok": 1.0, "cursor": { "id": 123, - "ns": format!("{}.{}", ns.db, ns.coll), - "firstBatch": bson_util::to_bson_array(&first_batch), - }, - "ok": 1.0 + "ns": "test_db.test_coll", + "firstBatch": Vec::from(first_batch.clone()), + } }; - let result = aggregate.handle_response( - CommandResponse::with_document_and_address(address.clone(), response.clone()), - &Default::default(), - ); - assert!(result.is_ok()); - - let cursor_spec = result.unwrap(); + let cursor_spec = handle_response_test(&aggregate, response.clone()).unwrap(); assert_eq!(cursor_spec.address(), &address); assert_eq!(cursor_spec.id(), 123); assert_eq!(cursor_spec.batch_size(), None); - assert_eq!( - cursor_spec - .initial_buffer - .into_iter() - .collect::>(), - first_batch - ); + assert_eq!(cursor_spec.initial_buffer, first_batch); let aggregate = Aggregate::new( ns, @@ -228,46 +219,29 @@ async fn handle_success() { .build(), ), ); - let result = aggregate.handle_response( - CommandResponse::with_document_and_address(address.clone(), response), - &Default::default(), - ); - assert!(result.is_ok()); - let cursor_spec = result.unwrap(); + let cursor_spec = handle_response_test(&aggregate, response).unwrap(); assert_eq!(cursor_spec.address(), &address); assert_eq!(cursor_spec.id(), 123); assert_eq!(cursor_spec.batch_size(), Some(123)); assert_eq!(cursor_spec.max_time(), Some(Duration::from_millis(5))); - assert_eq!( - cursor_spec - .initial_buffer - .into_iter() - .collect::>(), - first_batch - ); + assert_eq!(cursor_spec.initial_buffer, first_batch); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_max_await_time() { - let response = CommandResponse::with_document_and_address( - ServerAddress::default(), - doc! { - "cursor": { - "id": 123, - "ns": "a.b", - "firstBatch": [] - }, - "ok": 1.0 - }, - ); + let response = doc! { + "ok": 1, + "cursor": { + "id": 123, + "ns": "a.b", + "firstBatch": [] + } + }; let aggregate = Aggregate::empty(); - - let spec = aggregate - .handle_response(response.clone(), &Default::default()) - .expect("handle should succeed"); + let spec = handle_response_test(&aggregate, response.clone()).unwrap(); assert!(spec.max_time().is_none()); let max_await = Duration::from_millis(123); @@ -275,31 +249,29 @@ async fn handle_max_await_time() { .max_await_time(max_await) .build(); let aggregate = Aggregate::new(Namespace::empty(), Vec::new(), Some(options)); - let spec = aggregate - .handle_response(response, &Default::default()) - .expect("handle_should_succeed"); + let spec = handle_response_test(&aggregate, response).unwrap(); assert_eq!(spec.max_time(), Some(max_await)); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_write_concern_error() { - let response = CommandResponse::with_document(doc! { - "cursor" : { - "firstBatch" : [ ], - "id" : 0_i64, - "ns" : "test.test" + let response = doc! { + "ok": 1.0, + "cursor": { + "id": 0, + "ns": "test.test", + "firstBatch": [], }, - "writeConcernError" : { - "code" : 64, - "codeName" : "WriteConcernFailed", - "errmsg" : "waiting for replication timed out", - "errInfo" : { - "wtimeout" : true + "writeConcernError": { + "code": 64, + "codeName": "WriteConcernFailed", + "errmsg": "Waiting for replication timed out", + "errInfo": { + "wtimeout": true } - }, - "ok" : 1, - }); + } + }; let aggregate = Aggregate::new( Namespace::empty(), @@ -307,9 +279,7 @@ async fn handle_write_concern_error() { None, ); - let error = aggregate - .handle_response(response, &Default::default()) - .expect_err("should get wc error"); + let error = handle_response_test(&aggregate, response).unwrap_err(); match *error.kind { ErrorKind::Write(WriteFailure::WriteConcernError(_)) => {} ref e => panic!("should have gotten WriteConcernError, got {:?} instead", e), @@ -322,20 +292,14 @@ async fn handle_invalid_response() { let aggregate = Aggregate::empty(); let garbled = doc! { "asdfasf": "ASdfasdf" }; - assert!(aggregate - .handle_response(CommandResponse::with_document(garbled), &Default::default()) - .is_err()); + handle_response_test(&aggregate, garbled).unwrap_err(); let missing_cursor_field = doc! { + "ok": 1.0, "cursor": { "ns": "test.test", "firstBatch": [], } }; - assert!(aggregate - .handle_response( - CommandResponse::with_document(missing_cursor_field), - &Default::default() - ) - .is_err()); + handle_response_test(&aggregate, missing_cursor_field).unwrap_err(); } diff --git a/src/operation/commit_transaction/mod.rs b/src/operation/commit_transaction/mod.rs index c00787644..ca91c7eb7 100644 --- a/src/operation/commit_transaction/mod.rs +++ b/src/operation/commit_transaction/mod.rs @@ -3,13 +3,13 @@ use std::time::Duration; use bson::doc; use crate::{ - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, error::Result, operation::{append_options, Operation, Retryability}, options::{Acknowledgment, TransactionOptions, WriteConcern}, }; -use super::WriteConcernOnlyBody; +use super::{CommandResponse, WriteConcernOnlyBody}; pub(crate) struct CommitTransaction { options: Option, @@ -23,6 +23,8 @@ impl CommitTransaction { impl Operation for CommitTransaction { type O = (); + type Response = CommandResponse; + const NAME: &'static str = "commitTransaction"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -41,10 +43,10 @@ impl Operation for CommitTransaction { fn handle_response( &self, - response: CommandResponse, + response: WriteConcernOnlyBody, _description: &StreamDescription, ) -> Result { - response.body::()?.validate() + response.validate() } fn write_concern(&self) -> Option<&WriteConcern> { diff --git a/src/operation/count/mod.rs b/src/operation/count/mod.rs index 05e065b6d..42c4a4993 100644 --- a/src/operation/count/mod.rs +++ b/src/operation/count/mod.rs @@ -1,17 +1,20 @@ #[cfg(test)] mod test; +use bson::Document; use serde::Deserialize; use crate::{ bson::doc, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, coll::{options::EstimatedDocumentCountOptions, Namespace}, error::{Error, ErrorKind, Result}, operation::{append_options, CursorBody, Operation, Retryability}, selection_criteria::SelectionCriteria, }; +use super::CommandResponse; + const SERVER_4_9_0_WIRE_VERSION: i32 = 12; pub(crate) struct Count { @@ -38,6 +41,8 @@ impl Count { impl Operation for Count { type O = u64; + type Response = CommandResponse; + const NAME: &'static str = "count"; fn build(&mut self, description: &StreamDescription) -> Result { @@ -77,13 +82,13 @@ impl Operation for Count { fn handle_response( &self, - response: CommandResponse, + response: Response, description: &StreamDescription, ) -> Result { - let response_body: ResponseBody = match description.max_wire_version { - Some(v) if v >= SERVER_4_9_0_WIRE_VERSION => { - let CursorBody { mut cursor } = response.body()?; - + let response_body: ResponseBody = match (description.max_wire_version, response) { + (Some(v), Response::Aggregate(CursorBody { mut cursor, .. })) + if v >= SERVER_4_9_0_WIRE_VERSION => + { cursor .first_batch .pop_front() @@ -94,7 +99,13 @@ impl Operation for Count { }) })? } - _ => response.body()?, + (_, Response::Count(body)) => body, + _ => { + return Err(ErrorKind::InvalidResponse { + message: "response from server did not match count command".to_string(), + } + .into()) + } }; Ok(response_body.n) @@ -121,6 +132,13 @@ impl Operation for Count { } #[derive(Debug, Deserialize)] -struct ResponseBody { +#[serde(untagged)] +pub(crate) enum Response { + Aggregate(CursorBody), + Count(ResponseBody), +} + +#[derive(Debug, Deserialize)] +pub(crate) struct ResponseBody { n: u64, } diff --git a/src/operation/count/test.rs b/src/operation/count/test.rs index d5a9468b6..8d89eaacd 100644 --- a/src/operation/count/test.rs +++ b/src/operation/count/test.rs @@ -2,11 +2,15 @@ use std::time::Duration; use crate::{ bson::doc, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, coll::{options::EstimatedDocumentCountOptions, Namespace}, concern::ReadConcern, - error::ErrorKind, - operation::{test, Count, Operation}, + operation::{ + count::SERVER_4_9_0_WIRE_VERSION, + test::{self, handle_response_test, handle_response_test_with_wire_version}, + Count, + Operation, + }, options::ReadConcernLevel, }; @@ -77,25 +81,36 @@ async fn handle_success() { let count_op = Count::empty(); let n = 26; - let response = CommandResponse::with_document(doc! { "n" : n, "ok" : 1 }); - - let actual_values = count_op - .handle_response(response, &Default::default()) - .expect("supposed to succeed"); + let response = doc! { "ok": 1.0, "n": n }; + let actual_values = handle_response_test(&count_op, response).unwrap(); assert_eq!(actual_values, n); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] -async fn handle_response_no_n() { +async fn handle_success_agg() { let count_op = Count::empty(); - let response = CommandResponse::with_document(doc! { "ok" : 1 }); + let n = 26; + let response = doc! { + "ok": 1.0, + "cursor": { + "id": 0, + "ns": "a.b", + "firstBatch": [ { "n": n } ] + } + }; + + let actual_values = + handle_response_test_with_wire_version(&count_op, response, SERVER_4_9_0_WIRE_VERSION) + .unwrap(); + assert_eq!(actual_values, n); +} - let result = count_op.handle_response(response, &Default::default()); - match result.map_err(|e| *e.kind) { - Err(ErrorKind::InvalidResponse { .. }) => {} - other => panic!("expected response error, but got {:?}", other), - } +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn handle_response_no_n() { + let count_op = Count::empty(); + handle_response_test(&count_op, doc! { "ok": 1.0 }).unwrap_err(); } diff --git a/src/operation/count_documents/mod.rs b/src/operation/count_documents/mod.rs index bb147bde1..16b5fe556 100644 --- a/src/operation/count_documents/mod.rs +++ b/src/operation/count_documents/mod.rs @@ -3,10 +3,10 @@ mod test; use bson::{doc, Document}; -use super::{Operation, Retryability}; +use super::{CursorBody, CursorResponse, Operation, Retryability}; use crate::{ bson_util, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, error::{ErrorKind, Result}, operation::aggregate::Aggregate, options::{AggregateOptions, CountOptions}, @@ -65,6 +65,8 @@ impl CountDocuments { impl Operation for CountDocuments { type O = u64; + type Response = CursorResponse; + const NAME: &'static str = Aggregate::NAME; fn build(&mut self, description: &StreamDescription) -> Result { @@ -73,15 +75,10 @@ impl Operation for CountDocuments { fn handle_response( &self, - response: CommandResponse, - description: &StreamDescription, + mut response: CursorBody, + _description: &StreamDescription, ) -> Result { - let result = self - .aggregate - .handle_response(response, description) - .map(|mut spec| spec.initial_buffer.pop_front())?; - - let result_doc = match result { + let result_doc = match response.cursor.first_batch.pop_front() { Some(doc) => doc, None => return Ok(0), }; diff --git a/src/operation/count_documents/test.rs b/src/operation/count_documents/test.rs index 88e3ba8c1..16db9c0d1 100644 --- a/src/operation/count_documents/test.rs +++ b/src/operation/count_documents/test.rs @@ -1,10 +1,13 @@ use crate::{ bson::doc, bson_util, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, coll::Namespace, concern::ReadConcern, - operation::{test, Operation}, + operation::{ + test::{self, handle_response_test}, + Operation, + }, options::{CountOptions, Hint}, }; @@ -100,23 +103,15 @@ async fn handle_success() { let count_op = CountDocuments::new(ns, None, None); let n = 26; - let response = CommandResponse::with_document(doc! { - "cursor" : { - "firstBatch" : [ - { - "_id" : 1, - "n" : n - } - ], - "id" : 0, - "ns" : "test_db.test_coll" - }, - "ok" : 1 - }); - - let actual_values = count_op - .handle_response(response, &Default::default()) - .expect("supposed to succeed"); + let response = doc! { + "ok": 1.0, + "cursor": { + "id": 0, + "ns": "test_db.test_coll", + "firstBatch": [ { "_id": 1, "n": n } ], + } + }; + let actual_values = handle_response_test(&count_op, response).unwrap(); assert_eq!(actual_values, n); } diff --git a/src/operation/create/mod.rs b/src/operation/create/mod.rs index aa3973a3d..97530b900 100644 --- a/src/operation/create/mod.rs +++ b/src/operation/create/mod.rs @@ -3,13 +3,15 @@ mod test; use crate::{ bson::doc, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, error::Result, operation::{append_options, Operation, WriteConcernOnlyBody}, options::{CreateCollectionOptions, WriteConcern}, Namespace, }; +use super::CommandResponse; + #[derive(Debug)] pub(crate) struct Create { ns: Namespace, @@ -35,6 +37,8 @@ impl Create { impl Operation for Create { type O = (); + type Response = CommandResponse; + const NAME: &'static str = "create"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -52,10 +56,10 @@ impl Operation for Create { fn handle_response( &self, - response: CommandResponse, + response: WriteConcernOnlyBody, _description: &StreamDescription, ) -> Result { - response.body::()?.validate() + response.validate() } fn write_concern(&self) -> Option<&WriteConcern> { diff --git a/src/operation/create/test.rs b/src/operation/create/test.rs index 242a5fa1c..3b07ef899 100644 --- a/src/operation/create/test.rs +++ b/src/operation/create/test.rs @@ -1,9 +1,9 @@ use crate::{ bson::{doc, Bson}, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, concern::WriteConcern, error::{ErrorKind, WriteFailure}, - operation::{Create, Operation}, + operation::{test::handle_response_test, Create, Operation}, options::{CreateCollectionOptions, ValidationAction, ValidationLevel}, Namespace, }; @@ -76,11 +76,7 @@ async fn build_validator() { #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_success() { let op = Create::empty(); - - let ok_response = CommandResponse::with_document(doc! { "ok": 1.0 }); - assert!(op.handle_response(ok_response, &Default::default()).is_ok()); - let ok_extra = CommandResponse::with_document(doc! { "ok": 1.0, "hello": "world" }); - assert!(op.handle_response(ok_extra, &Default::default()).is_ok()); + handle_response_test(&op, doc! { "ok": 1.0 }).unwrap(); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] @@ -88,19 +84,17 @@ async fn handle_success() { async fn handle_write_concern_error() { let op = Create::empty(); - let response = CommandResponse::with_document(doc! { + let response = doc! { "writeConcernError": { "code": 100, "codeName": "hello world", "errmsg": "12345" }, "ok": 1 - }); - - let result = op.handle_response(response, &Default::default()); - assert!(result.is_err()); + }; - match *result.unwrap_err().kind { + let err = handle_response_test(&op, response).unwrap_err(); + match *err.kind { ErrorKind::Write(WriteFailure::WriteConcernError(ref wc_err)) => { assert_eq!(wc_err.code, 100); assert_eq!(wc_err.code_name, "hello world"); diff --git a/src/operation/delete/mod.rs b/src/operation/delete/mod.rs index 407a2c12a..6c4806f52 100644 --- a/src/operation/delete/mod.rs +++ b/src/operation/delete/mod.rs @@ -3,7 +3,7 @@ mod test; use crate::{ bson::{doc, Document}, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, coll::Namespace, collation::Collation, error::{convert_bulk_errors, Result}, @@ -12,6 +12,8 @@ use crate::{ results::DeleteResult, }; +use super::CommandResponse; + #[derive(Debug)] pub(crate) struct Delete { ns: Namespace, @@ -55,6 +57,8 @@ impl Delete { impl Operation for Delete { type O = DeleteResult; + type Response = CommandResponse; + const NAME: &'static str = "delete"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -88,14 +92,13 @@ impl Operation for Delete { fn handle_response( &self, - response: CommandResponse, + response: WriteResponseBody, _description: &StreamDescription, ) -> Result { - let body: WriteResponseBody = response.body()?; - body.validate().map_err(convert_bulk_errors)?; + response.validate().map_err(convert_bulk_errors)?; Ok(DeleteResult { - deleted_count: body.n, + deleted_count: response.n, }) } diff --git a/src/operation/delete/test.rs b/src/operation/delete/test.rs index 8201d8a55..57da6ac22 100644 --- a/src/operation/delete/test.rs +++ b/src/operation/delete/test.rs @@ -3,10 +3,10 @@ use pretty_assertions::assert_eq; use crate::{ bson::doc, bson_util, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, concern::{Acknowledgment, WriteConcern}, error::{ErrorKind, WriteConcernError, WriteError, WriteFailure}, - operation::{Delete, Operation}, + operation::{test::handle_response_test, Delete, Operation}, options::DeleteOptions, Namespace, }; @@ -102,15 +102,14 @@ async fn build_one() { async fn handle_success() { let op = Delete::empty(); - let ok_response = CommandResponse::with_document(doc! { - "ok": 1.0, - "n": 3, - }); - - let ok_result = op.handle_response(ok_response, &Default::default()); - assert!(ok_result.is_ok()); - - let delete_result = ok_result.unwrap(); + let delete_result = handle_response_test( + &op, + doc! { + "ok": 1.0, + "n": 3 + }, + ) + .expect("should succeed"); assert_eq!(delete_result.deleted_count, 3); } @@ -118,11 +117,14 @@ async fn handle_success() { #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_invalid_response() { let op = Delete::empty(); - - let invalid_response = CommandResponse::with_document(doc! { "ok": 1.0, "asdfadsf": 123123 }); - assert!(op - .handle_response(invalid_response, &Default::default()) - .is_err()); + handle_response_test( + &op, + doc! { + "ok": 1.0, + "asffasdf": 123123 + }, + ) + .expect_err("should fail"); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] @@ -130,7 +132,7 @@ async fn handle_invalid_response() { async fn handle_write_failure() { let op = Delete::empty(); - let write_error_response = CommandResponse::with_document(doc! { + let write_error_response = doc! { "ok": 1.0, "n": 0, "writeErrors": [ @@ -140,10 +142,9 @@ async fn handle_write_failure() { "errmsg": "my error string" } ] - }); - let write_error_result = op.handle_response(write_error_response, &Default::default()); - assert!(write_error_result.is_err()); - match *write_error_result.unwrap_err().kind { + }; + let write_error = handle_response_test(&op, write_error_response).unwrap_err(); + match *write_error.kind { ErrorKind::Write(WriteFailure::WriteError(ref error)) => { let expected_err = WriteError { code: 1234, @@ -162,7 +163,7 @@ async fn handle_write_failure() { async fn handle_write_concern_failure() { let op = Delete::empty(); - let wc_error_response = CommandResponse::with_document(doc! { + let wc_error_response = doc! { "ok": 1.0, "n": 0, "writeConcernError": { @@ -177,12 +178,11 @@ async fn handle_write_concern_failure() { } } } - }); - - let wc_error_result = op.handle_response(wc_error_response, &Default::default()); - assert!(wc_error_result.is_err()); + }; - match *wc_error_result.unwrap_err().kind { + let wc_error = handle_response_test(&op, wc_error_response) + .expect_err("should fail with write concern error"); + match *wc_error.kind { ErrorKind::Write(WriteFailure::WriteConcernError(ref wc_error)) => { let expected_wc_err = WriteConcernError { code: 456, diff --git a/src/operation/distinct/mod.rs b/src/operation/distinct/mod.rs index 4bdf51a18..2291297b6 100644 --- a/src/operation/distinct/mod.rs +++ b/src/operation/distinct/mod.rs @@ -5,13 +5,15 @@ use serde::Deserialize; use crate::{ bson::{doc, Bson, Document}, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, coll::{options::DistinctOptions, Namespace}, error::Result, operation::{append_options, Operation, Retryability}, selection_criteria::SelectionCriteria, }; +use super::CommandResponse; + pub(crate) struct Distinct { ns: Namespace, field_name: String, @@ -50,6 +52,8 @@ impl Distinct { impl Operation for Distinct { type O = Vec; + type Response = CommandResponse; + const NAME: &'static str = "distinct"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -72,10 +76,10 @@ impl Operation for Distinct { } fn handle_response( &self, - response: CommandResponse, + response: Response, _description: &StreamDescription, ) -> Result { - response.body::().map(|body| body.values) + Ok(response.values) } fn selection_criteria(&self) -> Option<&SelectionCriteria> { @@ -91,6 +95,6 @@ impl Operation for Distinct { } #[derive(Debug, Deserialize)] -struct ResponseBody { +pub(crate) struct Response { values: Vec, } diff --git a/src/operation/distinct/test.rs b/src/operation/distinct/test.rs index 3d45e97d2..8e8b69636 100644 --- a/src/operation/distinct/test.rs +++ b/src/operation/distinct/test.rs @@ -2,10 +2,14 @@ use std::time::Duration; use crate::{ bson::{doc, Bson}, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, coll::{options::DistinctOptions, Namespace}, error::ErrorKind, - operation::{test, Distinct, Operation}, + operation::{ + test::{self, handle_response_test}, + Distinct, + Operation, + }, }; #[cfg_attr(feature = "tokio-runtime", tokio::test)] @@ -100,15 +104,12 @@ async fn handle_success() { let expected_values: Vec = vec![Bson::String("A".to_string()), Bson::String("B".to_string())]; - let response = CommandResponse::with_document(doc! { + let response = doc! { "values" : expected_values.clone(), "ok" : 1 - }); - - let actual_values = distinct_op - .handle_response(response, &Default::default()) - .expect("supposed to succeed"); + }; + let actual_values = handle_response_test(&distinct_op, response).unwrap(); assert_eq!(actual_values, expected_values); } @@ -117,17 +118,13 @@ async fn handle_success() { async fn handle_response_with_empty_values() { let distinct_op = Distinct::empty(); - let response = CommandResponse::with_document(doc! { + let response = doc! { "values" : [], "ok" : 1 - }); + }; let expected_values: Vec = Vec::new(); - - let actual_values = distinct_op - .handle_response(response, &Default::default()) - .expect("supposed to succeed"); - + let actual_values = handle_response_test(&distinct_op, response).unwrap(); assert_eq!(actual_values, expected_values); } @@ -136,11 +133,11 @@ async fn handle_response_with_empty_values() { async fn handle_response_no_values() { let distinct_op = Distinct::empty(); - let response = CommandResponse::with_document(doc! { + let response = doc! { "ok" : 1 - }); + }; - let result = distinct_op.handle_response(response, &Default::default()); + let result = handle_response_test(&distinct_op, response); match result.map_err(|e| *e.kind) { Err(ErrorKind::InvalidResponse { .. }) => {} other => panic!("expected response error, but got {:?}", other), diff --git a/src/operation/drop_collection/mod.rs b/src/operation/drop_collection/mod.rs index 4ac156c0f..96261727a 100644 --- a/src/operation/drop_collection/mod.rs +++ b/src/operation/drop_collection/mod.rs @@ -3,13 +3,15 @@ mod test; use crate::{ bson::doc, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, error::{Error, Result}, operation::{append_options, Operation, WriteConcernOnlyBody}, options::{DropCollectionOptions, WriteConcern}, Namespace, }; +use super::CommandResponse; + #[derive(Debug)] pub(crate) struct DropCollection { ns: Namespace, @@ -35,6 +37,8 @@ impl DropCollection { impl Operation for DropCollection { type O = (); + type Response = CommandResponse; + const NAME: &'static str = "drop"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -53,10 +57,10 @@ impl Operation for DropCollection { fn handle_response( &self, - response: CommandResponse, + response: WriteConcernOnlyBody, _description: &StreamDescription, ) -> Result { - response.body::()?.validate() + response.validate() } fn handle_error(&self, error: Error) -> Result { diff --git a/src/operation/drop_collection/test.rs b/src/operation/drop_collection/test.rs index a372d7ed8..195efec04 100644 --- a/src/operation/drop_collection/test.rs +++ b/src/operation/drop_collection/test.rs @@ -1,9 +1,9 @@ use crate::{ bson::doc, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, concern::{Acknowledgment, WriteConcern}, error::{ErrorKind, WriteFailure}, - operation::{DropCollection, Operation}, + operation::{test::handle_response_test, DropCollection, Operation}, options::DropCollectionOptions, Namespace, }; @@ -55,10 +55,10 @@ async fn build() { async fn handle_success() { let op = DropCollection::empty(); - let ok_response = CommandResponse::with_document(doc! { "ok": 1.0 }); - assert!(op.handle_response(ok_response, &Default::default()).is_ok()); - let ok_extra = CommandResponse::with_document(doc! { "ok": 1.0, "hello": "world" }); - assert!(op.handle_response(ok_extra, &Default::default()).is_ok()); + let ok_response = doc! { "ok": 1.0 }; + handle_response_test(&op, ok_response).unwrap(); + let ok_extra = doc! { "ok": 1.0, "hello": "world" }; + handle_response_test(&op, ok_extra).unwrap(); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] @@ -66,19 +66,17 @@ async fn handle_success() { async fn handle_write_concern_error() { let op = DropCollection::empty(); - let response = CommandResponse::with_document(doc! { + let response = doc! { "writeConcernError": { "code": 100, "codeName": "hello world", "errmsg": "12345" }, "ok": 1 - }); - - let result = op.handle_response(response, &Default::default()); - assert!(result.is_err()); + }; - match *result.unwrap_err().kind { + let err = handle_response_test(&op, response).unwrap_err(); + match *err.kind { ErrorKind::Write(WriteFailure::WriteConcernError(ref wc_err)) => { assert_eq!(wc_err.code, 100); assert_eq!(wc_err.code_name, "hello world"); diff --git a/src/operation/drop_database/mod.rs b/src/operation/drop_database/mod.rs index 99976a63e..e0418ee62 100644 --- a/src/operation/drop_database/mod.rs +++ b/src/operation/drop_database/mod.rs @@ -3,12 +3,14 @@ mod test; use crate::{ bson::doc, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, error::Result, operation::{append_options, Operation, WriteConcernOnlyBody}, options::{DropDatabaseOptions, WriteConcern}, }; +use super::CommandResponse; + #[derive(Debug)] pub(crate) struct DropDatabase { target_db: String, @@ -28,6 +30,8 @@ impl DropDatabase { impl Operation for DropDatabase { type O = (); + type Response = CommandResponse; + const NAME: &'static str = "dropDatabase"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -46,10 +50,10 @@ impl Operation for DropDatabase { fn handle_response( &self, - response: CommandResponse, + response: WriteConcernOnlyBody, _description: &StreamDescription, ) -> Result { - response.body::()?.validate() + response.validate() } fn write_concern(&self) -> Option<&WriteConcern> { diff --git a/src/operation/drop_database/test.rs b/src/operation/drop_database/test.rs index 1547f8548..aac559b44 100644 --- a/src/operation/drop_database/test.rs +++ b/src/operation/drop_database/test.rs @@ -1,9 +1,9 @@ use crate::{ bson::doc, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, concern::{Acknowledgment, WriteConcern}, error::{ErrorKind, WriteFailure}, - operation::{DropDatabase, Operation}, + operation::{test::handle_response_test, DropDatabase, Operation}, options::DropDatabaseOptions, }; @@ -53,10 +53,10 @@ async fn build() { async fn handle_success() { let op = DropDatabase::empty(); - let ok_response = CommandResponse::with_document(doc! { "ok": 1.0 }); - assert!(op.handle_response(ok_response, &Default::default()).is_ok()); - let ok_extra = CommandResponse::with_document(doc! { "ok": 1.0, "hello": "world" }); - assert!(op.handle_response(ok_extra, &Default::default()).is_ok()); + let ok_response = doc! { "ok": 1.0 }; + handle_response_test(&op, ok_response).unwrap(); + let ok_extra = doc! { "ok": 1.0, "hello": "world" }; + handle_response_test(&op, ok_extra).unwrap(); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] @@ -64,19 +64,18 @@ async fn handle_success() { async fn handle_write_concern_error() { let op = DropDatabase::empty(); - let response = CommandResponse::with_document(doc! { + let response = doc! { "writeConcernError": { "code": 100, "codeName": "hello world", "errmsg": "12345" }, "ok": 1 - }); + }; - let result = op.handle_response(response, &Default::default()); - assert!(result.is_err()); + let err = handle_response_test(&op, response).unwrap_err(); - match *result.unwrap_err().kind { + match *err.kind { ErrorKind::Write(WriteFailure::WriteConcernError(ref wc_err)) => { assert_eq!(wc_err.code, 100); assert_eq!(wc_err.code_name, "hello world"); diff --git a/src/operation/find/mod.rs b/src/operation/find/mod.rs index b94593e0a..e4cc64782 100644 --- a/src/operation/find/mod.rs +++ b/src/operation/find/mod.rs @@ -1,9 +1,13 @@ #[cfg(test)] mod test; +use std::marker::PhantomData; + +use serde::de::DeserializeOwned; + use crate::{ bson::{doc, Document}, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, cursor::CursorSpecification, error::{ErrorKind, Result}, operation::{append_options, CursorBody, Operation, Retryability}, @@ -11,14 +15,17 @@ use crate::{ Namespace, }; +use super::CursorResponse; + #[derive(Debug)] -pub(crate) struct Find { +pub(crate) struct Find { ns: Namespace, filter: Option, options: Option, + _phantom: PhantomData, } -impl Find { +impl Find { #[cfg(test)] fn empty() -> Self { Self::new( @@ -40,12 +47,14 @@ impl Find { ns, filter, options, + _phantom: Default::default(), } } } -impl Operation for Find { - type O = CursorSpecification; +impl Operation for Find { + type O = CursorSpecification; + type Response = CursorResponse; const NAME: &'static str = "find"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -97,16 +106,12 @@ impl Operation for Find { fn handle_response( &self, - response: CommandResponse, - _description: &StreamDescription, + response: CursorBody, + description: &StreamDescription, ) -> Result { - let source_address = response.source_address().clone(); - let mut body: CursorBody = response.body()?; - body.cursor.ns = self.ns.clone(); - Ok(CursorSpecification::new( - body.cursor, - source_address, + response.cursor, + description.server_address.clone(), self.options.as_ref().and_then(|opts| opts.batch_size), self.options.as_ref().and_then(|opts| opts.max_await_time), )) diff --git a/src/operation/find/test.rs b/src/operation/find/test.rs index f9c4cca8c..3fe51b9b0 100644 --- a/src/operation/find/test.rs +++ b/src/operation/find/test.rs @@ -3,8 +3,12 @@ use std::time::Duration; use crate::{ bson::{doc, Document}, bson_util, - cmap::{CommandResponse, StreamDescription}, - operation::{test, Find, Operation}, + cmap::StreamDescription, + operation::{ + test::{self, handle_response_test}, + Find, + Operation, + }, options::{CursorType, FindOptions, Hint, ReadConcern, ReadConcernLevel, ServerAddress}, Namespace, }; @@ -15,7 +19,7 @@ fn build_test( options: Option, mut expected_body: Document, ) { - let mut find = Find::new(ns.clone(), filter, options); + let mut find = Find::::new(ns.clone(), filter, options); let mut cmd = find.build(&StreamDescription::new_testing()).unwrap(); @@ -176,7 +180,7 @@ async fn build_batch_size() { let options = FindOptions::builder() .batch_size((std::i32::MAX as u32) + 1) .build(); - let mut op = Find::new(Namespace::empty(), None, Some(options)); + let mut op = Find::::new(Namespace::empty(), None, Some(options)); assert!(op.build(&StreamDescription::new_testing()).is_err()) } @@ -188,7 +192,7 @@ async fn op_selection_criteria() { selection_criteria, ..Default::default() }; - Find::new(Namespace::empty(), None, Some(options)) + Find::::new(Namespace::empty(), None, Some(options)) }); } @@ -205,7 +209,7 @@ async fn handle_success() { port: None, }; - let find = Find::empty(); + let find = Find::::empty(); let first_batch = vec![doc! {"_id": 1}, doc! {"_id": 2}]; @@ -218,13 +222,7 @@ async fn handle_success() { "ok": 1.0 }; - let result = find.handle_response( - CommandResponse::with_document_and_address(address.clone(), response.clone()), - &Default::default(), - ); - assert!(result.is_ok()); - - let cursor_spec = result.unwrap(); + let cursor_spec = handle_response_test(&find, response.clone()).unwrap(); assert_eq!(cursor_spec.address(), &address); assert_eq!(cursor_spec.id(), 123); assert_eq!(cursor_spec.batch_size(), None); @@ -241,13 +239,7 @@ async fn handle_success() { None, Some(FindOptions::builder().batch_size(123).build()), ); - let result = find.handle_response( - CommandResponse::with_document_and_address(address.clone(), response), - &Default::default(), - ); - assert!(result.is_ok()); - - let cursor_spec = result.unwrap(); + let cursor_spec = handle_response_test(&find, response).unwrap(); assert_eq!(cursor_spec.address(), &address); assert_eq!(cursor_spec.id(), 123); assert_eq!(cursor_spec.batch_size(), Some(123)); @@ -262,11 +254,7 @@ async fn handle_success() { fn verify_max_await_time(max_await_time: Option, cursor_type: Option) { let ns = Namespace::empty(); - let address = ServerAddress::Tcp { - host: "localhost".to_string(), - port: None, - }; - let find = Find::new( + let find = Find::::new( ns, None, Some(FindOptions { @@ -276,8 +264,8 @@ fn verify_max_await_time(max_await_time: Option, cursor_type: Option, cursor_type: Option::empty(); let garbled = doc! { "asdfasf": "ASdfasdf" }; - assert!(find - .handle_response(CommandResponse::with_document(garbled), &Default::default()) - .is_err()); + handle_response_test(&find, garbled).unwrap_err(); let missing_cursor_field = doc! { "cursor": { @@ -326,10 +309,5 @@ async fn handle_invalid_response() { "firstBatch": [], } }; - assert!(find - .handle_response( - CommandResponse::with_document(missing_cursor_field), - &Default::default() - ) - .is_err()); + handle_response_test(&find, missing_cursor_field).unwrap_err(); } diff --git a/src/operation/find_and_modify/mod.rs b/src/operation/find_and_modify/mod.rs index bba3b8c38..4bea1816c 100644 --- a/src/operation/find_and_modify/mod.rs +++ b/src/operation/find_and_modify/mod.rs @@ -10,7 +10,7 @@ use self::options::FindAndModifyOptions; use crate::{ bson::{doc, from_document, Bson, Document}, bson_util, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, coll::{ options::{ FindOneAndDeleteOptions, @@ -25,6 +25,8 @@ use crate::{ options::WriteConcern, }; +use super::CommandResponse; + pub(crate) struct FindAndModify where T: DeserializeOwned, @@ -100,6 +102,7 @@ where T: DeserializeOwned, { type O = Option; + type Response = CommandResponse; const NAME: &'static str = "findAndModify"; fn build(&mut self, description: &StreamDescription) -> Result { @@ -128,11 +131,10 @@ where fn handle_response( &self, - response: CommandResponse, + response: Response, _description: &StreamDescription, ) -> Result { - let body: ResponseBody = response.body()?; - match body.value { + match response.value { Bson::Document(doc) => Ok(Some(from_document(doc)?)), Bson::Null => Ok(None), other => Err(ErrorKind::InvalidResponse { @@ -156,6 +158,6 @@ where } #[derive(Debug, Deserialize)] -struct ResponseBody { +pub(crate) struct Response { value: Bson, } diff --git a/src/operation/find_and_modify/test.rs b/src/operation/find_and_modify/test.rs index 5eb47e2bc..5a3a2698f 100644 --- a/src/operation/find_and_modify/test.rs +++ b/src/operation/find_and_modify/test.rs @@ -3,9 +3,9 @@ use std::time::Duration; use crate::{ bson::{doc, oid::ObjectId, Bson, Document}, bson_util, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, coll::options::ReturnDocument, - operation::{FindAndModify, Operation}, + operation::{test::handle_response_test, FindAndModify, Operation}, options::{ FindOneAndDeleteOptions, FindOneAndReplaceOptions, @@ -143,7 +143,7 @@ async fn handle_success_delete() { "rating" : 100, "score" : 5 }; - let ok_response = CommandResponse::with_document(doc! { + let ok_response = doc! { "lastErrorObject" : { "connectionId" : 1, "updatedExisting" : true, @@ -155,13 +155,10 @@ async fn handle_success_delete() { }, "value" : value.clone(), "ok" : 1 - }); + }; - let result = op.handle_response(ok_response, &Default::default()); - assert_eq!( - result.expect("handle failed").expect("result was None"), - value - ); + let result = handle_response_test(&op, ok_response).unwrap(); + assert_eq!(result.unwrap(), value); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] @@ -169,10 +166,8 @@ async fn handle_success_delete() { async fn handle_null_value_delete() { let op = empty_delete(); - let null_value = CommandResponse::with_document(doc! { "ok": 1.0, "value": Bson::Null}); - let result = op.handle_response(null_value, &Default::default()); - assert!(result.is_ok()); - assert_eq!(result.expect("handle failed"), None); + let result = handle_response_test(&op, doc! { "ok": 1.0, "value": Bson::Null }).unwrap(); + assert_eq!(result, None); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] @@ -180,8 +175,7 @@ async fn handle_null_value_delete() { async fn handle_no_value_delete() { let op = empty_delete(); - let no_value = CommandResponse::with_document(doc! { "ok": 1.0 }); - assert!(op.handle_response(no_value, &Default::default()).is_err()); + handle_response_test(&op, doc! { "ok": 1.0 }).unwrap_err(); } // replace tests @@ -334,7 +328,7 @@ async fn handle_success_replace() { "rating" : 100, "score" : 5 }; - let ok_response = CommandResponse::with_document(doc! { + let ok_response = doc! { "lastErrorObject" : { "connectionId" : 1, "updatedExisting" : true, @@ -346,33 +340,25 @@ async fn handle_success_replace() { }, "value" : value.clone(), "ok" : 1 - }); + }; - let result = op.handle_response(ok_response, &Default::default()); - assert_eq!( - result.expect("handle failed").expect("result was None"), - value - ); + let result = handle_response_test(&op, ok_response).unwrap(); + assert_eq!(result.unwrap(), value); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_null_value_replace() { let op = empty_replace(); - - let null_value = CommandResponse::with_document(doc! { "ok": 1.0, "value": Bson::Null}); - let result = op.handle_response(null_value, &Default::default()); - assert!(result.is_ok()); - assert_eq!(result.expect("handle failed"), None); + let result = handle_response_test(&op, doc! { "ok": 1.0, "value": Bson::Null }).unwrap(); + assert_eq!(result, None); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_no_value_replace() { let op = empty_replace(); - - let no_value = CommandResponse::with_document(doc! { "ok": 1.0 }); - assert!(op.handle_response(no_value, &Default::default()).is_err()); + handle_response_test(&op, doc! { "ok": 1.0 }).unwrap_err(); } // update tests @@ -511,7 +497,7 @@ async fn handle_success_update() { "rating" : 100, "score" : 5 }; - let ok_response = CommandResponse::with_document(doc! { + let ok_response = doc! { "lastErrorObject" : { "connectionId" : 1, "updatedExisting" : true, @@ -523,31 +509,23 @@ async fn handle_success_update() { }, "value" : value.clone(), "ok" : 1 - }); + }; - let result = op.handle_response(ok_response, &Default::default()); - assert_eq!( - result.expect("handle failed").expect("result was None"), - value - ); + let result = handle_response_test(&op, ok_response).unwrap(); + assert_eq!(result.unwrap(), value); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_null_value_update() { let op = empty_update(); - - let null_value = CommandResponse::with_document(doc! { "ok": 1.0, "value": Bson::Null}); - let result = op.handle_response(null_value, &Default::default()); - assert!(result.is_ok()); - assert_eq!(result.expect("handle failed"), None); + let result = handle_response_test(&op, doc! { "ok": 1.0, "value": Bson::Null }).unwrap(); + assert_eq!(result, None); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_no_value_update() { let op = empty_update(); - - let no_value = CommandResponse::with_document(doc! { "ok": 1.0 }); - assert!(op.handle_response(no_value, &Default::default()).is_err()); + handle_response_test(&op, doc! { "ok": 1.0 }).unwrap_err(); } diff --git a/src/operation/get_more/mod.rs b/src/operation/get_more/mod.rs index dc7944b21..53455dec7 100644 --- a/src/operation/get_more/mod.rs +++ b/src/operation/get_more/mod.rs @@ -1,13 +1,13 @@ #[cfg(test)] mod test; -use std::{collections::VecDeque, time::Duration}; +use std::{collections::VecDeque, marker::PhantomData, time::Duration}; -use serde::Deserialize; +use serde::{de::DeserializeOwned, Deserialize}; use crate::{ - bson::{doc, Document}, - cmap::{Command, CommandResponse, StreamDescription}, + bson::doc, + cmap::{Command, StreamDescription}, cursor::CursorInformation, error::{ErrorKind, Result}, operation::Operation, @@ -16,16 +16,19 @@ use crate::{ Namespace, }; +use super::CommandResponse; + #[derive(Debug)] -pub(crate) struct GetMore { +pub(crate) struct GetMore { ns: Namespace, cursor_id: i64, selection_criteria: SelectionCriteria, batch_size: Option, max_time: Option, + _phantom: PhantomData, } -impl GetMore { +impl GetMore { pub(crate) fn new(info: CursorInformation) -> Self { Self { ns: info.ns, @@ -33,12 +36,15 @@ impl GetMore { selection_criteria: SelectionCriteria::from_address(info.address), batch_size: info.batch_size, max_time: info.max_time, + _phantom: Default::default(), } } } -impl Operation for GetMore { - type O = GetMoreResult; +impl Operation for GetMore { + type O = GetMoreResult; + type Response = CommandResponse>; + const NAME: &'static str = "getMore"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -71,13 +77,12 @@ impl Operation for GetMore { fn handle_response( &self, - response: CommandResponse, + response: GetMoreResponseBody, _description: &StreamDescription, ) -> Result { - let body: GetMoreResponseBody = response.body()?; Ok(GetMoreResult { - batch: body.cursor.next_batch, - exhausted: body.cursor.id == 0, + batch: response.cursor.next_batch, + exhausted: response.cursor.id == 0, }) } @@ -87,13 +92,13 @@ impl Operation for GetMore { } #[derive(Debug, Deserialize)] -struct GetMoreResponseBody { - cursor: NextBatchBody, +pub(crate) struct GetMoreResponseBody { + cursor: NextBatchBody, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] -struct NextBatchBody { +struct NextBatchBody { id: i64, - next_batch: VecDeque, + next_batch: VecDeque, } diff --git a/src/operation/get_more/test.rs b/src/operation/get_more/test.rs index 4960a56e7..c7abb8a2d 100644 --- a/src/operation/get_more/test.rs +++ b/src/operation/get_more/test.rs @@ -3,9 +3,9 @@ use std::time::Duration; use crate::{ bson::{doc, Document}, bson_util, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, cursor::CursorInformation, - operation::{GetMore, Operation}, + operation::{test::handle_response_test, GetMore, Operation}, options::ServerAddress, sdam::{ServerDescription, ServerInfo, ServerType}, Namespace, @@ -26,7 +26,7 @@ fn build_test( batch_size, max_time, }; - let mut get_more = GetMore::new(info); + let mut get_more = GetMore::::new(info); let build_result = get_more.build(&StreamDescription::new_testing()); assert!(build_result.is_ok()); @@ -117,7 +117,7 @@ async fn build_batch_size() { batch_size: Some((std::i32::MAX as u32) + 1), max_time: None, }; - let mut op = GetMore::new(info); + let mut op = GetMore::::new(info); assert!(op.build(&StreamDescription::new_testing()).is_err()) } @@ -136,7 +136,7 @@ async fn op_selection_criteria() { batch_size: None, max_time: None, }; - let get_more = GetMore::new(info); + let get_more = GetMore::::new(info); let server_description = ServerDescription { address, server_type: ServerType::Unknown, @@ -181,36 +181,32 @@ async fn handle_success() { batch_size: None, max_time: None, }; - let get_more = GetMore::new(info); + let get_more = GetMore::::new(info); let batch = vec![doc! { "_id": 1 }, doc! { "_id": 2 }, doc! { "_id": 3 }]; - let response = CommandResponse::with_document(doc! { + let response = doc! { "cursor": { "id": 123, "ns": "test_db.test_coll", "nextBatch": bson_util::to_bson_array(&batch), }, "ok": 1 - }); + }; - let result = get_more - .handle_response(response, &Default::default()) - .expect("handle success case failed"); + let result = handle_response_test(&get_more, response).unwrap(); assert!(!result.exhausted); assert_eq!(result.batch, batch); - let response = CommandResponse::with_document(doc! { + let response = doc! { "cursor": { "id": 0, "ns": "test_db.test_coll", "nextBatch": bson_util::to_bson_array(&batch), }, "ok": 1 - }); - let result = get_more - .handle_response(response, &Default::default()) - .expect("handle success case failed"); + }; + let result = handle_response_test(&get_more, response).unwrap(); assert!(result.exhausted); assert_eq!(result.batch, batch); } diff --git a/src/operation/insert/mod.rs b/src/operation/insert/mod.rs index 9859b9113..1828b3f1d 100644 --- a/src/operation/insert/mod.rs +++ b/src/operation/insert/mod.rs @@ -9,7 +9,7 @@ use serde::Serialize; use crate::{ bson::{doc, Document}, bson_util, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, error::{BulkWriteFailure, Error, ErrorKind, Result}, operation::{append_options, Operation, Retryability, WriteResponseBody}, options::{InsertManyOptions, WriteConcern}, @@ -17,6 +17,8 @@ use crate::{ Namespace, }; +use super::CommandResponse; + #[derive(Debug)] pub(crate) struct Insert { ns: Namespace, @@ -49,6 +51,8 @@ impl Insert { impl Operation for Insert { type O = InsertManyResult; + type Response = CommandResponse; + const NAME: &'static str = "insert"; fn build(&mut self, description: &StreamDescription) -> Result { @@ -110,15 +114,18 @@ impl Operation for Insert { fn handle_response( &self, - response: CommandResponse, + response: WriteResponseBody, _description: &StreamDescription, ) -> Result { - let body: WriteResponseBody = response.body()?; - let mut map = HashMap::new(); if self.is_ordered() { // in ordered inserts, only the first n were attempted. - for (i, id) in self.inserted_ids.iter().enumerate().take(body.n as usize) { + for (i, id) in self + .inserted_ids + .iter() + .enumerate() + .take(response.n as usize) + { map.insert(i, id.clone()); } } else { @@ -128,21 +135,21 @@ impl Operation for Insert { map.insert(i, id.clone()); } - if let Some(write_errors) = body.write_errors.as_ref() { + if let Some(write_errors) = response.write_errors.as_ref() { for err in write_errors { map.remove(&err.index); } } } - if body.write_errors.is_some() || body.write_concern_error.is_some() { + if response.write_errors.is_some() || response.write_concern_error.is_some() { return Err(Error::new( ErrorKind::BulkWrite(BulkWriteFailure { - write_errors: body.write_errors, - write_concern_error: body.write_concern_error, + write_errors: response.write_errors, + write_concern_error: response.write_concern_error, inserted_ids: map, }), - body.labels, + response.labels, )); } diff --git a/src/operation/insert/test.rs b/src/operation/insert/test.rs index 59fe7c85c..c16f982af 100644 --- a/src/operation/insert/test.rs +++ b/src/operation/insert/test.rs @@ -1,9 +1,9 @@ use crate::{ bson::{doc, Bson, Document}, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, concern::WriteConcern, error::{BulkWriteError, ErrorKind, WriteConcernError}, - operation::{Insert, Operation}, + operation::{test::handle_response_test, Insert, Operation}, options::InsertManyOptions, Namespace, }; @@ -142,13 +142,7 @@ async fn handle_success() { .op .build(&StreamDescription::new_testing()) .unwrap(); - let ok_response = CommandResponse::with_document(doc! { "ok": 1.0, "n": 3 }); - let ok_result = fixtures - .op - .handle_response(ok_response, &Default::default()); - assert!(ok_result.is_ok()); - - let response = ok_result.unwrap(); + let response = handle_response_test(&fixtures.op, doc! { "ok": 1.0, "n": 3 }).unwrap(); let inserted_ids = response.inserted_ids; assert_eq!(inserted_ids.len(), 3); assert_eq!( @@ -161,12 +155,7 @@ async fn handle_success() { #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_invalid_response() { let fixtures = fixtures(); - - let invalid_response = CommandResponse::with_document(doc! { "ok": 1.0, "asdfadsf": 123123 }); - assert!(fixtures - .op - .handle_response(invalid_response, &Default::default()) - .is_err()); + handle_response_test(&fixtures.op, doc! { "ok": 1.0, "asdfadsf": 123123 }).unwrap_err(); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] @@ -180,7 +169,7 @@ async fn handle_write_failure() { .build(&StreamDescription::new_testing()) .unwrap(); - let write_error_response = CommandResponse::with_document(doc! { + let write_error_response = doc! { "ok": 1.0, "n": 1, "writeErrors": [ @@ -205,13 +194,10 @@ async fn handle_write_failure() { } } } - }); - - let write_error_response = fixtures - .op - .handle_response(write_error_response, &Default::default()) - .expect_err("result should be err"); + }; + let write_error_response = + handle_response_test(&fixtures.op, write_error_response).unwrap_err(); match *write_error_response.kind { ErrorKind::BulkWrite(bwe) => { let write_errors = bwe.write_errors.expect("write errors should be present"); diff --git a/src/operation/list_collections/mod.rs b/src/operation/list_collections/mod.rs index 1b4b5d3bb..3e2141a68 100644 --- a/src/operation/list_collections/mod.rs +++ b/src/operation/list_collections/mod.rs @@ -1,24 +1,31 @@ #[cfg(test)] mod test; +use std::marker::PhantomData; + +use serde::de::DeserializeOwned; + use crate::{ bson::{doc, Document}, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, cursor::CursorSpecification, error::Result, operation::{append_options, CursorBody, Operation, Retryability}, options::{ListCollectionsOptions, ReadPreference, SelectionCriteria}, }; +use super::CursorResponse; + #[derive(Debug)] -pub(crate) struct ListCollections { +pub(crate) struct ListCollections { db: String, filter: Option, name_only: bool, options: Option, + _phantom: PhantomData, } -impl ListCollections { +impl ListCollections { #[cfg(test)] fn empty() -> Self { Self::new(String::new(), None, false, None) @@ -35,12 +42,18 @@ impl ListCollections { filter, name_only, options, + _phantom: PhantomData::default(), } } } -impl Operation for ListCollections { - type O = CursorSpecification; +impl Operation for ListCollections +where + T: DeserializeOwned + Unpin + Send + Sync, +{ + type O = CursorSpecification; + type Response = CursorResponse; + const NAME: &'static str = "listCollections"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -65,15 +78,12 @@ impl Operation for ListCollections { fn handle_response( &self, - response: CommandResponse, - _description: &StreamDescription, + response: CursorBody, + description: &StreamDescription, ) -> Result { - let source_address = response.source_address().clone(); - let body: CursorBody = response.body()?; - Ok(CursorSpecification::new( - body.cursor, - source_address, + response.cursor, + description.server_address.clone(), self.options.as_ref().and_then(|opts| opts.batch_size), None, )) diff --git a/src/operation/list_collections/test.rs b/src/operation/list_collections/test.rs index 237c1432c..98cc10903 100644 --- a/src/operation/list_collections/test.rs +++ b/src/operation/list_collections/test.rs @@ -1,13 +1,17 @@ use crate::{ bson::{doc, Document}, bson_util, - cmap::{CommandResponse, StreamDescription}, - operation::{ListCollections, Operation}, + cmap::StreamDescription, + operation::{test::handle_response_test, ListCollections, Operation}, options::{ListCollectionsOptions, ServerAddress}, Namespace, }; -fn build_test(db_name: &str, mut list_collections: ListCollections, mut expected_body: Document) { +fn build_test( + db_name: &str, + mut list_collections: ListCollections, + mut expected_body: Document, +) { let mut cmd = list_collections .build(&StreamDescription::new_testing()) .expect("build should succeed"); @@ -124,7 +128,7 @@ async fn build_batch_size() { #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn op_selection_criteria() { - assert!(ListCollections::empty() + assert!(ListCollections::::empty() .selection_criteria() .expect("should have criteria") .is_read_pref_primary()); @@ -169,13 +173,8 @@ async fn handle_success() { "ok": 1.0 }; - let cursor_spec = list_collections - .handle_response( - CommandResponse::with_document_and_address(ServerAddress::default(), response.clone()), - &Default::default(), - ) - .expect("handle should succeed"); - + let cursor_spec = + handle_response_test(&list_collections, response.clone()).expect("handle should succeed"); assert_eq!(cursor_spec.address(), &ServerAddress::default()); assert_eq!(cursor_spec.id(), 123); assert_eq!(cursor_spec.batch_size(), None); @@ -194,13 +193,9 @@ async fn handle_success() { false, Some(ListCollectionsOptions::builder().batch_size(123).build()), ); - let cursor_spec = list_collections - .handle_response( - CommandResponse::with_document_and_address(ServerAddress::default(), response), - &Default::default(), - ) - .expect("handle should succeed"); + let cursor_spec = + handle_response_test(&list_collections, response).expect("handle should succeed"); assert_eq!(cursor_spec.address(), &ServerAddress::default()); assert_eq!(cursor_spec.id(), 123); assert_eq!(cursor_spec.batch_size(), Some(123)); @@ -214,15 +209,52 @@ async fn handle_success() { ); } +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn handle_success_name_only() { + let ns = Namespace { + db: "test_db".to_string(), + coll: "test_coll".to_string(), + }; + + let list_collections = ListCollections::new("test_db".to_string(), None, false, None); + + let first_batch = vec![doc! { + "name" : "test", + "type" : "collection", + }]; + + let response = doc! { + "cursor": { + "id": 123, + "ns": format!("{}.{}", ns.db, ns.coll), + "firstBatch": bson_util::to_bson_array(&first_batch), + }, + "ok": 1.0 + }; + + let cursor_spec = + handle_response_test(&list_collections, response).expect("handle should succeed"); + assert_eq!(cursor_spec.address(), &ServerAddress::default()); + assert_eq!(cursor_spec.id(), 123); + assert_eq!(cursor_spec.batch_size(), None); + assert_eq!(cursor_spec.max_time(), None); + assert_eq!( + cursor_spec + .initial_buffer + .into_iter() + .collect::>(), + first_batch + ); +} + #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_invalid_response() { - let list_collections = ListCollections::empty(); + let list_collections = ListCollections::::empty(); let garbled = doc! { "asdfasf": "ASdfasdf" }; - assert!(list_collections - .handle_response(CommandResponse::with_document(garbled), &Default::default()) - .is_err()); + handle_response_test(&list_collections, garbled).expect_err("garbled response should fail"); let missing_cursor_field = doc! { "cursor": { @@ -230,10 +262,6 @@ async fn handle_invalid_response() { "firstBatch": [], } }; - assert!(list_collections - .handle_response( - CommandResponse::with_document(missing_cursor_field), - &Default::default() - ) - .is_err()); + handle_response_test(&list_collections, missing_cursor_field) + .expect_err("missing cursor field should fail"); } diff --git a/src/operation/list_databases/mod.rs b/src/operation/list_databases/mod.rs index 093224e9e..356a59e11 100644 --- a/src/operation/list_databases/mod.rs +++ b/src/operation/list_databases/mod.rs @@ -5,13 +5,15 @@ use serde::Deserialize; use crate::{ bson::{doc, Document}, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, error::Result, operation::{append_options, Operation, Retryability}, options::ListDatabasesOptions, selection_criteria::{ReadPreference, SelectionCriteria}, }; +use super::CommandResponse; + #[derive(Debug)] pub(crate) struct ListDatabases { filter: Option, @@ -44,6 +46,8 @@ impl ListDatabases { impl Operation for ListDatabases { type O = Vec; + type Response = CommandResponse; + const NAME: &'static str = "listDatabases"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -67,10 +71,10 @@ impl Operation for ListDatabases { fn handle_response( &self, - response: CommandResponse, + response: Response, _description: &StreamDescription, ) -> Result { - response.body::().map(|body| body.databases) + Ok(response.databases) } fn selection_criteria(&self) -> Option<&SelectionCriteria> { @@ -83,6 +87,6 @@ impl Operation for ListDatabases { } #[derive(Debug, Deserialize)] -struct ResponseBody { +pub(crate) struct Response { databases: Vec, } diff --git a/src/operation/list_databases/test.rs b/src/operation/list_databases/test.rs index aa89ae4c7..1d0895b11 100644 --- a/src/operation/list_databases/test.rs +++ b/src/operation/list_databases/test.rs @@ -1,9 +1,9 @@ use crate::{ bson::{doc, Bson, Document}, bson_util, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, error::ErrorKind, - operation::{ListDatabases, Operation}, + operation::{test::handle_response_test, ListDatabases, Operation}, options::ListDatabasesOptions, selection_criteria::ReadPreference, }; @@ -90,7 +90,6 @@ async fn build_with_options() { #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_success() { - let list_databases_op = ListDatabases::empty(); let total_size = 251658240; let databases: Vec = vec![ @@ -111,31 +110,23 @@ async fn handle_success() { }, ]; - let expected_values: Vec = databases.clone(); - - let response = CommandResponse::with_document(doc! { - "databases" : bson_util::to_bson_array(&databases), - "totalSize" : total_size, - "ok" : 1 - }); - - let actual_values = list_databases_op - .handle_response(response, &Default::default()) - .expect("supposed to succeed"); + let actual_values = handle_response_test( + &ListDatabases::empty(), + doc! { + "databases" : bson_util::to_bson_array(&databases), + "totalSize" : total_size, + "ok" : 1 + }, + ) + .expect("should succeed"); - assert_eq!(actual_values, expected_values); + assert_eq!(actual_values, databases); } #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_response_no_databases() { - let list_databases_op = ListDatabases::empty(); - - let response = CommandResponse::with_document(doc! { - "ok" : 1 - }); - - let result = list_databases_op.handle_response(response, &Default::default()); + let result = handle_response_test(&ListDatabases::empty(), doc! { "ok": 1 }); match result.map_err(|e| *e.kind) { Err(ErrorKind::InvalidResponse { .. }) => {} other => panic!("expected response error, but got {:?}", other), diff --git a/src/operation/mod.rs b/src/operation/mod.rs index 9be2fb2ae..ec530a5b2 100644 --- a/src/operation/mod.rs +++ b/src/operation/mod.rs @@ -17,16 +17,23 @@ mod list_databases; mod run_command; mod update; +#[cfg(test)] +mod test; + use std::{collections::VecDeque, fmt::Debug, ops::Deref}; -use serde::{Deserialize, Serialize}; +use bson::Timestamp; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::{ bson::{self, Bson, Document}, - cmap::{Command, CommandResponse, StreamDescription}, + bson_util, + client::ClusterTime, + cmap::{Command, RawCommandResponse, StreamDescription}, error::{ BulkWriteError, BulkWriteFailure, + CommandError, Error, ErrorKind, Result, @@ -62,6 +69,9 @@ pub(crate) trait Operation { /// The output type of this operation. type O; + /// The format of the command response from the server. + type Response: Response; + /// The name of the server side command associated with this operation. const NAME: &'static str; @@ -72,7 +82,7 @@ pub(crate) trait Operation { /// Interprets the server response to the command. fn handle_response( &self, - response: CommandResponse, + response: ::Body, description: &StreamDescription, ) -> Result; @@ -109,7 +119,7 @@ pub(crate) trait Operation { Retryability::None } - // Updates this operation as needed for a retry. + /// Updates this operation as needed for a retry. fn update_for_retry(&mut self) {} fn name(&self) -> &str { @@ -117,6 +127,132 @@ pub(crate) trait Operation { } } +/// Trait modeling the behavior of a command response to a server operation. +pub(crate) trait Response: Sized { + /// The command-specific portion of a command response. + /// This type will be passed to the associated operation's `handle_response` method. + type Body; + + /// Deserialize a response from the given raw response. + fn deserialize_response(raw: &RawCommandResponse) -> Result; + + /// The `ok` field of the response. + fn ok(&self) -> Option<&Bson>; + + /// Whether the command succeeeded or not (i.e. if this response is ok: 1). + fn is_success(&self) -> bool { + match self.ok() { + Some(b) => bson_util::get_int(&b) == Some(1), + None => false, + } + } + + /// The `clusterTime` field of the response. + fn cluster_time(&self) -> Option<&ClusterTime>; + + /// The `atClusterTime` field of the response. + fn at_cluster_time(&self) -> Option; + + /// Convert into the body of the response. + fn into_body(self) -> Self::Body; +} + +/// A response to a command with a body shaped deserialized to a `T`. +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub(crate) struct CommandResponse { + pub(crate) ok: Bson, + + #[serde(rename = "$clusterTime")] + pub(crate) cluster_time: Option, + + pub(crate) at_cluster_time: Option, + + #[serde(flatten)] + pub(crate) body: T, +} + +impl CommandResponse { + pub(crate) fn is_success(&self) -> bool { + ::is_success(self) + } +} + +impl Response for CommandResponse { + type Body = T; + + fn deserialize_response(raw: &RawCommandResponse) -> Result { + raw.body() + } + + fn ok(&self) -> Option<&Bson> { + Some(&self.ok) + } + + fn cluster_time(&self) -> Option<&ClusterTime> { + self.cluster_time.as_ref() + } + + fn at_cluster_time(&self) -> Option { + self.at_cluster_time + } + + fn into_body(self) -> Self::Body { + self.body + } +} + +/// A response to commands that return cursors. +#[derive(Debug)] +pub(crate) struct CursorResponse { + response: CommandResponse>, +} + +impl Response for CursorResponse { + type Body = CursorBody; + + fn deserialize_response(raw: &RawCommandResponse) -> Result { + Ok(Self { + response: raw.body()?, + }) + } + + fn ok(&self) -> Option<&Bson> { + self.response.ok() + } + + fn cluster_time(&self) -> Option<&ClusterTime> { + self.response.cluster_time() + } + + fn at_cluster_time(&self) -> Option { + self.response.body.cursor.at_cluster_time + } + + fn into_body(self) -> Self::Body { + self.response.body + } +} + +/// A response body useful for deserializing command errors. +#[derive(Deserialize, Debug)] +pub(crate) struct CommandErrorBody { + #[serde(rename = "errorLabels")] + pub(crate) error_labels: Option>, + + #[serde(flatten)] + pub(crate) command_error: CommandError, +} + +impl From for Error { + fn from(command_error_response: CommandErrorBody) -> Error { + Error::new( + ErrorKind::Command(command_error_response.command_error), + command_error_response.error_labels, + ) + } +} + /// Appends a serializable struct to the input document. /// The serializable struct MUST serialize to a Document, otherwise an error will be thrown. pub(crate) fn append_options( @@ -142,11 +278,11 @@ pub(crate) fn append_options( } #[derive(Deserialize, Debug)] -struct EmptyBody {} +pub(crate) struct EmptyBody {} /// Body of a write response that could possibly have a write concern error but not write errors. -#[derive(Deserialize)] -struct WriteConcernOnlyBody { +#[derive(Debug, Deserialize, Default, Clone)] +pub(crate) struct WriteConcernOnlyBody { #[serde(rename = "writeConcernError")] write_concern_error: Option, @@ -167,7 +303,7 @@ impl WriteConcernOnlyBody { } #[derive(Deserialize, Debug)] -struct WriteResponseBody { +pub(crate) struct WriteResponseBody { #[serde(flatten)] body: T, @@ -211,17 +347,24 @@ impl Deref for WriteResponseBody { } #[derive(Debug, Deserialize)] -struct CursorBody { - cursor: CursorInfo, +pub(crate) struct CursorBody { + cursor: CursorInfo, + + #[serde(flatten)] + write_concern_info: WriteConcernOnlyBody, } -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct CursorInfo { +#[derive(Debug, Deserialize, Clone)] +pub(crate) struct CursorInfo { pub(crate) id: i64, + pub(crate) ns: Namespace, + #[serde(rename = "firstBatch")] - pub(crate) first_batch: VecDeque, + pub(crate) first_batch: VecDeque, + + #[serde(rename = "atClusterTime")] + pub(crate) at_cluster_time: Option, } #[derive(Debug, PartialEq)] @@ -230,28 +373,3 @@ pub(crate) enum Retryability { Read, None, } - -#[cfg(test)] -mod test { - use crate::{ - operation::Operation, - options::{ReadPreference, SelectionCriteria}, - }; - - pub(crate) fn op_selection_criteria(constructor: F) - where - T: Operation, - F: Fn(Option) -> T, - { - let op = constructor(None); - assert_eq!(op.selection_criteria(), None); - - let read_pref: SelectionCriteria = ReadPreference::Secondary { - options: Default::default(), - } - .into(); - - let op = constructor(Some(read_pref.clone())); - assert_eq!(op.selection_criteria(), Some(&read_pref)); - } -} diff --git a/src/operation/run_command/mod.rs b/src/operation/run_command/mod.rs index a602a6e55..fbe8a72dc 100644 --- a/src/operation/run_command/mod.rs +++ b/src/operation/run_command/mod.rs @@ -1,11 +1,13 @@ #[cfg(test)] mod test; +use bson::{Bson, Timestamp}; + use super::Operation; use crate::{ bson::Document, - client::SESSIONS_UNSUPPORTED_COMMANDS, - cmap::{Command, CommandResponse, StreamDescription}, + client::{ClusterTime, SESSIONS_UNSUPPORTED_COMMANDS}, + cmap::{Command, RawCommandResponse, StreamDescription}, error::{ErrorKind, Result}, options::WriteConcern, selection_criteria::SelectionCriteria, @@ -45,6 +47,7 @@ impl RunCommand { impl Operation for RunCommand { type O = Document; + type Response = Response; // Since we can't actually specify a string statically here, we just put a descriptive string // that should fail loudly if accidentally passed to the server. @@ -66,10 +69,10 @@ impl Operation for RunCommand { fn handle_response( &self, - response: CommandResponse, + response: Document, _description: &StreamDescription, ) -> Result { - Ok(response.raw_response) + Ok(response) } fn selection_criteria(&self) -> Option<&SelectionCriteria> { @@ -88,3 +91,47 @@ impl Operation for RunCommand { .unwrap_or(false) } } + +#[derive(Debug)] +pub(crate) struct Response { + doc: Document, + cluster_time: Option, +} + +impl super::Response for Response { + type Body = Document; + + fn deserialize_response(raw: &RawCommandResponse) -> Result { + let doc: Document = raw.body()?; + + let cluster_time = doc + .get_document("$clusterTime") + .ok() + .and_then(|doc| bson::from_document(doc.clone()).ok()); + + Ok(Self { doc, cluster_time }) + } + + fn ok(&self) -> Option<&Bson> { + self.doc.get("ok") + } + + fn cluster_time(&self) -> Option<&ClusterTime> { + self.cluster_time.as_ref() + } + + fn at_cluster_time(&self) -> Option { + self.doc + .get_timestamp("atClusterTime") + .or_else(|_| { + self.doc + .get_document("cursor") + .and_then(|subdoc| subdoc.get_timestamp("atClusterTime")) + }) + .ok() + } + + fn into_body(self) -> Self::Body { + self.doc + } +} diff --git a/src/operation/run_command/test.rs b/src/operation/run_command/test.rs index a65b29b60..082ef3367 100644 --- a/src/operation/run_command/test.rs +++ b/src/operation/run_command/test.rs @@ -1,8 +1,11 @@ +use bson::Timestamp; + use super::RunCommand; use crate::{ bson::doc, - cmap::{CommandResponse, StreamDescription}, - operation::Operation, + client::ClusterTime, + cmap::{RawCommandResponse, StreamDescription}, + operation::{test::handle_response_test, Operation, Response}, }; #[cfg_attr(feature = "tokio-runtime", tokio::test)] @@ -26,17 +29,50 @@ async fn build() { #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] -async fn no_error_ok_0() { - let op = RunCommand::new("foo".into(), doc! { "isMaster": 1 }, None).unwrap(); - assert!(op.selection_criteria().is_none()); +async fn handle_success() { + let op = RunCommand::new("foo".into(), doc! { "hello": 1 }, None).unwrap(); + + let doc = doc! { + "ok": 1, + "some": "field", + "other": true, + "$clusterTime": { + "clusterTime": Timestamp { + time: 123, + increment: 345, + }, + "signature": {} + } + }; + let result_doc = handle_response_test(&op, doc.clone()).unwrap(); + assert_eq!(result_doc, doc); +} - let command_response = CommandResponse::with_document(doc! { - "ok": 0 - }); +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn response() { + let cluster_timestamp = Timestamp { + time: 123, + increment: 345, + }; + let doc = doc! { + "ok": 1, + "some": "field", + "other": true, + "$clusterTime": { + "clusterTime": cluster_timestamp, + "signature": {} + } + }; + let raw = RawCommandResponse::with_document(doc).unwrap(); + let response = ::Response::deserialize_response(&raw).unwrap(); + assert!(response.is_success()); assert_eq!( - op.handle_response(command_response, &Default::default()) - .ok(), - Some(doc! { "ok": 0 }) + response.cluster_time(), + Some(&ClusterTime { + cluster_time: cluster_timestamp, + signature: doc! {}, + }) ); } diff --git a/src/operation/test.rs b/src/operation/test.rs new file mode 100644 index 000000000..5f4b7d22c --- /dev/null +++ b/src/operation/test.rs @@ -0,0 +1,168 @@ +use bson::{doc, Document, Timestamp}; +use serde::Deserialize; + +use crate::{ + client::ClusterTime, + cmap::{RawCommandResponse, StreamDescription}, + error::{Result, TRANSIENT_TRANSACTION_ERROR}, + operation::{CommandErrorBody, CommandResponse, Operation, Response}, + options::{ReadPreference, SelectionCriteria}, +}; + +pub(crate) fn handle_response_test(op: &T, response_doc: Document) -> Result { + let raw = RawCommandResponse::with_document(response_doc).unwrap(); + let response = T::Response::deserialize_response(&raw)?; + op.handle_response(response.into_body(), &StreamDescription::new_testing()) +} + +pub(crate) fn handle_response_test_with_wire_version( + op: &T, + response_doc: Document, + wire_version: i32, +) -> Result { + let raw = RawCommandResponse::with_document(response_doc).unwrap(); + let response = T::Response::deserialize_response(&raw)?; + op.handle_response( + response.into_body(), + &StreamDescription::with_wire_version(wire_version), + ) +} + +pub(crate) fn op_selection_criteria(constructor: F) +where + T: Operation, + F: Fn(Option) -> T, +{ + let op = constructor(None); + assert_eq!(op.selection_criteria(), None); + + let read_pref: SelectionCriteria = ReadPreference::Secondary { + options: Default::default(), + } + .into(); + + let op = constructor(Some(read_pref.clone())); + assert_eq!(op.selection_criteria(), Some(&read_pref)); +} + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn response_success() { + let cluster_timestamp = Timestamp { + time: 123, + increment: 345, + }; + let doc = doc! { + "ok": 1, + "some": "field", + "other": true, + "$clusterTime": { + "clusterTime": cluster_timestamp, + "signature": {} + } + }; + let raw = RawCommandResponse::with_document(doc.clone()).unwrap(); + let response = CommandResponse::::deserialize_response(&raw).unwrap(); + + assert!(response.is_success()); + assert_eq!( + response.cluster_time(), + Some(&ClusterTime { + cluster_time: cluster_timestamp, + signature: doc! {}, + }) + ); + assert_eq!( + response.into_body(), + doc! { "some": "field", "other": true } + ); + + #[derive(Deserialize, Debug, PartialEq)] + struct Body { + some: String, + #[serde(rename = "other")] + o: bool, + #[serde(default)] + default: Option, + } + + let raw = RawCommandResponse::with_document(doc).unwrap(); + let response = CommandResponse::::deserialize_response(&raw).unwrap(); + + assert!(response.is_success()); + assert_eq!( + response.cluster_time(), + Some(&ClusterTime { + cluster_time: cluster_timestamp, + signature: doc! {}, + }) + ); + assert_eq!( + response.into_body(), + Body { + some: "field".to_string(), + o: true, + default: None, + } + ); +} + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn response_failure() { + let cluster_timestamp = Timestamp { + time: 123, + increment: 345, + }; + let doc = doc! { + "ok": 0, + "code": 123, + "codeName": "name", + "errmsg": "some message", + "errorLabels": [TRANSIENT_TRANSACTION_ERROR], + "$clusterTime": { + "clusterTime": cluster_timestamp, + "signature": {} + } + }; + let raw = RawCommandResponse::with_document(doc.clone()).unwrap(); + let response = CommandResponse::::deserialize_response(&raw).unwrap(); + + assert!(!response.is_success()); + assert_eq!( + response.cluster_time(), + Some(&ClusterTime { + cluster_time: cluster_timestamp, + signature: doc! {}, + }) + ); + assert_eq!( + response.into_body(), + doc! { + "code": 123, + "codeName": "name", + "errmsg": "some message", + "errorLabels": [TRANSIENT_TRANSACTION_ERROR], + } + ); + + let raw = RawCommandResponse::with_document(doc).unwrap(); + let response = CommandResponse::::deserialize_response(&raw).unwrap(); + + assert!(!response.is_success()); + assert_eq!( + response.cluster_time(), + Some(&ClusterTime { + cluster_time: cluster_timestamp, + signature: doc! {}, + }) + ); + let command_error: CommandErrorBody = response.into_body(); + assert_eq!(command_error.command_error.code, 123); + assert_eq!(command_error.command_error.code_name, "name"); + assert_eq!(command_error.command_error.message, "some message"); + assert_eq!( + command_error.error_labels, + Some(vec![TRANSIENT_TRANSACTION_ERROR.to_string()]) + ); +} diff --git a/src/operation/update/mod.rs b/src/operation/update/mod.rs index 882f1026c..ebe0bd961 100644 --- a/src/operation/update/mod.rs +++ b/src/operation/update/mod.rs @@ -6,7 +6,7 @@ use serde::Deserialize; use crate::{ bson::{doc, Bson, Document}, bson_util, - cmap::{Command, CommandResponse, StreamDescription}, + cmap::{Command, StreamDescription}, error::{convert_bulk_errors, Result}, operation::{Operation, Retryability, WriteResponseBody}, options::{UpdateModifications, UpdateOptions, WriteConcern}, @@ -14,6 +14,8 @@ use crate::{ Namespace, }; +use super::CommandResponse; + #[derive(Debug)] pub(crate) struct Update { ns: Namespace, @@ -57,6 +59,8 @@ impl Update { impl Operation for Update { type O = UpdateResult; + type Response = CommandResponse>; + const NAME: &'static str = "update"; fn build(&mut self, _description: &StreamDescription) -> Result { @@ -111,21 +115,20 @@ impl Operation for Update { fn handle_response( &self, - response: CommandResponse, + response: WriteResponseBody, _description: &StreamDescription, ) -> Result { - let body: WriteResponseBody = response.body()?; - body.validate().map_err(convert_bulk_errors)?; + response.validate().map_err(convert_bulk_errors)?; - let modified_count = body.n_modified; - let upserted_id = body + let modified_count = response.n_modified; + let upserted_id = response .upserted .as_ref() .and_then(|v| v.first()) .and_then(|doc| doc.get("_id")) .map(Clone::clone); - let matched_count = if upserted_id.is_some() { 0 } else { body.n }; + let matched_count = if upserted_id.is_some() { 0 } else { response.n }; Ok(UpdateResult { matched_count, @@ -150,7 +153,7 @@ impl Operation for Update { } #[derive(Deserialize)] -struct UpdateBody { +pub(crate) struct UpdateBody { #[serde(rename = "nModified")] n_modified: u64, upserted: Option>, diff --git a/src/operation/update/test.rs b/src/operation/update/test.rs index 6142cfdee..6919f0cd1 100644 --- a/src/operation/update/test.rs +++ b/src/operation/update/test.rs @@ -3,11 +3,11 @@ use pretty_assertions::assert_eq; use crate::{ bson::{doc, Bson}, bson_util, - cmap::{CommandResponse, StreamDescription}, + cmap::StreamDescription, coll::options::Hint, concern::{Acknowledgment, WriteConcern}, error::{ErrorKind, WriteConcernError, WriteError, WriteFailure}, - operation::{Operation, Update}, + operation::{test::handle_response_test, Operation, Update}, options::{UpdateModifications, UpdateOptions}, Namespace, }; @@ -158,19 +158,16 @@ async fn build_many() { async fn handle_success() { let op = Update::empty(); - let ok_response = CommandResponse::with_document(doc! { + let ok_response = doc! { "ok": 1.0, "n": 3, "nModified": 1, "upserted": [ { "index": 0, "_id": 1 } ] - }); - - let ok_result = op.handle_response(ok_response, &Default::default()); - assert!(ok_result.is_ok()); + }; - let update_result = ok_result.unwrap(); + let update_result = handle_response_test(&op, ok_response).unwrap(); assert_eq!(update_result.matched_count, 0); assert_eq!(update_result.modified_count, 1); assert_eq!(update_result.upserted_id, Some(Bson::Int32(1))); @@ -181,38 +178,24 @@ async fn handle_success() { async fn handle_success_no_upsert() { let op = Update::empty(); - let ok_response = CommandResponse::with_document(doc! { + let ok_response = doc! { "ok": 1.0, "n": 5, "nModified": 2 - }); - - let ok_result = op.handle_response(ok_response, &Default::default()); - assert!(ok_result.is_ok()); + }; - let update_result = ok_result.unwrap(); + let update_result = handle_response_test(&op, ok_response).unwrap(); assert_eq!(update_result.matched_count, 5); assert_eq!(update_result.modified_count, 2); assert_eq!(update_result.upserted_id, None); } -#[cfg_attr(feature = "tokio-runtime", tokio::test)] -#[cfg_attr(feature = "async-std-runtime", async_std::test)] -async fn handle_invalid_response() { - let op = Update::empty(); - - let invalid_response = CommandResponse::with_document(doc! { "ok": 1.0, "asdfadsf": 123123 }); - assert!(op - .handle_response(invalid_response, &Default::default()) - .is_err()); -} - #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn handle_write_failure() { let op = Update::empty(); - let write_error_response = CommandResponse::with_document(doc! { + let write_error_response = doc! { "ok": 1.0, "n": 12, "nModified": 0, @@ -223,10 +206,10 @@ async fn handle_write_failure() { "errmsg": "my error string" } ] - }); - let write_error_result = op.handle_response(write_error_response, &Default::default()); - assert!(write_error_result.is_err()); - match *write_error_result.unwrap_err().kind { + }; + + let write_error = handle_response_test(&op, write_error_response).unwrap_err(); + match *write_error.kind { ErrorKind::Write(WriteFailure::WriteError(ref error)) => { let expected_err = WriteError { code: 1234, @@ -245,7 +228,7 @@ async fn handle_write_failure() { async fn handle_write_concern_failure() { let op = Update::empty(); - let wc_error_response = CommandResponse::with_document(doc! { + let wc_error_response = doc! { "ok": 1.0, "n": 0, "nModified": 0, @@ -261,12 +244,10 @@ async fn handle_write_concern_failure() { } } } - }); - - let wc_error_result = op.handle_response(wc_error_response, &Default::default()); - assert!(wc_error_result.is_err()); + }; - match *wc_error_result.unwrap_err().kind { + let wc_error = handle_response_test(&op, wc_error_response).unwrap_err(); + match *wc_error.kind { ErrorKind::Write(WriteFailure::WriteConcernError(ref wc_error)) => { let expected_wc_err = WriteConcernError { code: 456, diff --git a/src/results.rs b/src/results.rs index e2d7292d9..2521e3181 100644 --- a/src/results.rs +++ b/src/results.rs @@ -76,8 +76,8 @@ pub struct DeleteResult { } #[derive(Debug, Clone)] -pub(crate) struct GetMoreResult { - pub(crate) batch: VecDeque, +pub(crate) struct GetMoreResult { + pub(crate) batch: VecDeque, pub(crate) exhausted: bool, } diff --git a/src/runtime/async_read_ext.rs b/src/runtime/async_read_ext.rs index 05d05705f..100c7bb08 100644 --- a/src/runtime/async_read_ext.rs +++ b/src/runtime/async_read_ext.rs @@ -1,3 +1,5 @@ +use std::io::Read; + use async_trait::async_trait; use futures_io::AsyncRead; @@ -28,3 +30,27 @@ pub(crate) trait AsyncLittleEndianRead: Unpin + futures_util::AsyncReadExt { } impl AsyncLittleEndianRead for R {} + +pub(crate) trait SyncLittleEndianRead: Read { + /// Read an `i32` in little-endian order. + fn read_i32(&mut self) -> Result { + let mut buf: [u8; 4] = [0; 4]; + self.read_exact(&mut buf)?; + Ok(i32::from_le_bytes(buf)) + } + + /// Read a `u32` in little-endian orer. + fn read_u32(&mut self) -> Result { + let mut buf: [u8; 4] = [0; 4]; + self.read_exact(&mut buf)?; + Ok(u32::from_le_bytes(buf)) + } + + fn read_u8(&mut self) -> Result { + let mut buf: [u8; 1] = [0; 1]; + self.read_exact(&mut buf)?; + Ok(buf[0]) + } +} + +impl SyncLittleEndianRead for R {} diff --git a/src/runtime/async_write_ext.rs b/src/runtime/async_write_ext.rs index 216a0f006..af5e0aea8 100644 --- a/src/runtime/async_write_ext.rs +++ b/src/runtime/async_write_ext.rs @@ -1,3 +1,5 @@ +use std::io::Write; + use async_trait::async_trait; use futures_io::AsyncWrite; @@ -25,3 +27,25 @@ pub(crate) trait AsyncLittleEndianWrite: Unpin + futures_util::AsyncWriteExt { } impl AsyncLittleEndianWrite for W {} + +/// Trait providing helpers that write various integer types in little-endian order. +pub(crate) trait SyncLittleEndianWrite: Write { + /// Write an `i32` in little-endian order. + fn write_i32(&mut self, n: i32) -> Result<()> { + self.write_all(&n.to_le_bytes())?; + Ok(()) + } + + /// Write a `u32` in little-endian order. + fn write_u32(&mut self, n: u32) -> Result<()> { + self.write_all(&n.to_le_bytes())?; + Ok(()) + } + + fn write_u8(&mut self, n: u8) -> Result<()> { + self.write_all(&[n])?; + Ok(()) + } +} + +impl SyncLittleEndianWrite for W {} diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 733934585..3a22ce7e7 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -12,8 +12,8 @@ use std::{future::Future, net::SocketAddr, time::Duration}; pub(crate) use self::{ acknowledged_message::AcknowledgedMessage, - async_read_ext::AsyncLittleEndianRead, - async_write_ext::AsyncLittleEndianWrite, + async_read_ext::{AsyncLittleEndianRead, SyncLittleEndianRead}, + async_write_ext::{AsyncLittleEndianWrite, SyncLittleEndianWrite}, join_handle::AsyncJoinHandle, resolver::AsyncResolver, stream::AsyncStream, diff --git a/src/sdam/description/topology/server_selection/test/mod.rs b/src/sdam/description/topology/server_selection/test/mod.rs index fcc563a25..e77932f4c 100644 --- a/src/sdam/description/topology/server_selection/test/mod.rs +++ b/src/sdam/description/topology/server_selection/test/mod.rs @@ -86,22 +86,24 @@ impl TestServerDescription { None => return None, }; - let mut command_response = is_master_response_from_server_type(server_type); - command_response.tags = self.tags; - command_response.last_write = self.last_write.map(|last_write| LastWrite { - last_write_date: DateTime::from_millis(last_write.last_write_date), + let server_address = ServerAddress::parse(self.address).ok()?; + let tags = self.tags; + let last_write = self.last_write; + let avg_rtt_ms = self.avg_rtt_ms; + let reply = is_master_response_from_server_type(server_type).map(|mut command_response| { + command_response.tags = tags; + command_response.last_write = last_write.map(|last_write| LastWrite { + last_write_date: DateTime::from_millis(last_write.last_write_date), + }); + Ok(IsMasterReply { + server_address: server_address.clone(), + command_response, + round_trip_time: avg_rtt_ms.map(f64_ms_as_duration), + cluster_time: None, + }) }); - let is_master = IsMasterReply { - command_response, - round_trip_time: self.avg_rtt_ms.map(f64_ms_as_duration), - cluster_time: None, - }; - - let mut server_desc = ServerDescription::new( - ServerAddress::parse(&self.address).unwrap(), - Some(Ok(is_master)), - ); + let mut server_desc = ServerDescription::new(server_address, reply); server_desc.last_update_time = self .last_update_time .map(|i| DateTime::from_millis(i.into())); @@ -150,45 +152,37 @@ impl TestServerType { } } -fn is_master_response_from_server_type(server_type: ServerType) -> IsMasterCommandResponse { +fn is_master_response_from_server_type(server_type: ServerType) -> Option { let mut response = IsMasterCommandResponse::default(); match server_type { ServerType::Unknown => { - response.ok = Some(0.0); + return None; } ServerType::Mongos => { - response.ok = Some(1.0); response.msg = Some("isdbgrid".into()); } ServerType::RsPrimary => { - response.ok = Some(1.0); response.set_name = Some("foo".into()); response.is_writable_primary = Some(true); } ServerType::RsOther => { - response.ok = Some(1.0); response.set_name = Some("foo".into()); response.hidden = Some(true); } ServerType::RsSecondary => { - response.ok = Some(1.0); response.set_name = Some("foo".into()); response.secondary = Some(true); } ServerType::RsArbiter => { - response.ok = Some(1.0); response.set_name = Some("foo".into()); response.arbiter_only = Some(true); } ServerType::RsGhost => { - response.ok = Some(1.0); response.is_replica_set = Some(true); } - ServerType::Standalone => { - response.ok = Some(1.0); - } + ServerType::Standalone => {} }; - response + Some(response) } diff --git a/src/sdam/description/topology/test/sdam.rs b/src/sdam/description/topology/test/sdam.rs index e20889b93..0497d516d 100644 --- a/src/sdam/description/topology/test/sdam.rs +++ b/src/sdam/description/topology/test/sdam.rs @@ -43,7 +43,7 @@ pub struct Phase { #[derive(Debug, Deserialize)] pub struct Response(String, TestIsMasterCommandResponse); -#[derive(Debug, Clone, Default, Deserialize)] +#[derive(Debug, Clone, Default, Deserialize, PartialEq)] #[serde(rename_all = "camelCase")] pub(crate) struct TestIsMasterCommandResponse { pub is_writable_primary: Option, @@ -80,7 +80,6 @@ impl From for IsMasterCommandResponse { IsMasterCommandResponse { is_writable_primary: test.is_writable_primary, is_master: test.is_master, - ok: test.ok, hosts: test.hosts, passives: test.passives, arbiters: test.arbiters, @@ -107,12 +106,6 @@ impl From for IsMasterCommandResponse { } } -impl PartialEq for TestIsMasterCommandResponse { - fn eq(&self, other: &Self) -> bool { - IsMasterCommandResponse::from(self.clone()) == other.clone().into() - } -} - #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ApplicationError { @@ -227,16 +220,6 @@ async fn run_test(test_file: TestFile) { for (i, phase) in test_file.phases.into_iter().enumerate() { for Response(address, command_response) in phase.responses { - let is_master_reply = if command_response == Default::default() { - Err("dummy error".to_string()) - } else { - Ok(IsMasterReply { - command_response: command_response.into(), - round_trip_time: Some(Duration::from_millis(1234)), // Doesn't matter for tests. - cluster_time: None, - }) - }; - let address = ServerAddress::parse(&address).unwrap_or_else(|_| { panic!( "{}: couldn't parse address \"{:?}\"", @@ -245,11 +228,37 @@ async fn run_test(test_file: TestFile) { ) }); + let is_master_reply = if command_response.ok != Some(1.0) { + Err(Error::from(ErrorKind::Command(CommandError { + code: 1234, + code_name: "dummy error".to_string(), + message: "dummy".to_string(), + }))) + } else if command_response == Default::default() { + Err(Error::from(ErrorKind::Io(Arc::new( + std::io::ErrorKind::BrokenPipe.into(), + )))) + } else { + Ok(IsMasterReply { + server_address: address.clone(), + command_response: command_response.into(), + round_trip_time: Some(Duration::from_millis(1234)), // Doesn't matter for tests. + cluster_time: None, + }) + }; + // only update server if we have strong reference to it like the monitors do if let Some(server) = servers.get(&address).and_then(|s| s.upgrade()) { - let new_sd = ServerDescription::new(address.clone(), Some(is_master_reply)); - if topology.update(&server, new_sd).await { - servers = topology.get_servers().await + match is_master_reply { + Ok(reply) => { + let new_sd = ServerDescription::new(address.clone(), Some(Ok(reply))); + if topology.update(&server, new_sd).await { + servers = topology.get_servers().await + } + } + Err(e) => { + topology.handle_monitor_error(e, &server).await; + } } } } @@ -528,6 +537,7 @@ async fn pool_cleared_error_does_not_mark_unknown() { ServerDescription::new( address.clone(), Some(Ok(IsMasterReply { + server_address: address.clone(), command_response: heartbeat_response, round_trip_time: Some(Duration::from_secs(1)), cluster_time: None, diff --git a/src/sync/coll.rs b/src/sync/coll.rs index a0343bf38..5259e96fc 100644 --- a/src/sync/coll.rs +++ b/src/sync/coll.rs @@ -390,7 +390,7 @@ impl Collection { impl Collection where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { /// Finds the documents in the collection matching `filter`. pub fn find( diff --git a/src/sync/cursor.rs b/src/sync/cursor.rs index bd48c34c3..dc3d0fdec 100644 --- a/src/sync/cursor.rs +++ b/src/sync/cursor.rs @@ -71,14 +71,14 @@ use crate::{ #[derive(Debug)] pub struct Cursor where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { async_cursor: AsyncCursor, } impl Cursor where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { pub(crate) fn new(async_cursor: AsyncCursor) -> Self { Self { async_cursor } @@ -87,7 +87,7 @@ where impl Iterator for Cursor where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { type Item = Result; @@ -125,7 +125,7 @@ where impl SessionCursor where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { pub(crate) fn new(async_cursor: AsyncSessionCursor) -> Self { Self { async_cursor } @@ -174,14 +174,14 @@ where /// This updates the buffer of the parent `SessionCursor` when dropped. pub struct SessionCursorIter<'cursor, 'session, T = Document> where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { async_stream: SessionCursorStream<'cursor, 'session, T>, } impl Iterator for SessionCursorIter<'_, '_, T> where - T: DeserializeOwned + Unpin, + T: DeserializeOwned + Unpin + Send + Sync, { type Item = Result; diff --git a/src/test/coll.rs b/src/test/coll.rs index f1a290705..4f9f1b4c1 100644 --- a/src/test/coll.rs +++ b/src/test/coll.rs @@ -796,7 +796,7 @@ async fn typed_insert_one() { async fn insert_one_and_find(coll: &Collection, insert_data: T) where - T: Serialize + DeserializeOwned + Clone + PartialEq + Debug + Unpin, + T: Serialize + DeserializeOwned + Clone + PartialEq + Debug + Unpin + Send + Sync, { coll.insert_one(insert_data.clone(), None).await.unwrap(); let result = coll diff --git a/src/test/spec/write_error.rs b/src/test/spec/write_error.rs index 0fc294904..a91f3b170 100644 --- a/src/test/spec/write_error.rs +++ b/src/test/spec/write_error.rs @@ -11,7 +11,9 @@ use crate::{ async fn details() { let _guard = LOCK.run_concurrently().await; let client = EventClient::new().await; - if client.server_version_lt(5, 0) { + + // TODO: RUST-894 unskip once SERVER-58399 is fixed. + if client.server_version_lt(5, 0) || client.is_sharded() { return; }