From f8ba83bbfe84b01b088b4a6a2708cee1db9203d3 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Mon, 25 Oct 2021 13:29:46 -0400 Subject: [PATCH 01/21] implement deserialize for RawBson, RawDocument, etc. --- src/de/mod.rs | 49 ++- src/de/raw.rs | 781 ++++++++++++++++++++++++++++++++++++---- src/de/serde.rs | 37 +- src/raw/array.rs | 17 + src/raw/bson.rs | 326 ++++++++++++++++- src/raw/document.rs | 23 +- src/raw/document_buf.rs | 18 +- src/raw/mod.rs | 4 + 8 files changed, 1165 insertions(+), 90 deletions(-) diff --git a/src/de/mod.rs b/src/de/mod.rs index 8a20a342..4a24d4c3 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -35,6 +35,7 @@ use std::io::Read; use crate::{ bson::{Array, Binary, Bson, DbPointer, Document, JavaScriptCodeWithScope, Regex, Timestamp}, oid::{self, ObjectId}, + raw::RawBinary, ser::write_i32, spec::{self, BinarySubtype}, Decimal128, @@ -45,7 +46,7 @@ use ::serde::{ Deserialize, }; -pub(crate) use self::serde::BsonVisitor; +pub(crate) use self::serde::{convert_unsigned_to_signed_raw, BsonVisitor}; pub(crate) const MAX_BSON_SIZE: i32 = 16 * 1024 * 1024; pub(crate) const MIN_BSON_DOCUMENT_SIZE: i32 = 4 + 1; // 4 bytes for length, one byte for null terminator @@ -277,6 +278,7 @@ impl Binary { Self::from_reader_with_len_and_payload(reader, len, subtype) } + // TODO: RUST-976: call through to the RawBinary version of this instead of duplicating code pub(crate) fn from_reader_with_len_and_payload( mut reader: R, mut len: i32, @@ -317,6 +319,51 @@ impl Binary { } } +impl<'a> RawBinary<'a> { + pub(crate) fn from_slice_with_len_and_payload( + mut bytes: &'a [u8], + mut len: i32, + subtype: BinarySubtype, + ) -> Result { + if !(0..=MAX_BSON_SIZE).contains(&len) { + return Err(Error::invalid_length( + len as usize, + &format!("binary length must be between 0 and {}", MAX_BSON_SIZE).as_str(), + )); + } else if len as usize > bytes.len() { + return Err(Error::invalid_length( + len as usize, + &format!( + "binary length {} exceeds buffer length {}", + len, + bytes.len() + ) + .as_str(), + )); + } + + // Skip length data in old binary. + if let BinarySubtype::BinaryOld = subtype { + let data_len = read_i32(&mut bytes)?; + println!("data_len={}", data_len); + + if data_len + 4 != len { + return Err(Error::invalid_length( + data_len as usize, + &"0x02 length did not match top level binary length", + )); + } + + len -= 4; + } + + Ok(Self { + bytes: &bytes[0..len as usize], + subtype, + }) + } +} + impl DbPointer { pub(crate) fn from_reader(mut reader: R) -> Result { let ns = read_string(&mut reader, false)?; diff --git a/src/de/raw.rs b/src/de/raw.rs index c48ba3c7..0139f5bc 100644 --- a/src/de/raw.rs +++ b/src/de/raw.rs @@ -1,16 +1,18 @@ use std::{ borrow::Cow, io::{ErrorKind, Read}, + sync::Arc, }; use serde::{ - de::{EnumAccess, Error as SerdeError, IntoDeserializer, VariantAccess}, + de::{EnumAccess, Error as SerdeError, IntoDeserializer, MapAccess, VariantAccess}, forward_to_deserialize_any, Deserializer as SerdeDeserializer, }; use crate::{ oid::ObjectId, + raw::{RawBinary, RawBson, RAW_ARRAY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, spec::{BinarySubtype, ElementType}, uuid::UUID_NEWTYPE_NAME, Binary, @@ -19,6 +21,7 @@ use crate::{ DbPointer, Decimal128, JavaScriptCodeWithScope, + RawDocument, Regex, Timestamp, }; @@ -37,6 +40,13 @@ use super::{ }; use crate::de::serde::MapDeserializer; +#[derive(Debug, Clone, Copy)] +enum DeserializerHint { + None, + BinarySubtype(BinarySubtype), + RawBson, +} + /// Deserializer used to parse and deserialize raw BSON bytes. pub(crate) struct Deserializer<'de> { bytes: BsonBuf<'de>, @@ -50,6 +60,11 @@ pub(crate) struct Deserializer<'de> { current_type: ElementType, } +enum DocumentType { + Array, + EmbeddedDocument, +} + impl<'de> Deserializer<'de> { pub(crate) fn new(buf: &'de [u8], utf8_lossy: bool) -> Self { Self { @@ -95,16 +110,52 @@ impl<'de> Deserializer<'de> { self.bytes.read_str() } - fn deserialize_document_key(&mut self) -> Result> { + fn deserialize_cstr(&mut self) -> Result> { self.bytes.read_cstr() } + fn deserialize_document( + &mut self, + visitor: V, + hint: DeserializerHint, + document_type: DocumentType, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + let is_array = match document_type { + DocumentType::Array => true, + DocumentType::EmbeddedDocument => false, + }; + + match hint { + DeserializerHint::RawBson => { + let mut len = self.bytes.slice(4)?; + let len = read_i32(&mut len)?; + + let doc = RawDocument::new(self.bytes.read_slice(len as usize)?) + .map_err(Error::custom)?; + + let access = if is_array { + RawDocumentAccess::for_array(doc) + } else { + RawDocumentAccess::new(doc) + }; + + visitor.visit_map(access) + } + _ if is_array => self.access_document(|access| visitor.visit_seq(access)), + _ => self.access_document(|access| visitor.visit_map(access)), + } + } + /// Construct a `DocumentAccess` and pass it into the provided closure, returning the /// result of the closure if no other errors are encountered. - fn deserialize_document(&mut self, f: F) -> Result + fn access_document(&mut self, f: F) -> Result where F: FnOnce(DocumentAccess<'_, 'de>) -> Result, { + println!("in access"); let mut length_remaining = read_i32(&mut self.bytes)? - 4; let out = f(DocumentAccess { root_deserializer: self, @@ -132,15 +183,11 @@ impl<'de> Deserializer<'de> { Ok(Some(element_type)) } - fn deserialize_next( - &mut self, - visitor: V, - binary_subtype_hint: Option, - ) -> Result + fn deserialize_next(&mut self, visitor: V, hint: DeserializerHint) -> Result where V: serde::de::Visitor<'de>, { - if let Some(expected_st) = binary_subtype_hint { + if let DeserializerHint::BinarySubtype(expected_st) = hint { if self.current_type != ElementType::Binary { return Err(Error::custom(format!( "expected Binary with subtype {:?}, instead got {:?}", @@ -161,12 +208,12 @@ impl<'de> Deserializer<'de> { ElementType::Null => visitor.visit_unit(), ElementType::ObjectId => { let oid = ObjectId::from_reader(&mut self.bytes)?; - visitor.visit_map(ObjectIdAccess::new(oid)) + visitor.visit_map(ObjectIdAccess::new(oid, hint)) } ElementType::EmbeddedDocument => { - self.deserialize_document(|access| visitor.visit_map(access)) + self.deserialize_document(visitor, hint, DocumentType::EmbeddedDocument) } - ElementType::Array => self.deserialize_document(|access| visitor.visit_seq(access)), + ElementType::Array => self.deserialize_document(visitor, hint, DocumentType::Array), ElementType::Binary => { let len = read_i32(&mut self.bytes)?; if !(0..=MAX_BSON_SIZE).contains(&len) { @@ -177,7 +224,7 @@ impl<'de> Deserializer<'de> { } let subtype = BinarySubtype::from(read_u8(&mut self.bytes)?); - if let Some(expected_subtype) = binary_subtype_hint { + if let DeserializerHint::BinarySubtype(expected_subtype) = hint { if subtype != expected_subtype { return Err(Error::custom(format!( "expected binary subtype {:?} instead got {:?}", @@ -190,6 +237,17 @@ impl<'de> Deserializer<'de> { BinarySubtype::Generic => { visitor.visit_borrowed_bytes(self.bytes.read_slice(len as usize)?) } + _ if matches!(hint, DeserializerHint::RawBson) => { + let binary = RawBinary::from_slice_with_len_and_payload( + self.bytes.read_slice(len as usize)?, + len, + subtype, + )?; + let mut d = BinaryDeserializer::borrowed(binary); + visitor.visit_map(BinaryAccess { + deserializer: &mut d, + }) + } _ => { let binary = Binary::from_reader_with_len_and_payload( &mut self.bytes, @@ -204,45 +262,58 @@ impl<'de> Deserializer<'de> { } } ElementType::Undefined => { - let doc = Bson::Undefined.into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + visitor.visit_map(RawBsonAccess::new("$undefined", BsonContent::Boolean(true))) } ElementType::DateTime => { let dti = read_i64(&mut self.bytes)?; let dt = DateTime::from_millis(dti); - let mut d = DateTimeDeserializer::new(dt); + let mut d = DateTimeDeserializer::new(dt, hint); visitor.visit_map(DateTimeAccess { deserializer: &mut d, }) } ElementType::RegularExpression => { - let doc = Bson::RegularExpression(Regex::from_reader(&mut self.bytes)?) - .into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + let mut de = RegexDeserializer::new(&mut *self); + visitor.visit_map(RegexAccess::new(&mut de)) } ElementType::DbPointer => { - let doc = Bson::DbPointer(DbPointer::from_reader(&mut self.bytes)?) - .into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + let mut de = DbPointerDeserializer::new(&mut *self, hint); + visitor.visit_map(DbPointerAccess::new(&mut de)) } ElementType::JavaScriptCode => { let utf8_lossy = self.bytes.utf8_lossy; - let code = read_string(&mut self.bytes, utf8_lossy)?; - let doc = Bson::JavaScriptCode(code).into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + + match hint { + DeserializerHint::RawBson => visitor.visit_map(RawBsonAccess::new( + "$code", + BsonContent::Str(self.bytes.read_borrowed_str()?), + )), + _ => { + let code = read_string(&mut self.bytes, utf8_lossy)?; + let doc = Bson::JavaScriptCode(code).into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + } } ElementType::JavaScriptCodeWithScope => { - let utf8_lossy = self.bytes.utf8_lossy; - let code_w_scope = - JavaScriptCodeWithScope::from_reader(&mut self.bytes, utf8_lossy)?; - let doc = Bson::JavaScriptCodeWithScope(code_w_scope).into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + let _len = read_i32(&mut self.bytes)?; + let mut de = CodeWithScopeDeserializer::new(&mut *self, hint); + visitor.visit_map(CodeWithScopeAccess::new(&mut de)) } ElementType::Symbol => { let utf8_lossy = self.bytes.utf8_lossy; - let symbol = read_string(&mut self.bytes, utf8_lossy)?; - let doc = Bson::Symbol(symbol).into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + + match hint { + DeserializerHint::RawBson => visitor.visit_map(RawBsonAccess::new( + "$symbol", + BsonContent::Str(self.bytes.read_borrowed_str()?), + )), + _ => { + let symbol = read_string(&mut self.bytes, utf8_lossy)?; + let doc = Bson::Symbol(symbol).into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + } } ElementType::Timestamp => { let ts = Timestamp::from_reader(&mut self.bytes)?; @@ -256,12 +327,10 @@ impl<'de> Deserializer<'de> { visitor.visit_map(Decimal128Access::new(d128)) } ElementType::MaxKey => { - let doc = Bson::MaxKey.into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + visitor.visit_map(RawBsonAccess::new("$maxKey", BsonContent::Int32(1))) } ElementType::MinKey => { - let doc = Bson::MinKey.into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + visitor.visit_map(RawBsonAccess::new("$minKey", BsonContent::Int32(1))) } } } @@ -275,7 +344,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: serde::de::Visitor<'de>, { - self.deserialize_next(visitor, None) + self.deserialize_next(visitor, DeserializerHint::None) } #[inline] @@ -301,7 +370,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { match self.current_type { ElementType::String => visitor.visit_enum(self.deserialize_str()?.into_deserializer()), ElementType::EmbeddedDocument => { - self.deserialize_document(|access| visitor.visit_enum(access)) + self.access_document(|access| visitor.visit_enum(access)) } t => Err(Error::custom(format!("expected enum, instead got {:?}", t))), } @@ -321,10 +390,13 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: serde::de::Visitor<'de>, { - if name == UUID_NEWTYPE_NAME { - self.deserialize_next(visitor, Some(BinarySubtype::Uuid)) - } else { - visitor.visit_newtype_struct(self) + match name { + UUID_NEWTYPE_NAME => self.deserialize_next( + visitor, + DeserializerHint::BinarySubtype(BinarySubtype::Uuid), + ), + RAW_BSON_NEWTYPE => self.deserialize_next(visitor, DeserializerHint::RawBson), + _ => visitor.visit_newtype_struct(self), } } @@ -428,7 +500,7 @@ impl<'d, 'de> serde::de::SeqAccess<'de> for DocumentAccess<'d, 'de> { if self.read_next_type()?.is_none() { return Ok(None); } - let _index = self.read(|s| s.root_deserializer.deserialize_document_key())?; + let _index = self.read(|s| s.root_deserializer.deserialize_cstr())?; self.read_next_value(seed).map(Some) } } @@ -498,7 +570,7 @@ impl<'d, 'de> serde::de::Deserializer<'de> for DocumentKeyDeserializer<'d, 'de> where V: serde::de::Visitor<'de>, { - let s = self.root_deserializer.deserialize_document_key()?; + let s = self.root_deserializer.deserialize_cstr()?; match s { Cow::Borrowed(b) => visitor.visit_borrowed_str(b), Cow::Owned(string) => visitor.visit_string(string), @@ -534,16 +606,94 @@ impl<'de> serde::de::Deserializer<'de> for FieldDeserializer { } } +struct RawDocumentAccess<'d> { + deserializer: RawDocumentDeserializer<'d>, + first: bool, + array: bool, +} + +impl<'de> RawDocumentAccess<'de> { + fn new(doc: &'de RawDocument) -> Self { + Self { + deserializer: RawDocumentDeserializer { raw_doc: doc }, + first: true, + array: false, + } + } + + fn for_array(doc: &'de RawDocument) -> Self { + Self { + deserializer: RawDocumentDeserializer { raw_doc: doc }, + first: true, + array: true, + } + } +} + +impl<'de> serde::de::MapAccess<'de> for RawDocumentAccess<'de> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + if self.first { + self.first = false; + seed.deserialize(FieldDeserializer { + field_name: if self.array { + RAW_ARRAY_NEWTYPE + } else { + RAW_DOCUMENT_NEWTYPE + }, + }) + .map(Some) + } else { + Ok(None) + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(self.deserializer) + } +} + +#[derive(Clone, Copy)] +struct RawDocumentDeserializer<'a> { + raw_doc: &'a RawDocument, +} + +impl<'de> serde::de::Deserializer<'de> for RawDocumentDeserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_borrowed_bytes(self.raw_doc.as_bytes()) + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + struct ObjectIdAccess { oid: ObjectId, visited: bool, + hint: DeserializerHint, } impl ObjectIdAccess { - fn new(oid: ObjectId) -> Self { + fn new(oid: ObjectId, hint: DeserializerHint) -> Self { Self { oid, visited: false, + hint, } } } @@ -567,11 +717,17 @@ impl<'de> serde::de::MapAccess<'de> for ObjectIdAccess { where V: serde::de::DeserializeSeed<'de>, { - seed.deserialize(ObjectIdDeserializer(self.oid)) + seed.deserialize(ObjectIdDeserializer { + oid: self.oid, + hint: self.hint, + }) } } -struct ObjectIdDeserializer(ObjectId); +struct ObjectIdDeserializer { + oid: ObjectId, + hint: DeserializerHint, +} impl<'de> serde::de::Deserializer<'de> for ObjectIdDeserializer { type Error = Error; @@ -580,7 +736,13 @@ impl<'de> serde::de::Deserializer<'de> for ObjectIdDeserializer { where V: serde::de::Visitor<'de>, { - visitor.visit_string(self.0.to_hex()) + println!("oid hint {:?}", self.hint); + println!("visitor: {:?}", std::any::type_name::()); + // save an allocation when deserializing to raw bson + match self.hint { + DeserializerHint::RawBson => visitor.visit_bytes(&self.oid.bytes()), + _ => visitor.visit_string(self.oid.to_hex()), + } } serde::forward_to_deserialize_any! { @@ -749,6 +911,14 @@ struct DateTimeAccess<'d> { deserializer: &'d mut DateTimeDeserializer, } +// impl<'d> DateTimeAccess<'d> { +// fn new(deserializer: &'d mut DateTimeDeserializer) -> Self { +// Self { + +// } +// } +// } + impl<'de, 'd> serde::de::MapAccess<'de> for DateTimeAccess<'d> { type Error = Error; @@ -782,13 +952,15 @@ impl<'de, 'd> serde::de::MapAccess<'de> for DateTimeAccess<'d> { struct DateTimeDeserializer { dt: DateTime, stage: DateTimeDeserializationStage, + hint: DeserializerHint, } impl DateTimeDeserializer { - fn new(dt: DateTime) -> Self { + fn new(dt: DateTime, hint: DeserializerHint) -> Self { Self { dt, stage: DateTimeDeserializationStage::TopLevel, + hint, } } } @@ -801,12 +973,18 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut DateTimeDeserializer { V: serde::de::Visitor<'de>, { match self.stage { - DateTimeDeserializationStage::TopLevel => { - self.stage = DateTimeDeserializationStage::NumberLong; - visitor.visit_map(DateTimeAccess { - deserializer: &mut self, - }) - } + DateTimeDeserializationStage::TopLevel => match self.hint { + DeserializerHint::RawBson => { + self.stage = DateTimeDeserializationStage::Done; + visitor.visit_i64(self.dt.timestamp_millis()) + } + _ => { + self.stage = DateTimeDeserializationStage::NumberLong; + visitor.visit_map(DateTimeAccess { + deserializer: &mut self, + }) + } + }, DateTimeDeserializationStage::NumberLong => { self.stage = DateTimeDeserializationStage::Done; visitor.visit_string(self.dt.timestamp_millis().to_string()) @@ -824,11 +1002,11 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut DateTimeDeserializer { } } -struct BinaryAccess<'d> { - deserializer: &'d mut BinaryDeserializer, +struct BinaryAccess<'d, 'de> { + deserializer: &'d mut BinaryDeserializer<'de>, } -impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d> { +impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d, 'de> { type Error = Error; fn next_key_seed(&mut self, seed: K) -> Result> @@ -863,21 +1041,35 @@ impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d> { } } -struct BinaryDeserializer { - binary: Binary, +enum BinaryContent<'a> { + Borrowed(RawBinary<'a>), + Owned(Binary), +} + +struct BinaryDeserializer<'a> { + binary: BinaryContent<'a>, stage: BinaryDeserializationStage, } -impl BinaryDeserializer { +impl BinaryDeserializer<'static> { fn new(binary: Binary) -> Self { Self { - binary, + binary: BinaryContent::Owned(binary), + stage: BinaryDeserializationStage::TopLevel, + } + } +} + +impl<'a> BinaryDeserializer<'a> { + fn borrowed(binary: RawBinary<'a>) -> Self { + Self { + binary: BinaryContent::Borrowed(binary), stage: BinaryDeserializationStage::TopLevel, } } } -impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BinaryDeserializer { +impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BinaryDeserializer<'de> { type Error = Error; fn deserialize_any(mut self, visitor: V) -> Result @@ -893,11 +1085,21 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BinaryDeserializer { } BinaryDeserializationStage::Subtype => { self.stage = BinaryDeserializationStage::Bytes; - visitor.visit_string(hex::encode([u8::from(self.binary.subtype)])) + match self.binary { + BinaryContent::Owned(ref b) => { + visitor.visit_string(hex::encode([u8::from(b.subtype)])) + } + BinaryContent::Borrowed(b) => visitor.visit_u8(b.subtype().into()), + } } BinaryDeserializationStage::Bytes => { self.stage = BinaryDeserializationStage::Done; - visitor.visit_string(base64::encode(self.binary.bytes.as_slice())) + match self.binary { + BinaryContent::Owned(ref b) => { + visitor.visit_string(base64::encode(b.bytes.as_slice())) + } + BinaryContent::Borrowed(b) => visitor.visit_borrowed_bytes(b.as_bytes()), + } } BinaryDeserializationStage::Done => { Err(Error::custom("Binary fully deserialized already")) @@ -919,6 +1121,402 @@ enum BinaryDeserializationStage { Done, } +struct CodeWithScopeAccess<'de, 'd, 'a> { + deserializer: &'a mut CodeWithScopeDeserializer<'de, 'd>, +} + +impl<'de, 'd, 'a> CodeWithScopeAccess<'de, 'd, 'a> { + fn new(deserializer: &'a mut CodeWithScopeDeserializer<'de, 'd>) -> Self { + Self { deserializer } + } +} + +impl<'de, 'd, 'a> serde::de::MapAccess<'de> for CodeWithScopeAccess<'de, 'd, 'a> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + println!("key: {:?}", self.deserializer.stage); + match self.deserializer.stage { + CodeWithScopeDeserializationStage::Code => seed + .deserialize(FieldDeserializer { + field_name: "$code", + }) + .map(Some), + CodeWithScopeDeserializationStage::Scope => seed + .deserialize(FieldDeserializer { + field_name: "$scope", + }) + .map(Some), + CodeWithScopeDeserializationStage::Done => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.deserializer) + } +} + +struct CodeWithScopeDeserializer<'de, 'a> { + root_deserializer: &'a mut Deserializer<'de>, + stage: CodeWithScopeDeserializationStage, + hint: DeserializerHint, +} + +impl<'de, 'a> CodeWithScopeDeserializer<'de, 'a> { + fn new(root_deserializer: &'a mut Deserializer<'de>, hint: DeserializerHint) -> Self { + Self { + root_deserializer, + stage: CodeWithScopeDeserializationStage::Code, + hint, + } + } +} + +impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut CodeWithScopeDeserializer<'de, 'a> { + type Error = Error; + + fn deserialize_any(mut self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.stage { + CodeWithScopeDeserializationStage::Code => { + self.stage = CodeWithScopeDeserializationStage::Scope; + match self.root_deserializer.deserialize_str()? { + Cow::Borrowed(s) => { + println!("visiting code: {}", s); + visitor.visit_borrowed_str(s) + } + Cow::Owned(s) => visitor.visit_string(s), + } + } + CodeWithScopeDeserializationStage::Scope => { + self.stage = CodeWithScopeDeserializationStage::Done; + println!("deserializing scope"); + self.root_deserializer.deserialize_document( + visitor, + self.hint, + DocumentType::EmbeddedDocument, + ) + } + CodeWithScopeDeserializationStage::Done => Err(Error::custom( + "JavaScriptCodeWithScope fully deserialized already", + )), + } + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +#[derive(Debug)] +enum CodeWithScopeDeserializationStage { + Code, + Scope, + Done, +} + +struct DbPointerAccess<'de, 'd, 'a> { + deserializer: &'a mut DbPointerDeserializer<'de, 'd>, +} + +impl<'de, 'd, 'a> DbPointerAccess<'de, 'd, 'a> { + fn new(deserializer: &'a mut DbPointerDeserializer<'de, 'd>) -> Self { + Self { deserializer } + } +} + +impl<'de, 'd, 'a> serde::de::MapAccess<'de> for DbPointerAccess<'de, 'd, 'a> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + println!("key: {:?}", self.deserializer.stage); + match self.deserializer.stage { + DbPointerDeserializationStage::TopLevel => seed + .deserialize(FieldDeserializer { + field_name: "$dbPointer", + }) + .map(Some), + DbPointerDeserializationStage::Namespace => seed + .deserialize(FieldDeserializer { field_name: "$ref" }) + .map(Some), + DbPointerDeserializationStage::Id => seed + .deserialize(FieldDeserializer { field_name: "$id" }) + .map(Some), + DbPointerDeserializationStage::Done => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.deserializer) + } +} + +struct DbPointerDeserializer<'de, 'a> { + root_deserializer: &'a mut Deserializer<'de>, + stage: DbPointerDeserializationStage, + hint: DeserializerHint, +} + +impl<'de, 'a> DbPointerDeserializer<'de, 'a> { + fn new(root_deserializer: &'a mut Deserializer<'de>, hint: DeserializerHint) -> Self { + Self { + root_deserializer, + stage: DbPointerDeserializationStage::TopLevel, + hint, + } + } +} + +impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut DbPointerDeserializer<'de, 'a> { + type Error = Error; + + fn deserialize_any(mut self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + println!("deserializing {:?}", self.stage); + match self.stage { + DbPointerDeserializationStage::TopLevel => { + self.stage = DbPointerDeserializationStage::Namespace; + visitor.visit_map(DbPointerAccess::new(self)) + } + DbPointerDeserializationStage::Namespace => { + self.stage = DbPointerDeserializationStage::Id; + match self.root_deserializer.deserialize_str()? { + Cow::Borrowed(s) => visitor.visit_borrowed_str(s), + Cow::Owned(s) => visitor.visit_string(s), + } + } + DbPointerDeserializationStage::Id => { + self.stage = DbPointerDeserializationStage::Done; + visitor.visit_borrowed_bytes(self.root_deserializer.bytes.read_slice(12)?) + } + DbPointerDeserializationStage::Done => { + Err(Error::custom("DbPointer fully deserialized already")) + } + } + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +#[derive(Debug)] +enum DbPointerDeserializationStage { + TopLevel, + Namespace, + Id, + Done, +} + +struct RegexAccess<'de, 'd, 'a> { + deserializer: &'a mut RegexDeserializer<'de, 'd>, +} + +impl<'de, 'd, 'a> RegexAccess<'de, 'd, 'a> { + fn new(deserializer: &'a mut RegexDeserializer<'de, 'd>) -> Self { + Self { deserializer } + } +} + +impl<'de, 'd, 'a> serde::de::MapAccess<'de> for RegexAccess<'de, 'd, 'a> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + println!("key: {:?}", self.deserializer.stage); + match self.deserializer.stage { + RegexDeserializationStage::TopLevel => seed + .deserialize(FieldDeserializer { + field_name: "$regularExpression", + }) + .map(Some), + RegexDeserializationStage::Pattern => seed + .deserialize(FieldDeserializer { + field_name: "pattern", + }) + .map(Some), + RegexDeserializationStage::Options => seed + .deserialize(FieldDeserializer { + field_name: "options", + }) + .map(Some), + RegexDeserializationStage::Done => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.deserializer) + } +} + +struct RegexDeserializer<'de, 'a> { + root_deserializer: &'a mut Deserializer<'de>, + stage: RegexDeserializationStage, +} + +impl<'de, 'a> RegexDeserializer<'de, 'a> { + fn new(root_deserializer: &'a mut Deserializer<'de>) -> Self { + Self { + root_deserializer, + stage: RegexDeserializationStage::TopLevel, + } + } +} + +impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut RegexDeserializer<'de, 'a> { + type Error = Error; + + fn deserialize_any(mut self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.stage { + RegexDeserializationStage::TopLevel => { + self.stage.advance(); + visitor.visit_map(RegexAccess::new(self)) + } + RegexDeserializationStage::Pattern | RegexDeserializationStage::Options => { + self.stage.advance(); + match self.root_deserializer.deserialize_cstr()? { + Cow::Borrowed(s) => visitor.visit_borrowed_str(s), + Cow::Owned(s) => visitor.visit_string(s), + } + } + RegexDeserializationStage::Done => { + Err(Error::custom("DbPointer fully deserialized already")) + } + } + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +#[derive(Debug)] +enum RegexDeserializationStage { + TopLevel, + Pattern, + Options, + Done, +} + +impl RegexDeserializationStage { + fn advance(&mut self) { + *self = match self { + RegexDeserializationStage::TopLevel => RegexDeserializationStage::Pattern, + RegexDeserializationStage::Pattern => RegexDeserializationStage::Options, + RegexDeserializationStage::Options => RegexDeserializationStage::Done, + RegexDeserializationStage::Done => RegexDeserializationStage::Done, + } + } +} + +/// Helper access struct for visiting the extended JSON model of simple BSON types. +/// e.g. Symbol, Timestamp, etc. +struct RawBsonAccess<'a> { + key: &'static str, + value: BsonContent<'a>, + first: bool, +} + +/// Enum value representing some cached BSON data needed to represent a given +/// BSON type's extended JSON model. +#[derive(Debug, Clone, Copy)] +enum BsonContent<'a> { + Str(&'a str), + Int32(i32), + Boolean(bool), +} + +impl<'a> RawBsonAccess<'a> { + fn new(key: &'static str, value: BsonContent<'a>) -> Self { + Self { + key, + value, + first: true, + } + } +} + +impl<'de> MapAccess<'de> for RawBsonAccess<'de> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + if self.first { + self.first = false; + seed.deserialize(FieldDeserializer { + field_name: self.key, + }) + .map(Some) + } else { + Ok(None) + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(RawBsonDeserializer { value: self.value }) + } +} + +struct RawBsonDeserializer<'a> { + value: BsonContent<'a>, +} + +impl<'de, 'a> serde::de::Deserializer<'de> for RawBsonDeserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.value { + BsonContent::Boolean(b) => visitor.visit_bool(b), + BsonContent::Str(s) => visitor.visit_borrowed_str(s), + BsonContent::Int32(i) => visitor.visit_i32(i), + } + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + /// Struct wrapping a slice of BSON bytes. struct BsonBuf<'a> { bytes: &'a [u8], @@ -960,9 +1558,11 @@ impl<'a> BsonBuf<'a> { } /// Get the string starting at the provided index and ending at the buffer's current index. - fn str(&mut self, start: usize) -> Result> { + /// + /// Can optionally override the global UTF-8 lossy setting to ensure bytes are not allocated. + fn str(&mut self, start: usize, utf8_lossy_override: Option) -> Result> { let bytes = &self.bytes[start..self.index]; - let s = if self.utf8_lossy { + let s = if utf8_lossy_override.unwrap_or(self.utf8_lossy) { String::from_utf8_lossy(bytes) } else { Cow::Borrowed(std::str::from_utf8(bytes).map_err(Error::custom)?) @@ -991,15 +1591,10 @@ impl<'a> BsonBuf<'a> { self.index_check()?; - self.str(start) + self.str(start, None) } - /// Attempts to read a null-terminated UTF-8 string from the data. - /// - /// If invalid UTF-8 is encountered, the unicode replacement character will be inserted in place - /// of the offending data, resulting in an owned `String`. Otherwise, the data will be - /// borrowed as-is. - fn read_str(&mut self) -> Result> { + fn advance_to_str(&mut self) -> Result { let len = read_i32(self)?; let start = self.index; @@ -1014,13 +1609,41 @@ impl<'a> BsonBuf<'a> { self.index += (len - 1) as usize; self.index_check()?; - self.str(start) + Ok(start) + } + + /// Attempts to read a null-terminated UTF-8 string from the data. + /// + /// If invalid UTF-8 is encountered, the unicode replacement character will be inserted in place + /// of the offending data, resulting in an owned `String`. Otherwise, the data will be + /// borrowed as-is. + fn read_str(&mut self) -> Result> { + let start = self.advance_to_str()?; + self.str(start, None) + } + + /// Attempts to read a null-terminated UTF-8 string from the data. + fn read_borrowed_str(&mut self) -> Result<&'a str> { + let start = self.advance_to_str()?; + match self.str(start, Some(false))? { + Cow::Borrowed(s) => Ok(s), + Cow::Owned(_) => panic!("should have errored when encountering invalid UTF-8"), + } + } + + fn slice(&self, length: usize) -> Result<&'a [u8]> { + if self.index + length > self.bytes.len() { + return Err(Error::Io(Arc::new( + std::io::ErrorKind::UnexpectedEof.into(), + ))); + } + + Ok(&self.bytes[self.index..(self.index + length)]) } fn read_slice(&mut self, length: usize) -> Result<&'a [u8]> { - let start = self.index; + let slice = self.slice(length)?; self.index += length; - self.index_check()?; - Ok(&self.bytes[start..self.index]) + Ok(slice) } } diff --git a/src/de/serde.rs b/src/de/serde.rs index 507391d0..758a0bad 100644 --- a/src/de/serde.rs +++ b/src/de/serde.rs @@ -24,6 +24,7 @@ use crate::{ datetime::DateTime, document::{Document, IntoIter}, oid::ObjectId, + raw::RawBson, spec::BinarySubtype, uuid::UUID_NEWTYPE_NAME, Decimal128, @@ -471,14 +472,16 @@ impl<'de> Visitor<'de> for BsonVisitor { } } -fn convert_unsigned_to_signed(value: u64) -> Result -where - E: Error, -{ +enum BsonInteger { + Int32(i32), + Int64(i64), +} + +fn _convert_unsigned(value: u64) -> Result { if let Ok(int32) = i32::try_from(value) { - Ok(Bson::Int32(int32)) + Ok(BsonInteger::Int32(int32)) } else if let Ok(int64) = i64::try_from(value) { - Ok(Bson::Int64(int64)) + Ok(BsonInteger::Int64(int64)) } else { Err(Error::custom(format!( "cannot represent {} as a signed number", @@ -487,6 +490,28 @@ where } } +fn convert_unsigned_to_signed(value: u64) -> Result +where + E: Error, +{ + let bi = _convert_unsigned(value)?; + match bi { + BsonInteger::Int32(i) => Ok(Bson::Int32(i)), + BsonInteger::Int64(i) => Ok(Bson::Int64(i)), + } +} + +pub(crate) fn convert_unsigned_to_signed_raw<'a, E>(value: u64) -> Result, E> +where + E: Error, +{ + let bi = _convert_unsigned(value)?; + match bi { + BsonInteger::Int32(i) => Ok(RawBson::Int32(i)), + BsonInteger::Int64(i) => Ok(RawBson::Int64(i)), + } +} + /// Serde Deserializer pub struct Deserializer { value: Option, diff --git a/src/raw/array.rs b/src/raw/array.rs index 684a4a4c..329651e7 100644 --- a/src/raw/array.rs +++ b/src/raw/array.rs @@ -1,5 +1,7 @@ use std::convert::TryFrom; +use serde::Deserialize; + use super::{ error::{ValueAccessError, ValueAccessErrorKind, ValueAccessResult}, Error, @@ -240,3 +242,18 @@ impl<'a> Iterator for RawArrayIter<'a> { } } } + +impl<'de: 'a, 'a> Deserialize<'de> for &'a RawArray { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match RawBson::deserialize(deserializer)? { + RawBson::Array(d) => Ok(d), + b => Err(serde::de::Error::custom(format!( + "expected raw array reference, instead got {:?}", + b + ))), + } + } +} diff --git a/src/raw/bson.rs b/src/raw/bson.rs index 05ae4e19..5ac9d980 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -1,10 +1,18 @@ use std::convert::{TryFrom, TryInto}; +use serde::{ + de::{MapAccess, Unexpected, Visitor}, + Deserialize, +}; + use super::{Error, RawArray, RawDocument, Result}; use crate::{ + extjson, oid::{self, ObjectId}, + raw::{RAW_ARRAY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, spec::{BinarySubtype, ElementType}, Bson, + DateTime, DbPointer, Decimal128, Timestamp, @@ -239,6 +247,259 @@ impl<'a> RawBson<'a> { } } +impl<'de: 'a, 'a> Deserialize<'de> for RawBson<'a> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error as SerdeError; + + struct RawBsonVisitor; + + impl<'de> Visitor<'de> for RawBsonVisitor { + type Value = RawBson<'de>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a raw BSON reference") + } + + fn visit_borrowed_str(self, v: &'de str) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::String(v)) + } + + fn visit_borrowed_bytes( + self, + bytes: &'de [u8], + ) -> std::result::Result + where + E: SerdeError, + { + Ok(RawBson::Binary(RawBinary { + bytes, + subtype: BinarySubtype::Generic, + })) + } + + fn visit_i8(self, v: i8) -> std::result::Result + where + E: SerdeError, + { + Ok(RawBson::Int32(v.into())) + } + + fn visit_i16(self, v: i16) -> std::result::Result + where + E: SerdeError, + { + Ok(RawBson::Int32(v.into())) + } + + fn visit_i32(self, v: i32) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Int32(v)) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Int64(v)) + } + + fn visit_u8(self, value: u8) -> std::result::Result + where + E: serde::de::Error, + { + crate::de::convert_unsigned_to_signed_raw(value.into()) + } + + fn visit_u16(self, value: u16) -> std::result::Result + where + E: serde::de::Error, + { + crate::de::convert_unsigned_to_signed_raw(value.into()) + } + + fn visit_u32(self, value: u32) -> std::result::Result + where + E: serde::de::Error, + { + crate::de::convert_unsigned_to_signed_raw(value.into()) + } + + fn visit_u64(self, value: u64) -> std::result::Result + where + E: serde::de::Error, + { + crate::de::convert_unsigned_to_signed_raw(value) + } + + fn visit_bool(self, v: bool) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Boolean(v)) + } + + fn visit_f64(self, v: f64) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Double(v)) + } + + fn visit_none(self) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Null) + } + + fn visit_unit(self) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Null) + } + + fn visit_newtype_struct( + self, + deserializer: D, + ) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + // use extjson for: ObjectId, datetime, timestamp, symbol, minkey, maxkey + fn visit_map(self, mut map: A) -> std::result::Result + where + A: serde::de::MapAccess<'de>, + { + let k = map.next_key::<&str>()?.ok_or_else(|| { + SerdeError::custom("expected a key when deserializing RawBson") + })?; + match k { + "$oid" => { + let oid: ObjectId = map.next_value()?; + Ok(RawBson::ObjectId(oid)) + } + "$symbol" => { + let s: &str = map.next_value()?; + Ok(RawBson::Symbol(s)) + } + "$numberDecimalBytes" => Ok(RawBson::Decimal128(map.next_value()?)), + "$regularExpression" => { + #[derive(Debug, Deserialize)] + struct BorrowedRegexBody<'a> { + pattern: &'a str, + + options: &'a str, + } + let body: BorrowedRegexBody = map.next_value()?; + Ok(RawBson::RegularExpression(RawRegex { + pattern: body.pattern, + options: body.options, + })) + } + "$undefined" => { + let _: bool = map.next_value()?; + Ok(RawBson::Undefined) + } + "$binary" => { + #[derive(Debug, Deserialize)] + struct BorrowedBinaryBody<'a> { + base64: &'a [u8], + + #[serde(rename = "subType")] + subtype: u8, + } + + let v = map.next_value::()?; + + Ok(RawBson::Binary(RawBinary { + bytes: v.base64, + subtype: v.subtype.into(), + })) + } + "$date" => { + let v = map.next_value::()?; + Ok(RawBson::DateTime(DateTime::from_millis(v))) + } + "$timestamp" => { + let v = map.next_value::()?; + Ok(RawBson::Timestamp(Timestamp { + time: v.t, + increment: v.i, + })) + } + "$minKey" => { + let _ = map.next_value::()?; + Ok(RawBson::MinKey) + } + "$maxKey" => { + let _ = map.next_value::()?; + Ok(RawBson::MaxKey) + } + "$code" => { + let code = map.next_value::<&str>()?; + if let Some(key) = map.next_key::<&str>()? { + if key == "$scope" { + let scope = map.next_value::<&RawDocument>()?; + Ok(RawBson::JavaScriptCodeWithScope( + RawJavaScriptCodeWithScope { code, scope }, + )) + } else { + Err(SerdeError::unknown_field(key, &["$scope"])) + } + } else { + Ok(RawBson::JavaScriptCode(code)) + } + } + "$dbPointer" => { + #[derive(Deserialize)] + struct BorrowedDbPointerBody<'a> { + #[serde(rename = "$ref")] + ns: &'a str, + + #[serde(rename = "$id")] + id: ObjectId, + } + + let body: BorrowedDbPointerBody = map.next_value()?; + Ok(RawBson::DbPointer(RawDbPointer { + namespace: body.ns, + id: body.id, + })) + } + RAW_DOCUMENT_NEWTYPE => { + let bson = map.next_value::<&[u8]>()?; + let doc = RawDocument::new(bson).map_err(SerdeError::custom)?; + Ok(RawBson::Document(doc)) + } + RAW_ARRAY_NEWTYPE => { + let bson = map.next_value::<&[u8]>()?; + let doc = RawDocument::new(bson).map_err(SerdeError::custom)?; + Ok(RawBson::Array(RawArray::from_doc(doc))) + } + k => Err(SerdeError::custom(format!( + "can't deserialize RawBson from map, key={}", + k + ))), + } + } + } + + deserializer.deserialize_newtype_struct(RAW_BSON_NEWTYPE, RawBsonVisitor) + } +} + impl<'a> TryFrom> for Bson { type Error = Error; @@ -302,8 +563,8 @@ impl<'a> TryFrom> for Bson { /// A BSON binary value referencing raw bytes stored elsewhere. #[derive(Clone, Copy, Debug, PartialEq)] pub struct RawBinary<'a> { - pub(super) subtype: BinarySubtype, - pub(super) bytes: &'a [u8], + pub(crate) subtype: BinarySubtype, + pub(crate) bytes: &'a [u8], } impl<'a> RawBinary<'a> { @@ -318,6 +579,21 @@ impl<'a> RawBinary<'a> { } } +impl<'de: 'a, 'a> Deserialize<'de> for RawBinary<'a> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match RawBson::deserialize(deserializer)? { + RawBson::Binary(b) => Ok(b), + c => Err(serde::de::Error::custom(format!( + "expected binary, but got {:?} instead", + c + ))), + } + } +} + /// A BSON regex referencing raw bytes stored elsewhere. #[derive(Clone, Copy, Debug, PartialEq)] pub struct RawRegex<'a> { @@ -337,10 +613,26 @@ impl<'a> RawRegex<'a> { } } +impl<'de: 'a, 'a> Deserialize<'de> for RawRegex<'a> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match RawBson::deserialize(deserializer)? { + RawBson::RegularExpression(b) => Ok(b), + c => Err(serde::de::Error::custom(format!( + "expected Regex, but got {:?} instead", + c + ))), + } + } +} + /// A BSON "code with scope" value referencing raw bytes stored elsewhere. #[derive(Clone, Copy, Debug, PartialEq)] pub struct RawJavaScriptCodeWithScope<'a> { pub(crate) code: &'a str, + pub(crate) scope: &'a RawDocument, } @@ -356,9 +648,39 @@ impl<'a> RawJavaScriptCodeWithScope<'a> { } } +impl<'de: 'a, 'a> Deserialize<'de> for RawJavaScriptCodeWithScope<'a> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match RawBson::deserialize(deserializer)? { + RawBson::JavaScriptCodeWithScope(b) => Ok(b), + c => Err(serde::de::Error::custom(format!( + "expected CodeWithScope, but got {:?} instead", + c + ))), + } + } +} + /// A BSON DB pointer value referencing raw bytes stored elesewhere. #[derive(Debug, Clone, Copy, PartialEq)] pub struct RawDbPointer<'a> { pub(crate) namespace: &'a str, pub(crate) id: ObjectId, } + +impl<'de: 'a, 'a> Deserialize<'de> for RawDbPointer<'a> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match RawBson::deserialize(deserializer)? { + RawBson::DbPointer(b) => Ok(b), + c => Err(serde::de::Error::custom(format!( + "expected DbPointer, but got {:?} instead", + c + ))), + } + } +} diff --git a/src/raw/document.rs b/src/raw/document.rs index e2141bcc..69af5d41 100644 --- a/src/raw/document.rs +++ b/src/raw/document.rs @@ -3,7 +3,13 @@ use std::{ convert::{TryFrom, TryInto}, }; -use crate::{raw::error::ErrorKind, DateTime, Timestamp}; +use serde::Deserialize; + +use crate::{ + raw::{error::ErrorKind, RAW_DOCUMENT_NEWTYPE}, + DateTime, + Timestamp, +}; use super::{ error::{ValueAccessError, ValueAccessErrorKind, ValueAccessResult}, @@ -482,6 +488,21 @@ impl RawDocument { } } +impl<'de: 'a, 'a> Deserialize<'de> for &'a RawDocument { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match RawBson::deserialize(deserializer)? { + RawBson::Document(d) => Ok(d), + b => Err(serde::de::Error::custom(format!( + "expected raw document reference, instead got {:?}", + b + ))), + } + } +} + impl std::fmt::Debug for RawDocument { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RawDocument") diff --git a/src/raw/document_buf.rs b/src/raw/document_buf.rs index 019a6e25..15cdba6e 100644 --- a/src/raw/document_buf.rs +++ b/src/raw/document_buf.rs @@ -4,7 +4,12 @@ use std::{ ops::Deref, }; -use crate::Document; +use serde::{de::Visitor, Deserialize, Deserializer}; + +use crate::{ + raw::{RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, + Document, +}; use super::{Error, ErrorKind, Iter, RawBson, RawDocument, Result}; @@ -139,6 +144,17 @@ impl RawDocumentBuf { } } +impl<'de> Deserialize<'de> for RawDocumentBuf { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + // TODO: RUST-1045 implement visit_map to deserialize from arbitrary maps. + let doc: &'de RawDocument = Deserialize::deserialize(deserializer)?; + Ok(doc.to_owned()) + } +} + impl std::fmt::Debug for RawDocumentBuf { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RawDocumentBuf") diff --git a/src/raw/mod.rs b/src/raw/mod.rs index 59f36595..1787fc04 100644 --- a/src/raw/mod.rs +++ b/src/raw/mod.rs @@ -134,6 +134,10 @@ pub use self::{ iter::Iter, }; +pub(crate) const RAW_DOCUMENT_NEWTYPE: &str = "$__private__bson_RawDocument"; +pub(crate) const RAW_ARRAY_NEWTYPE: &str = "$__private__bson_RawArray"; +pub(crate) const RAW_BSON_NEWTYPE: &str = "$__private__bson_RawBson"; + /// Given a u8 slice, return an i32 calculated from the first four bytes in /// little endian order. fn f64_from_slice(val: &[u8]) -> Result { From fbe1a87a60e952e657201c3fb03f1e0710e7aacc Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Wed, 27 Oct 2021 17:30:34 -0400 Subject: [PATCH 02/21] cleanup debug print, add comments --- src/de/mod.rs | 1 - src/de/raw.rs | 116 +++++++++++++++++++++++++++++++++---------------- src/raw/mod.rs | 5 +++ 3 files changed, 83 insertions(+), 39 deletions(-) diff --git a/src/de/mod.rs b/src/de/mod.rs index 4a24d4c3..658a876d 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -345,7 +345,6 @@ impl<'a> RawBinary<'a> { // Skip length data in old binary. if let BinarySubtype::BinaryOld = subtype { let data_len = read_i32(&mut bytes)?; - println!("data_len={}", data_len); if data_len + 4 != len { return Err(Error::invalid_length( diff --git a/src/de/raw.rs b/src/de/raw.rs index 0139f5bc..6bdaf0bf 100644 --- a/src/de/raw.rs +++ b/src/de/raw.rs @@ -40,10 +40,19 @@ use super::{ }; use crate::de::serde::MapDeserializer; +/// Hint provided to the deserializer via `deserialize_newtype_struct` as to the type of thing +/// being deserialized. #[derive(Debug, Clone, Copy)] enum DeserializerHint { + /// No hint provided, deserialize normally. None, + + /// The type being deserialized expects the BSON to contain a binary value with the provided + /// subtype. This is currently used to deserialize `bson::Uuid` values. BinarySubtype(BinarySubtype), + + /// The type being deserialized is raw BSON, meaning no allocations should occur as part of + /// deserializing and everything should be visited via borrowing or `Copy`. RawBson, } @@ -60,6 +69,8 @@ pub(crate) struct Deserializer<'de> { current_type: ElementType, } +/// Enum used to determine what the type of document being deserialized is in +/// `Deserializer::deserialize_document`. enum DocumentType { Array, EmbeddedDocument, @@ -110,10 +121,19 @@ impl<'de> Deserializer<'de> { self.bytes.read_str() } + /// Read a null-terminated C style string from the underling BSON. + /// + /// If utf8_lossy, this will be an owned string if invalid UTF-8 is encountered in the string, + /// otherwise it will be borrowed. fn deserialize_cstr(&mut self) -> Result> { self.bytes.read_cstr() } + /// Read a document from the underling BSON, whether it's an array or an actual document. + /// + /// If hinted to use raw BSON, the bytes themselves will be visited using a special newtype + /// name. Otherwise, the key-value pairs will be accessed in order, either as part of a + /// `MapAccess` for documents or a `SeqAccess` for arrays. fn deserialize_document( &mut self, visitor: V, @@ -155,7 +175,6 @@ impl<'de> Deserializer<'de> { where F: FnOnce(DocumentAccess<'_, 'de>) -> Result, { - println!("in access"); let mut length_remaining = read_i32(&mut self.bytes)? - 4; let out = f(DocumentAccess { root_deserializer: self, @@ -183,6 +202,8 @@ impl<'de> Deserializer<'de> { Ok(Some(element_type)) } + /// Deserialize the next element in the BSON, using the type of the element along with the + /// provided hint to determine how to visit the data. fn deserialize_next(&mut self, visitor: V, hint: DeserializerHint) -> Result where V: serde::de::Visitor<'de>, @@ -277,7 +298,7 @@ impl<'de> Deserializer<'de> { visitor.visit_map(RegexAccess::new(&mut de)) } ElementType::DbPointer => { - let mut de = DbPointerDeserializer::new(&mut *self, hint); + let mut de = DbPointerDeserializer::new(&mut *self); visitor.visit_map(DbPointerAccess::new(&mut de)) } ElementType::JavaScriptCode => { @@ -606,9 +627,15 @@ impl<'de> serde::de::Deserializer<'de> for FieldDeserializer { } } +/// A `MapAccess` used to deserialize entire documents as chunks of bytes without deserializing +/// the individual key/value pairs. struct RawDocumentAccess<'d> { deserializer: RawDocumentDeserializer<'d>, - first: bool, + + /// Whether the first key has been deserialized yet or not. + deserialized_first: bool, + + /// Whether or not this document being deserialized is for anarray or not. array: bool, } @@ -616,7 +643,7 @@ impl<'de> RawDocumentAccess<'de> { fn new(doc: &'de RawDocument) -> Self { Self { deserializer: RawDocumentDeserializer { raw_doc: doc }, - first: true, + deserialized_first: false, array: false, } } @@ -624,7 +651,7 @@ impl<'de> RawDocumentAccess<'de> { fn for_array(doc: &'de RawDocument) -> Self { Self { deserializer: RawDocumentDeserializer { raw_doc: doc }, - first: true, + deserialized_first: false, array: true, } } @@ -637,8 +664,11 @@ impl<'de> serde::de::MapAccess<'de> for RawDocumentAccess<'de> { where K: serde::de::DeserializeSeed<'de>, { - if self.first { - self.first = false; + if !self.deserialized_first { + self.deserialized_first = true; + + // the newtype name will indicate to the `RawBson` enum that the incoming + // bytes are meant to be treated as a document or array instead of a binary value. seed.deserialize(FieldDeserializer { field_name: if self.array { RAW_ARRAY_NEWTYPE @@ -736,8 +766,6 @@ impl<'de> serde::de::Deserializer<'de> for ObjectIdDeserializer { where V: serde::de::Visitor<'de>, { - println!("oid hint {:?}", self.hint); - println!("visitor: {:?}", std::any::type_name::()); // save an allocation when deserializing to raw bson match self.hint { DeserializerHint::RawBson => visitor.visit_bytes(&self.oid.bytes()), @@ -901,24 +929,17 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut TimestampDeserializer { } } -enum DateTimeDeserializationStage { - TopLevel, - NumberLong, - Done, -} - +/// A `MapAccess` providing access to a BSON datetime being deserialized. +/// +/// If hinted to be raw BSON, this deserializes the serde data model equivalent +/// of { "$date": }. +/// +/// Otherwise, this deserializes the serde data model equivalent of +/// { "$date": { "$numberLong": } }. struct DateTimeAccess<'d> { deserializer: &'d mut DateTimeDeserializer, } -// impl<'d> DateTimeAccess<'d> { -// fn new(deserializer: &'d mut DateTimeDeserializer) -> Self { -// Self { - -// } -// } -// } - impl<'de, 'd> serde::de::MapAccess<'de> for DateTimeAccess<'d> { type Error = Error; @@ -955,6 +976,12 @@ struct DateTimeDeserializer { hint: DeserializerHint, } +enum DateTimeDeserializationStage { + TopLevel, + NumberLong, + Done, +} + impl DateTimeDeserializer { fn new(dt: DateTime, hint: DeserializerHint) -> Self { Self { @@ -1002,6 +1029,13 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut DateTimeDeserializer { } } +/// A `MapAccess` providing access to a BSON binary being deserialized. +/// +/// If hinted to be raw BSON, this deserializes the serde data model equivalent +/// of { "$binary": { "subType": , "base64": } }. +/// +/// Otherwise, this deserializes the serde data model equivalent of +/// { "$binary": { "subType": , "base64": } }. struct BinaryAccess<'d, 'de> { deserializer: &'d mut BinaryDeserializer<'de>, } @@ -1041,6 +1075,7 @@ impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d, 'de> { } } +/// Storage of possibly borrowed, possibly owned binary data. enum BinaryContent<'a> { Borrowed(RawBinary<'a>), Owned(Binary), @@ -1121,6 +1156,13 @@ enum BinaryDeserializationStage { Done, } +/// A `MapAccess` providing access to a BSON code with scope being deserialized. +/// +/// If hinted to be raw BSON, this deserializes the serde data model equivalent +/// of { "$code": , "$scope": <&RawDocument> } }. +/// +/// Otherwise, this deserializes the serde data model equivalent of +/// { "$code": "$scope": }. struct CodeWithScopeAccess<'de, 'd, 'a> { deserializer: &'a mut CodeWithScopeDeserializer<'de, 'd>, } @@ -1138,7 +1180,6 @@ impl<'de, 'd, 'a> serde::de::MapAccess<'de> for CodeWithScopeAccess<'de, 'd, 'a> where K: serde::de::DeserializeSeed<'de>, { - println!("key: {:?}", self.deserializer.stage); match self.deserializer.stage { CodeWithScopeDeserializationStage::Code => seed .deserialize(FieldDeserializer { @@ -1189,16 +1230,12 @@ impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut CodeWithScopeDeserial CodeWithScopeDeserializationStage::Code => { self.stage = CodeWithScopeDeserializationStage::Scope; match self.root_deserializer.deserialize_str()? { - Cow::Borrowed(s) => { - println!("visiting code: {}", s); - visitor.visit_borrowed_str(s) - } + Cow::Borrowed(s) => visitor.visit_borrowed_str(s), Cow::Owned(s) => visitor.visit_string(s), } } CodeWithScopeDeserializationStage::Scope => { self.stage = CodeWithScopeDeserializationStage::Done; - println!("deserializing scope"); self.root_deserializer.deserialize_document( visitor, self.hint, @@ -1225,6 +1262,10 @@ enum CodeWithScopeDeserializationStage { Done, } +/// A `MapAccess` providing access to a BSON DB pointer being deserialized. +/// +/// Regardless of the hint, this deserializes the serde data model equivalent +/// of { "$dbPointer": { "$ref": , "$id": } }. struct DbPointerAccess<'de, 'd, 'a> { deserializer: &'a mut DbPointerDeserializer<'de, 'd>, } @@ -1242,7 +1283,6 @@ impl<'de, 'd, 'a> serde::de::MapAccess<'de> for DbPointerAccess<'de, 'd, 'a> { where K: serde::de::DeserializeSeed<'de>, { - println!("key: {:?}", self.deserializer.stage); match self.deserializer.stage { DbPointerDeserializationStage::TopLevel => seed .deserialize(FieldDeserializer { @@ -1270,15 +1310,13 @@ impl<'de, 'd, 'a> serde::de::MapAccess<'de> for DbPointerAccess<'de, 'd, 'a> { struct DbPointerDeserializer<'de, 'a> { root_deserializer: &'a mut Deserializer<'de>, stage: DbPointerDeserializationStage, - hint: DeserializerHint, } impl<'de, 'a> DbPointerDeserializer<'de, 'a> { - fn new(root_deserializer: &'a mut Deserializer<'de>, hint: DeserializerHint) -> Self { + fn new(root_deserializer: &'a mut Deserializer<'de>) -> Self { Self { root_deserializer, stage: DbPointerDeserializationStage::TopLevel, - hint, } } } @@ -1290,7 +1328,6 @@ impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut DbPointerDeserializer where V: serde::de::Visitor<'de>, { - println!("deserializing {:?}", self.stage); match self.stage { DbPointerDeserializationStage::TopLevel => { self.stage = DbPointerDeserializationStage::Namespace; @@ -1328,6 +1365,10 @@ enum DbPointerDeserializationStage { Done, } +/// A `MapAccess` providing access to a BSON regular expression being deserialized. +/// +/// Regardless of the hint, this deserializes the serde data model equivalent +/// of { "$regularExpression": { "pattern": , "options": } }. struct RegexAccess<'de, 'd, 'a> { deserializer: &'a mut RegexDeserializer<'de, 'd>, } @@ -1345,7 +1386,6 @@ impl<'de, 'd, 'a> serde::de::MapAccess<'de> for RegexAccess<'de, 'd, 'a> { where K: serde::de::DeserializeSeed<'de>, { - println!("key: {:?}", self.deserializer.stage); match self.deserializer.stage { RegexDeserializationStage::TopLevel => seed .deserialize(FieldDeserializer { @@ -1594,7 +1634,7 @@ impl<'a> BsonBuf<'a> { self.str(start, None) } - fn advance_to_str(&mut self) -> Result { + fn _advance_to_len_encoded_str(&mut self) -> Result { let len = read_i32(self)?; let start = self.index; @@ -1618,13 +1658,13 @@ impl<'a> BsonBuf<'a> { /// of the offending data, resulting in an owned `String`. Otherwise, the data will be /// borrowed as-is. fn read_str(&mut self) -> Result> { - let start = self.advance_to_str()?; + let start = self._advance_to_len_encoded_str()?; self.str(start, None) } /// Attempts to read a null-terminated UTF-8 string from the data. fn read_borrowed_str(&mut self) -> Result<&'a str> { - let start = self.advance_to_str()?; + let start = self._advance_to_len_encoded_str()?; match self.str(start, Some(false))? { Cow::Borrowed(s) => Ok(s), Cow::Owned(_) => panic!("should have errored when encountering invalid UTF-8"), diff --git a/src/raw/mod.rs b/src/raw/mod.rs index 1787fc04..48359110 100644 --- a/src/raw/mod.rs +++ b/src/raw/mod.rs @@ -134,8 +134,13 @@ pub use self::{ iter::Iter, }; +/// Special newtype name indicating that the type being (de)serialized is a raw BSON document. pub(crate) const RAW_DOCUMENT_NEWTYPE: &str = "$__private__bson_RawDocument"; + +/// Special newtype name indicating that the type being (de)serialized is a raw BSON array. pub(crate) const RAW_ARRAY_NEWTYPE: &str = "$__private__bson_RawArray"; + +/// Special newtype name indicating that the type being (de)serialized is a raw BSON value. pub(crate) const RAW_BSON_NEWTYPE: &str = "$__private__bson_RawBson"; /// Given a u8 slice, return an i32 calculated from the first four bytes in From 176d5c66b4353efe62fed7bdb2fd207f5c1533d8 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Wed, 27 Oct 2021 20:20:48 -0400 Subject: [PATCH 03/21] WIP serialize --- src/de/raw.rs | 2 +- src/raw/array.rs | 29 ++++++- src/raw/bson.rs | 137 +++++++++++++++++++++++++++++++- src/raw/document.rs | 27 ++++++- src/raw/mod.rs | 3 + src/ser/raw/mod.rs | 35 ++++++-- src/ser/raw/value_serializer.rs | 35 +++++++- src/ser/serde.rs | 28 ++++--- 8 files changed, 267 insertions(+), 29 deletions(-) diff --git a/src/de/raw.rs b/src/de/raw.rs index 6bdaf0bf..62f1005d 100644 --- a/src/de/raw.rs +++ b/src/de/raw.rs @@ -1060,7 +1060,7 @@ impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d, 'de> { .map(Some), BinaryDeserializationStage::Bytes => seed .deserialize(FieldDeserializer { - field_name: "base64", + field_name: "bytes", }) .map(Some), BinaryDeserializationStage::Done => Ok(None), diff --git a/src/raw/array.rs b/src/raw/array.rs index 329651e7..49852fdc 100644 --- a/src/raw/array.rs +++ b/src/raw/array.rs @@ -1,6 +1,6 @@ use std::convert::TryFrom; -use serde::Deserialize; +use serde::{ser::SerializeSeq, Deserialize, Serialize}; use super::{ error::{ValueAccessError, ValueAccessErrorKind, ValueAccessResult}, @@ -12,7 +12,7 @@ use super::{ RawRegex, Result, }; -use crate::{oid::ObjectId, spec::ElementType, Bson, DateTime, Timestamp}; +use crate::{oid::ObjectId, raw::RAW_ARRAY_NEWTYPE, spec::ElementType, Bson, DateTime, Timestamp}; /// A slice of a BSON document containing a BSON array value (akin to [`std::str`]). This can be /// retrieved from a [`RawDocument`] via [`RawDocument::get`]. @@ -257,3 +257,28 @@ impl<'de: 'a, 'a> Deserialize<'de> for &'a RawArray { } } } + +impl<'a> Serialize for &'a RawArray { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + struct SeqSerializer<'a>(&'a RawArray); + + impl<'a> Serialize for SeqSerializer<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_seq(None)?; + for v in self.0 { + let v = v.map_err(serde::ser::Error::custom)?; + seq.serialize_element(&v)?; + } + seq.end() + } + } + + serializer.serialize_newtype_struct(RAW_ARRAY_NEWTYPE, &SeqSerializer(self)) + } +} diff --git a/src/raw/bson.rs b/src/raw/bson.rs index 5ac9d980..63928be6 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -2,14 +2,17 @@ use std::convert::{TryFrom, TryInto}; use serde::{ de::{MapAccess, Unexpected, Visitor}, + ser::SerializeStruct, Deserialize, + Serialize, }; +use serde_bytes::Bytes; use super::{Error, RawArray, RawDocument, Result}; use crate::{ extjson, oid::{self, ObjectId}, - raw::{RAW_ARRAY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, + raw::{RAW_ARRAY_NEWTYPE, RAW_BINARY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, spec::{BinarySubtype, ElementType}, Bson, DateTime, @@ -415,7 +418,7 @@ impl<'de: 'a, 'a> Deserialize<'de> for RawBson<'a> { "$binary" => { #[derive(Debug, Deserialize)] struct BorrowedBinaryBody<'a> { - base64: &'a [u8], + bytes: &'a [u8], #[serde(rename = "subType")] subtype: u8, @@ -424,7 +427,7 @@ impl<'de: 'a, 'a> Deserialize<'de> for RawBson<'a> { let v = map.next_value::()?; Ok(RawBson::Binary(RawBinary { - bytes: v.base64, + bytes: v.bytes, subtype: v.subtype.into(), })) } @@ -500,6 +503,61 @@ impl<'de: 'a, 'a> Deserialize<'de> for RawBson<'a> { } } +impl<'a> Serialize for RawBson<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + match self { + RawBson::Double(v) => serializer.serialize_f64(*v), + RawBson::String(v) => serializer.serialize_str(v), + RawBson::Array(v) => v.serialize(serializer), + RawBson::Document(v) => v.serialize(serializer), + RawBson::Boolean(v) => serializer.serialize_bool(*v), + RawBson::Null => serializer.serialize_unit(), + RawBson::Int32(v) => serializer.serialize_i32(*v), + RawBson::Int64(v) => serializer.serialize_i64(*v), + RawBson::ObjectId(oid) => oid.serialize(serializer), + RawBson::DateTime(dt) => dt.serialize(serializer), + RawBson::Binary(b) => b.serialize(serializer), + RawBson::JavaScriptCode(c) => { + let mut state = serializer.serialize_struct("$code", 1)?; + state.serialize_field("$code", c)?; + state.end() + } + RawBson::JavaScriptCodeWithScope(code_w_scope) => code_w_scope.serialize(serializer), + RawBson::DbPointer(dbp) => dbp.serialize(serializer), + RawBson::Symbol(s) => { + let mut state = serializer.serialize_struct("$symbol", 1)?; + state.serialize_field("$symbol", s)?; + state.end() + } + RawBson::RegularExpression(re) => re.serialize(serializer), + RawBson::Timestamp(t) => t.serialize(serializer), + RawBson::Decimal128(d) => { + let mut state = serializer.serialize_struct("$numberDecimal", 1)?; + state.serialize_field("$numberDecimalBytes", Bytes::new(&d.bytes))?; + state.end() + } + RawBson::Undefined => { + let mut state = serializer.serialize_struct("$undefined", 1)?; + state.serialize_field("$undefined", &true)?; + state.end() + } + RawBson::MaxKey => { + let mut state = serializer.serialize_struct("$maxKey", 1)?; + state.serialize_field("$maxKey", &1)?; + state.end() + } + RawBson::MinKey => { + let mut state = serializer.serialize_struct("$minKey", 1)?; + state.serialize_field("$minKey", &1)?; + state.end() + } + } + } +} + impl<'a> TryFrom> for Bson { type Error = Error; @@ -594,6 +652,52 @@ impl<'de: 'a, 'a> Deserialize<'de> for RawBinary<'a> { } } +impl<'a> Serialize for RawBinary<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + struct BinarySerializer<'a>(RawBinary<'a>); + + impl<'a> Serialize for BinarySerializer<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + if let BinarySubtype::Generic = self.0.subtype { + serializer.serialize_bytes(self.0.bytes) + } else if !serializer.is_human_readable() { + #[derive(Serialize)] + struct BorrowedBinary<'a> { + bytes: &'a Bytes, + + #[serde(rename = "subType")] + subtype: u8, + } + + let mut state = serializer.serialize_struct("$binary", 1)?; + let body = BorrowedBinary { + bytes: Bytes::new(self.0.bytes), + subtype: self.0.subtype.into(), + }; + state.serialize_field("$binary", &body)?; + state.end() + } else { + let mut state = serializer.serialize_struct("$binary", 1)?; + let body = extjson::models::BinaryBody { + base64: base64::encode(self.0.bytes), + subtype: hex::encode([self.0.subtype.into()]), + }; + state.serialize_field("$binary", &body)?; + state.end() + } + } + } + + serializer.serialize_newtype_struct(RAW_BINARY_NEWTYPE, &BinarySerializer(*self)) + } +} + /// A BSON regex referencing raw bytes stored elsewhere. #[derive(Clone, Copy, Debug, PartialEq)] pub struct RawRegex<'a> { @@ -628,6 +732,15 @@ impl<'de: 'a, 'a> Deserialize<'de> for RawRegex<'a> { } } +impl<'a> Serialize for RawRegex<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + todo!() + } +} + /// A BSON "code with scope" value referencing raw bytes stored elsewhere. #[derive(Clone, Copy, Debug, PartialEq)] pub struct RawJavaScriptCodeWithScope<'a> { @@ -663,6 +776,15 @@ impl<'de: 'a, 'a> Deserialize<'de> for RawJavaScriptCodeWithScope<'a> { } } +impl<'a> Serialize for RawJavaScriptCodeWithScope<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + todo!() + } +} + /// A BSON DB pointer value referencing raw bytes stored elesewhere. #[derive(Debug, Clone, Copy, PartialEq)] pub struct RawDbPointer<'a> { @@ -684,3 +806,12 @@ impl<'de: 'a, 'a> Deserialize<'de> for RawDbPointer<'a> { } } } + +impl<'a> Serialize for RawDbPointer<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + todo!() + } +} diff --git a/src/raw/document.rs b/src/raw/document.rs index 69af5d41..3aa9057f 100644 --- a/src/raw/document.rs +++ b/src/raw/document.rs @@ -3,7 +3,7 @@ use std::{ convert::{TryFrom, TryInto}, }; -use serde::Deserialize; +use serde::{ser::SerializeMap, Deserialize, Serialize}; use crate::{ raw::{error::ErrorKind, RAW_DOCUMENT_NEWTYPE}, @@ -503,6 +503,31 @@ impl<'de: 'a, 'a> Deserialize<'de> for &'a RawDocument { } } +impl<'a> Serialize for &'a RawDocument { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + struct KvpSerializer<'a>(&'a RawDocument); + + impl<'a> Serialize for KvpSerializer<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let mut map = serializer.serialize_map(None)?; + for kvp in self.0 { + let (k, v) = kvp.map_err(serde::ser::Error::custom)?; + map.serialize_entry(k, &v)?; + } + map.end() + } + } + + serializer.serialize_newtype_struct(RAW_DOCUMENT_NEWTYPE, &KvpSerializer(self)) + } +} + impl std::fmt::Debug for RawDocument { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RawDocument") diff --git a/src/raw/mod.rs b/src/raw/mod.rs index 48359110..4faa046a 100644 --- a/src/raw/mod.rs +++ b/src/raw/mod.rs @@ -143,6 +143,9 @@ pub(crate) const RAW_ARRAY_NEWTYPE: &str = "$__private__bson_RawArray"; /// Special newtype name indicating that the type being (de)serialized is a raw BSON value. pub(crate) const RAW_BSON_NEWTYPE: &str = "$__private__bson_RawBson"; +/// Special newtype name indicating that the type being (de)serialized is a raw BSON value. +pub(crate) const RAW_BINARY_NEWTYPE: &str = "$__private__bson_RawBinary"; + /// Given a u8 slice, return an i32 calculated from the first four bytes in /// little endian order. fn f64_from_slice(val: &[u8]) -> Result { diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index eef34538..91327d84 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -10,6 +10,7 @@ use self::value_serializer::{ValueSerializer, ValueType}; use super::{write_binary, write_cstring, write_f64, write_i32, write_i64, write_string}; use crate::{ + raw::RAW_BINARY_NEWTYPE, ser::{Error, Result}, spec::{BinarySubtype, ElementType}, uuid::UUID_NEWTYPE_NAME, @@ -25,9 +26,23 @@ pub(crate) struct Serializer { /// but in serde, the serializer learns of the type after serializing the key. type_index: usize, - /// Whether the binary value about to be serialized is a UUID or not. - /// This is indicated by serializing a newtype with name UUID_NEWTYPE_NAME; - is_uuid: bool, + // /// Whether the binary value about to be serialized is a UUID or not. + // /// This is indicated by serializing a newtype with name UUID_NEWTYPE_NAME; + // is_uuid: bool, + hint: SerializerHint, +} + +#[derive(Debug, Clone, Copy)] +enum SerializerHint { + None, + Uuid, + RawBinary, +} + +impl SerializerHint { + fn take(&mut self) -> SerializerHint { + std::mem::replace(self, SerializerHint::None) + } } impl Serializer { @@ -35,7 +50,7 @@ impl Serializer { Self { bytes: Vec::new(), type_index: 0, - is_uuid: false, + hint: SerializerHint::None, } } @@ -178,8 +193,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { fn serialize_bytes(self, v: &[u8]) -> Result { self.update_element_type(ElementType::Binary)?; - let subtype = if self.is_uuid { - self.is_uuid = false; + let subtype = if matches!(self.hint.take(), SerializerHint::Uuid) { BinarySubtype::Uuid } else { BinarySubtype::Generic @@ -233,7 +247,12 @@ impl<'a> serde::Serializer for &'a mut Serializer { T: serde::Serialize, { if name == UUID_NEWTYPE_NAME { - self.is_uuid = true; + self.hint = SerializerHint::Uuid; + } + match name { + UUID_NEWTYPE_NAME => self.hint = SerializerHint::Uuid, + RAW_BINARY_NEWTYPE => self.hint = SerializerHint::RawBinary, + _ => {} } value.serialize(self) } @@ -290,12 +309,14 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_map(self, _len: Option) -> Result { + println!("map"); self.update_element_type(ElementType::EmbeddedDocument)?; DocumentSerializer::start(&mut *self) } #[inline] fn serialize_struct(self, name: &'static str, _len: usize) -> Result { + println!("struct {}", name); let value_type = match name { "$oid" => Some(ValueType::ObjectId), "$date" => Some(ValueType::DateTime), diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs index 9af76549..dab7e29d 100644 --- a/src/ser/raw/value_serializer.rs +++ b/src/ser/raw/value_serializer.rs @@ -33,6 +33,10 @@ enum SerializationStep { BinaryBase64, BinarySubType { base64: String }, + RawBinary, + RawBinaryBytes, + RawBinarySubType { bytes: Vec }, + Symbol, RegEx, @@ -106,7 +110,7 @@ impl<'a> ValueSerializer<'a> { pub(super) fn new(rs: &'a mut Serializer, value_type: ValueType) -> Self { let state = match value_type { ValueType::DateTime => SerializationStep::DateTime, - ValueType::Binary => SerializationStep::Binary, + ValueType::Binary => SerializationStep::RawBinary, ValueType::ObjectId => SerializationStep::Oid, ValueType::Symbol => SerializationStep::Symbol, ValueType::RegularExpression => SerializationStep::RegEx, @@ -185,8 +189,15 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { } #[inline] - fn serialize_u8(self, _v: u8) -> Result { - Err(self.invalid_step("u8")) + fn serialize_u8(self, v: u8) -> Result { + match self.state { + SerializationStep::RawBinarySubType { ref bytes } => { + write_binary(&mut self.root_serializer.bytes, bytes.as_slice(), v.into())?; + self.state = SerializationStep::Done; + Ok(()) + } + _ => Err(self.invalid_step("u8")), + } } #[inline] @@ -273,6 +284,10 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { self.root_serializer.bytes.write_all(v)?; Ok(()) } + SerializationStep::RawBinaryBytes => { + self.state = SerializationStep::RawBinarySubType { bytes: v.to_vec() }; + Ok(()) + } _ => Err(self.invalid_step("&[u8]")), } } @@ -338,7 +353,7 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { #[inline] fn serialize_seq(self, _len: Option) -> Result { - Err(self.invalid_step("newtype_seq")) + Err(self.invalid_step("seq")) } #[inline] @@ -414,6 +429,18 @@ impl<'a, 'b> SerializeStruct for &'b mut ValueSerializer<'a> { value.serialize(&mut **self)?; self.state = SerializationStep::Done; } + (SerializationStep::RawBinary, "$binary") => { + self.state = SerializationStep::RawBinaryBytes; + value.serialize(&mut **self)?; + } + (SerializationStep::RawBinaryBytes, "bytes") => { + // state is updated in serialize + value.serialize(&mut **self)?; + } + (SerializationStep::RawBinarySubType { .. }, "subType") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } (SerializationStep::Binary, "$binary") => { self.state = SerializationStep::BinaryBase64; value.serialize(&mut **self)?; diff --git a/src/ser/serde.rs b/src/ser/serde.rs index f5ff7797..59ac7a34 100644 --- a/src/ser/serde.rs +++ b/src/ser/serde.rs @@ -17,6 +17,7 @@ use crate::{ datetime::DateTime, extjson, oid::ObjectId, + raw::RawBinary, spec::BinarySubtype, uuid::UUID_NEWTYPE_NAME, Binary, @@ -575,17 +576,22 @@ impl Serialize for Binary { where S: ser::Serializer, { - if let BinarySubtype::Generic = self.subtype { - serializer.serialize_bytes(self.bytes.as_slice()) - } else { - let mut state = serializer.serialize_struct("$binary", 1)?; - let body = extjson::models::BinaryBody { - base64: base64::encode(self.bytes.as_slice()), - subtype: hex::encode([self.subtype.into()]), - }; - state.serialize_field("$binary", &body)?; - state.end() - } + // if let BinarySubtype::Generic = self.subtype { + // serializer.serialize_bytes(self.bytes.as_slice()) + // } else { + // let mut state = serializer.serialize_struct("$binary", 1)?; + // let body = extjson::models::BinaryBody { + // base64: base64::encode(self.bytes.as_slice()), + // subtype: hex::encode([self.subtype.into()]), + // }; + // state.serialize_field("$binary", &body)?; + // state.end() + // } + let raw_binary = RawBinary { + bytes: self.bytes.as_slice(), + subtype: self.subtype, + }; + raw_binary.serialize(serializer) } } From f1a56509a04c8bfd1bdaf696a86f0880d0f2b9bc Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 28 Oct 2021 14:08:45 -0400 Subject: [PATCH 04/21] finish RawBinary serialization --- src/raw/bson.rs | 61 +++++++++++++++++++--------------------------- src/ser/raw/mod.rs | 7 +----- 2 files changed, 26 insertions(+), 42 deletions(-) diff --git a/src/raw/bson.rs b/src/raw/bson.rs index 63928be6..51c206e9 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -657,44 +657,33 @@ impl<'a> Serialize for RawBinary<'a> { where S: serde::Serializer, { - struct BinarySerializer<'a>(RawBinary<'a>); - - impl<'a> Serialize for BinarySerializer<'a> { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - if let BinarySubtype::Generic = self.0.subtype { - serializer.serialize_bytes(self.0.bytes) - } else if !serializer.is_human_readable() { - #[derive(Serialize)] - struct BorrowedBinary<'a> { - bytes: &'a Bytes, - - #[serde(rename = "subType")] - subtype: u8, - } - - let mut state = serializer.serialize_struct("$binary", 1)?; - let body = BorrowedBinary { - bytes: Bytes::new(self.0.bytes), - subtype: self.0.subtype.into(), - }; - state.serialize_field("$binary", &body)?; - state.end() - } else { - let mut state = serializer.serialize_struct("$binary", 1)?; - let body = extjson::models::BinaryBody { - base64: base64::encode(self.0.bytes), - subtype: hex::encode([self.0.subtype.into()]), - }; - state.serialize_field("$binary", &body)?; - state.end() - } + if let BinarySubtype::Generic = self.subtype { + serializer.serialize_bytes(self.bytes) + } else if !serializer.is_human_readable() { + #[derive(Serialize)] + struct BorrowedBinary<'a> { + bytes: &'a Bytes, + + #[serde(rename = "subType")] + subtype: u8, } - } - serializer.serialize_newtype_struct(RAW_BINARY_NEWTYPE, &BinarySerializer(*self)) + let mut state = serializer.serialize_struct("$binary", 1)?; + let body = BorrowedBinary { + bytes: Bytes::new(self.bytes), + subtype: self.subtype.into(), + }; + state.serialize_field("$binary", &body)?; + state.end() + } else { + let mut state = serializer.serialize_struct("$binary", 1)?; + let body = extjson::models::BinaryBody { + base64: base64::encode(self.bytes), + subtype: hex::encode([self.subtype.into()]), + }; + state.serialize_field("$binary", &body)?; + state.end() + } } } diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index 91327d84..e24910a4 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -246,12 +246,8 @@ impl<'a> serde::Serializer for &'a mut Serializer { where T: serde::Serialize, { - if name == UUID_NEWTYPE_NAME { - self.hint = SerializerHint::Uuid; - } match name { UUID_NEWTYPE_NAME => self.hint = SerializerHint::Uuid, - RAW_BINARY_NEWTYPE => self.hint = SerializerHint::RawBinary, _ => {} } value.serialize(self) @@ -309,14 +305,13 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_map(self, _len: Option) -> Result { - println!("map"); self.update_element_type(ElementType::EmbeddedDocument)?; DocumentSerializer::start(&mut *self) } #[inline] fn serialize_struct(self, name: &'static str, _len: usize) -> Result { - println!("struct {}", name); + // println!("struct {}", name); let value_type = match name { "$oid" => Some(ValueType::ObjectId), "$date" => Some(ValueType::DateTime), From 766f370dd1a27076574b53b7e2fde98d6dd20d28 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 28 Oct 2021 14:17:42 -0400 Subject: [PATCH 05/21] RawRegex serialization --- src/raw/bson.rs | 18 +++++++++++++++--- src/ser/serde.rs | 12 +++++------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/raw/bson.rs b/src/raw/bson.rs index 51c206e9..b9f9acdc 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -690,8 +690,8 @@ impl<'a> Serialize for RawBinary<'a> { /// A BSON regex referencing raw bytes stored elsewhere. #[derive(Clone, Copy, Debug, PartialEq)] pub struct RawRegex<'a> { - pub(super) pattern: &'a str, - pub(super) options: &'a str, + pub(crate) pattern: &'a str, + pub(crate) options: &'a str, } impl<'a> RawRegex<'a> { @@ -726,7 +726,19 @@ impl<'a> Serialize for RawRegex<'a> { where S: serde::Serializer, { - todo!() + #[derive(Serialize)] + struct BorrowedRegexBody<'a> { + pattern: &'a str, + options: &'a str, + } + + let mut state = serializer.serialize_struct("$regularExpression", 1)?; + let body = BorrowedRegexBody { + pattern: self.pattern, + options: self.options, + }; + state.serialize_field("$regularExpression", &body)?; + state.end() } } diff --git a/src/ser/serde.rs b/src/ser/serde.rs index 59ac7a34..593a6b49 100644 --- a/src/ser/serde.rs +++ b/src/ser/serde.rs @@ -17,7 +17,7 @@ use crate::{ datetime::DateTime, extjson, oid::ObjectId, - raw::RawBinary, + raw::{RawBinary, RawRegex}, spec::BinarySubtype, uuid::UUID_NEWTYPE_NAME, Binary, @@ -547,13 +547,11 @@ impl Serialize for Regex { where S: ser::Serializer, { - let mut state = serializer.serialize_struct("$regularExpression", 1)?; - let body = extjson::models::RegexBody { - pattern: self.pattern.clone(), - options: self.options.clone(), + let raw = RawRegex { + pattern: self.pattern.as_str(), + options: self.options.as_str(), }; - state.serialize_field("$regularExpression", &body)?; - state.end() + raw.serialize(serializer) } } From f0ce8837f86490f994a88e9a644b95bdfd6bf65e Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 28 Oct 2021 14:41:16 -0400 Subject: [PATCH 06/21] RawJavaScriptCodeWithScope serialization --- src/raw/bson.rs | 5 ++++- src/raw/document.rs | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/raw/bson.rs b/src/raw/bson.rs index b9f9acdc..f2f58110 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -782,7 +782,10 @@ impl<'a> Serialize for RawJavaScriptCodeWithScope<'a> { where S: serde::Serializer, { - todo!() + let mut state = serializer.serialize_struct("$codeWithScope", 2)?; + state.serialize_field("$code", &self.code)?; + state.serialize_field("$scope", &self.scope)?; + state.end() } } diff --git a/src/raw/document.rs b/src/raw/document.rs index 3aa9057f..89359a65 100644 --- a/src/raw/document.rs +++ b/src/raw/document.rs @@ -508,23 +508,23 @@ impl<'a> Serialize for &'a RawDocument { where S: serde::Serializer, { - struct KvpSerializer<'a>(&'a RawDocument); + // struct KvpSerializer<'a>(&'a RawDocument); - impl<'a> Serialize for KvpSerializer<'a> { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { + // impl<'a> Serialize for KvpSerializer<'a> { + // fn serialize(&self, serializer: S) -> std::result::Result + // where + // S: serde::Serializer, + // { let mut map = serializer.serialize_map(None)?; - for kvp in self.0 { + for kvp in *self { let (k, v) = kvp.map_err(serde::ser::Error::custom)?; map.serialize_entry(k, &v)?; } map.end() - } - } + // } + // } - serializer.serialize_newtype_struct(RAW_DOCUMENT_NEWTYPE, &KvpSerializer(self)) + // serializer.serialize_newtype_struct(RAW_DOCUMENT_NEWTYPE, &KvpSerializer(self)) } } From abba24433f19f6194da22decd86bdff9ad7829e2 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 28 Oct 2021 14:47:48 -0400 Subject: [PATCH 07/21] RawDbPointer serialize --- src/raw/bson.rs | 17 ++++++++++++++++- src/raw/document.rs | 14 +++++++------- src/ser/serde.rs | 12 +++++------- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/src/raw/bson.rs b/src/raw/bson.rs index f2f58110..f1841f15 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -816,6 +816,21 @@ impl<'a> Serialize for RawDbPointer<'a> { where S: serde::Serializer, { - todo!() + #[derive(Serialize)] + struct BorrowedDbPointerBody<'a> { + #[serde(rename = "$ref")] + ref_ns: &'a str, + + #[serde(rename = "$id")] + id: ObjectId, + } + + let mut state = serializer.serialize_struct("$dbPointer", 1)?; + let body = BorrowedDbPointerBody { + ref_ns: self.namespace, + id: self.id, + }; + state.serialize_field("$dbPointer", &body)?; + state.end() } } diff --git a/src/raw/document.rs b/src/raw/document.rs index 89359a65..67d80dea 100644 --- a/src/raw/document.rs +++ b/src/raw/document.rs @@ -515,13 +515,13 @@ impl<'a> Serialize for &'a RawDocument { // where // S: serde::Serializer, // { - let mut map = serializer.serialize_map(None)?; - for kvp in *self { - let (k, v) = kvp.map_err(serde::ser::Error::custom)?; - map.serialize_entry(k, &v)?; - } - map.end() - // } + let mut map = serializer.serialize_map(None)?; + for kvp in *self { + let (k, v) = kvp.map_err(serde::ser::Error::custom)?; + map.serialize_entry(k, &v)?; + } + map.end() + // } // } // serializer.serialize_newtype_struct(RAW_DOCUMENT_NEWTYPE, &KvpSerializer(self)) diff --git a/src/ser/serde.rs b/src/ser/serde.rs index 593a6b49..b68add82 100644 --- a/src/ser/serde.rs +++ b/src/ser/serde.rs @@ -17,7 +17,7 @@ use crate::{ datetime::DateTime, extjson, oid::ObjectId, - raw::{RawBinary, RawRegex}, + raw::{RawBinary, RawDbPointer, RawRegex}, spec::BinarySubtype, uuid::UUID_NEWTYPE_NAME, Binary, @@ -624,12 +624,10 @@ impl Serialize for DbPointer { where S: ser::Serializer, { - let mut state = serializer.serialize_struct("$dbPointer", 1)?; - let body = extjson::models::DbPointerBody { - ref_ns: self.namespace.clone(), - id: self.id.into(), + let raw = RawDbPointer { + namespace: self.namespace.as_str(), + id: self.id, }; - state.serialize_field("$dbPointer", &body)?; - state.end() + raw.serialize(serializer) } } From da88ab2ad3ccaaa4bf978501fb2b76c65a727754 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 28 Oct 2021 15:12:39 -0400 Subject: [PATCH 08/21] RawDocument serialize --- src/raw/document.rs | 35 +++++++++++++++-------------- src/raw/document_buf.rs | 12 +++++++++- src/ser/raw/mod.rs | 27 ++++++++++++++++------- src/ser/raw/value_serializer.rs | 39 +++++++++++++++++++++++++-------- 4 files changed, 79 insertions(+), 34 deletions(-) diff --git a/src/raw/document.rs b/src/raw/document.rs index 67d80dea..d803b17f 100644 --- a/src/raw/document.rs +++ b/src/raw/document.rs @@ -508,23 +508,26 @@ impl<'a> Serialize for &'a RawDocument { where S: serde::Serializer, { - // struct KvpSerializer<'a>(&'a RawDocument); - - // impl<'a> Serialize for KvpSerializer<'a> { - // fn serialize(&self, serializer: S) -> std::result::Result - // where - // S: serde::Serializer, - // { - let mut map = serializer.serialize_map(None)?; - for kvp in *self { - let (k, v) = kvp.map_err(serde::ser::Error::custom)?; - map.serialize_entry(k, &v)?; + struct KvpSerializer<'a>(&'a RawDocument); + + impl<'a> Serialize for KvpSerializer<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + if serializer.is_human_readable() { + let mut map = serializer.serialize_map(None)?; + for kvp in self.0 { + let (k, v) = kvp.map_err(serde::ser::Error::custom)?; + map.serialize_entry(k, &v)?; + } + map.end() + } else { + serializer.serialize_bytes(self.0.as_bytes()) + } + } } - map.end() - // } - // } - - // serializer.serialize_newtype_struct(RAW_DOCUMENT_NEWTYPE, &KvpSerializer(self)) + serializer.serialize_newtype_struct(RAW_DOCUMENT_NEWTYPE, &KvpSerializer(self)) } } diff --git a/src/raw/document_buf.rs b/src/raw/document_buf.rs index 15cdba6e..a4bd12b6 100644 --- a/src/raw/document_buf.rs +++ b/src/raw/document_buf.rs @@ -4,7 +4,7 @@ use std::{ ops::Deref, }; -use serde::{de::Visitor, Deserialize, Deserializer}; +use serde::{de::Visitor, Deserialize, Deserializer, Serialize}; use crate::{ raw::{RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, @@ -155,6 +155,16 @@ impl<'de> Deserialize<'de> for RawDocumentBuf { } } +impl Serialize for RawDocumentBuf { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let doc: &RawDocument = &self; + doc.serialize(serializer) + } +} + impl std::fmt::Debug for RawDocumentBuf { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RawDocumentBuf") diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index e24910a4..9584f206 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -1,6 +1,8 @@ mod document_serializer; mod value_serializer; +use std::io::Write; + use serde::{ ser::{Error as SerdeError, SerializeMap, SerializeStruct}, Serialize, @@ -36,7 +38,7 @@ pub(crate) struct Serializer { enum SerializerHint { None, Uuid, - RawBinary, + RawDocument, } impl SerializerHint { @@ -191,15 +193,23 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_bytes(self, v: &[u8]) -> Result { - self.update_element_type(ElementType::Binary)?; + match self.hint { + SerializerHint::RawDocument => { + self.update_element_type(ElementType::EmbeddedDocument)?; + self.bytes.write_all(v)?; + } + _ => { + self.update_element_type(ElementType::Binary)?; - let subtype = if matches!(self.hint.take(), SerializerHint::Uuid) { - BinarySubtype::Uuid - } else { - BinarySubtype::Generic - }; + let subtype = if matches!(self.hint.take(), SerializerHint::Uuid) { + BinarySubtype::Uuid + } else { + BinarySubtype::Generic + }; - write_binary(&mut self.bytes, v, subtype)?; + write_binary(&mut self.bytes, v, subtype)?; + } + }; Ok(()) } @@ -248,6 +258,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { { match name { UUID_NEWTYPE_NAME => self.hint = SerializerHint::Uuid, + RAW_DOCUMENT_NEWTYPE => self.hint = SerializerHint::RawDocument, _ => {} } value.serialize(self) diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs index dab7e29d..4d8a286a 100644 --- a/src/ser/raw/value_serializer.rs +++ b/src/ser/raw/value_serializer.rs @@ -54,7 +54,7 @@ enum SerializationStep { Code, CodeWithScopeCode, - CodeWithScopeScope { code: String }, + CodeWithScopeScope { code: String, raw: bool }, MinKey, @@ -265,6 +265,7 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { SerializationStep::CodeWithScopeCode => { self.state = SerializationStep::CodeWithScopeScope { code: v.to_string(), + raw: false, }; } s => { @@ -288,6 +289,14 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { self.state = SerializationStep::RawBinarySubType { bytes: v.to_vec() }; Ok(()) } + SerializationStep::CodeWithScopeScope { ref code, raw } if raw => { + let len = 4 + 4 + code.len() as i32 + 1 + v.len() as i32; + write_i32(&mut self.root_serializer.bytes, len)?; + write_string(&mut self.root_serializer.bytes, code)?; + self.root_serializer.bytes.write_all(v)?; + self.state = SerializationStep::Done; + Ok(()) + } _ => Err(self.invalid_step("&[u8]")), } } @@ -326,15 +335,23 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { } #[inline] - fn serialize_newtype_struct( - self, - _name: &'static str, - _value: &T, - ) -> Result + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result where T: Serialize, { - Err(self.invalid_step("newtype_struct")) + match (&mut self.state, name) { + ( + SerializationStep::CodeWithScopeScope { + ref code, + ref mut raw, + }, + RAW_DOCUMENT_NEWTYPE, + ) => { + *raw = true; + value.serialize(self) + } + _ => Err(self.invalid_step("newtype_struct")), + } } #[inline] @@ -384,10 +401,10 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { #[inline] fn serialize_map(self, _len: Option) -> Result { match self.state { - SerializationStep::CodeWithScopeScope { ref code } => { + SerializationStep::CodeWithScopeScope { ref code, raw } if !raw => { CodeWithScopeSerializer::start(code.as_str(), self.root_serializer) } - _ => Err(self.invalid_step("tuple_map")), + _ => Err(self.invalid_step("map")), } } @@ -406,6 +423,10 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { ) -> Result { Err(self.invalid_step("struct_variant")) } + + fn is_human_readable(&self) -> bool { + false + } } impl<'a, 'b> SerializeStruct for &'b mut ValueSerializer<'a> { From d42f46757bcebcbf23364750a33420e1ce17c4c9 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 28 Oct 2021 15:59:18 -0400 Subject: [PATCH 09/21] RawBson serialize --- src/raw/bson.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/raw/bson.rs b/src/raw/bson.rs index f1841f15..711f3b03 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -534,11 +534,7 @@ impl<'a> Serialize for RawBson<'a> { } RawBson::RegularExpression(re) => re.serialize(serializer), RawBson::Timestamp(t) => t.serialize(serializer), - RawBson::Decimal128(d) => { - let mut state = serializer.serialize_struct("$numberDecimal", 1)?; - state.serialize_field("$numberDecimalBytes", Bytes::new(&d.bytes))?; - state.end() - } + RawBson::Decimal128(d) => d.serialize(serializer), RawBson::Undefined => { let mut state = serializer.serialize_struct("$undefined", 1)?; state.serialize_field("$undefined", &true)?; From 58a6d31ca1ab4779424d0002117762b197016784 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 28 Oct 2021 16:05:43 -0400 Subject: [PATCH 10/21] RawArray serialization --- src/raw/array.rs | 14 +++++++++----- src/ser/raw/mod.rs | 10 +++++++++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/raw/array.rs b/src/raw/array.rs index 49852fdc..44fa903d 100644 --- a/src/raw/array.rs +++ b/src/raw/array.rs @@ -270,12 +270,16 @@ impl<'a> Serialize for &'a RawArray { where S: serde::Serializer, { - let mut seq = serializer.serialize_seq(None)?; - for v in self.0 { - let v = v.map_err(serde::ser::Error::custom)?; - seq.serialize_element(&v)?; + if serializer.is_human_readable() { + let mut seq = serializer.serialize_seq(None)?; + for v in self.0 { + let v = v.map_err(serde::ser::Error::custom)?; + seq.serialize_element(&v)?; + } + seq.end() + } else { + serializer.serialize_bytes(self.0.as_bytes()) } - seq.end() } } diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index 9584f206..5c559721 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -12,7 +12,7 @@ use self::value_serializer::{ValueSerializer, ValueType}; use super::{write_binary, write_cstring, write_f64, write_i32, write_i64, write_string}; use crate::{ - raw::RAW_BINARY_NEWTYPE, + raw::{RAW_ARRAY_NEWTYPE, RAW_BINARY_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, ser::{Error, Result}, spec::{BinarySubtype, ElementType}, uuid::UUID_NEWTYPE_NAME, @@ -37,8 +37,15 @@ pub(crate) struct Serializer { #[derive(Debug, Clone, Copy)] enum SerializerHint { None, + + /// The next call to `serialize_bytes` is for the purposes of serializing a UUID. Uuid, + + /// The next call to `serialize_bytes` is for the purposes of serializing a raw document. RawDocument, + + /// The next call to `serialize_bytes` is for the purposes of serializing a raw array. + RawArray, } impl SerializerHint { @@ -259,6 +266,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { match name { UUID_NEWTYPE_NAME => self.hint = SerializerHint::Uuid, RAW_DOCUMENT_NEWTYPE => self.hint = SerializerHint::RawDocument, + RAW_ARRAY_NEWTYPE => self.hint = SerializerHint::RawArray, _ => {} } value.serialize(self) From 3c969659742ad06e24002295b894d95b54eae584 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Thu, 28 Oct 2021 17:45:02 -0400 Subject: [PATCH 11/21] add corpus tests, fix discovered bugs --- src/de/raw.rs | 122 +++++++++++++++++++++++++++----- src/de/serde.rs | 5 +- src/raw/test/props.rs | 6 +- src/ser/raw/value_serializer.rs | 9 ++- src/tests/spec/corpus.rs | 39 +++++++++- 5 files changed, 149 insertions(+), 32 deletions(-) diff --git a/src/de/raw.rs b/src/de/raw.rs index 62f1005d..9a6afdd9 100644 --- a/src/de/raw.rs +++ b/src/de/raw.rs @@ -1,5 +1,6 @@ use std::{ borrow::Cow, + convert::TryInto, io::{ErrorKind, Read}, sync::Arc, }; @@ -37,6 +38,7 @@ use super::{ Error, Result, MAX_BSON_SIZE, + MIN_CODE_WITH_SCOPE_SIZE, }; use crate::de::serde::MapDeserializer; @@ -129,6 +131,18 @@ impl<'de> Deserializer<'de> { self.bytes.read_cstr() } + /// Read an ObjectId from the underling BSON. + /// + /// If hinted to use raw BSON, the bytes of the ObjectId will be visited. + /// Otherwise, a map in the shape of the extended JSON format of an ObjectId will be. + fn deserialize_objectid(&mut self, visitor: V, hint: DeserializerHint) -> Result + where + V: serde::de::Visitor<'de>, + { + let oid = ObjectId::from_reader(&mut self.bytes)?; + visitor.visit_map(ObjectIdAccess::new(oid, hint)) + } + /// Read a document from the underling BSON, whether it's an array or an actual document. /// /// If hinted to use raw BSON, the bytes themselves will be visited using a special newtype @@ -227,10 +241,7 @@ impl<'de> Deserializer<'de> { }, ElementType::Boolean => visitor.visit_bool(read_bool(&mut self.bytes)?), ElementType::Null => visitor.visit_unit(), - ElementType::ObjectId => { - let oid = ObjectId::from_reader(&mut self.bytes)?; - visitor.visit_map(ObjectIdAccess::new(oid, hint)) - } + ElementType::ObjectId => self.deserialize_objectid(visitor, hint), ElementType::EmbeddedDocument => { self.deserialize_document(visitor, hint, DocumentType::EmbeddedDocument) } @@ -298,7 +309,7 @@ impl<'de> Deserializer<'de> { visitor.visit_map(RegexAccess::new(&mut de)) } ElementType::DbPointer => { - let mut de = DbPointerDeserializer::new(&mut *self); + let mut de = DbPointerDeserializer::new(&mut *self, hint); visitor.visit_map(DbPointerAccess::new(&mut de)) } ElementType::JavaScriptCode => { @@ -317,9 +328,43 @@ impl<'de> Deserializer<'de> { } } ElementType::JavaScriptCodeWithScope => { - let _len = read_i32(&mut self.bytes)?; - let mut de = CodeWithScopeDeserializer::new(&mut *self, hint); - visitor.visit_map(CodeWithScopeAccess::new(&mut de)) + let len = read_i32(&mut self.bytes)?; + + if len < MIN_CODE_WITH_SCOPE_SIZE { + return Err(SerdeError::invalid_length( + len.try_into().unwrap_or(0), + &format!( + "CodeWithScope to be at least {} bytes", + MIN_CODE_WITH_SCOPE_SIZE + ) + .as_str(), + )); + } else if (self.bytes.bytes_remaining() as i32) < len - 4 { + return Err(SerdeError::invalid_length( + len.try_into().unwrap_or(0), + &format!( + "CodeWithScope to be at most {} bytes", + self.bytes.bytes_remaining() + ) + .as_str(), + )); + } + + let mut de = CodeWithScopeDeserializer::new(&mut *self, hint, len - 4); + let out = visitor.visit_map(CodeWithScopeAccess::new(&mut de)); + + if de.length_remaining != 0 { + return Err(SerdeError::invalid_length( + len.try_into().unwrap_or(0), + &format!( + "CodeWithScope length {} bytes greater than actual length", + de.length_remaining + ) + .as_str(), + )); + } + + out } ElementType::Symbol => { let utf8_lossy = self.bytes.utf8_lossy; @@ -1032,7 +1077,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut DateTimeDeserializer { /// A `MapAccess` providing access to a BSON binary being deserialized. /// /// If hinted to be raw BSON, this deserializes the serde data model equivalent -/// of { "$binary": { "subType": , "base64": } }. +/// of { "$binary": { "subType": , "bytes": } }. /// /// Otherwise, this deserializes the serde data model equivalent of /// { "$binary": { "subType": , "base64": } }. @@ -1058,9 +1103,17 @@ impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d, 'de> { field_name: "subType", }) .map(Some), + BinaryDeserializationStage::Bytes + if matches!(self.deserializer.binary, BinaryContent::Borrowed(_)) => + { + seed.deserialize(FieldDeserializer { + field_name: "bytes", + }) + .map(Some) + } BinaryDeserializationStage::Bytes => seed .deserialize(FieldDeserializer { - field_name: "bytes", + field_name: "base64", }) .map(Some), BinaryDeserializationStage::Done => Ok(None), @@ -1076,6 +1129,7 @@ impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d, 'de> { } /// Storage of possibly borrowed, possibly owned binary data. +#[derive(Debug)] enum BinaryContent<'a> { Borrowed(RawBinary<'a>), Owned(Binary), @@ -1207,16 +1261,37 @@ struct CodeWithScopeDeserializer<'de, 'a> { root_deserializer: &'a mut Deserializer<'de>, stage: CodeWithScopeDeserializationStage, hint: DeserializerHint, + length_remaining: i32, } impl<'de, 'a> CodeWithScopeDeserializer<'de, 'a> { - fn new(root_deserializer: &'a mut Deserializer<'de>, hint: DeserializerHint) -> Self { + fn new(root_deserializer: &'a mut Deserializer<'de>, hint: DeserializerHint, len: i32) -> Self { Self { root_deserializer, stage: CodeWithScopeDeserializationStage::Code, hint, + length_remaining: len, } } + + /// Executes a closure that reads from the BSON bytes and returns an error if the number of + /// bytes read exceeds length_remaining. + /// + /// A mutable reference to this `CodeWithScopeDeserializer` is passed into the closure. + fn read(&mut self, f: F) -> Result + where + F: FnOnce(&mut Self) -> Result, + { + let start_bytes = self.root_deserializer.bytes.bytes_read(); + let out = f(self); + let bytes_read = self.root_deserializer.bytes.bytes_read() - start_bytes; + self.length_remaining -= bytes_read as i32; + + if self.length_remaining < 0 { + return Err(Error::custom("length of CodeWithScope too short")); + } + out + } } impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut CodeWithScopeDeserializer<'de, 'a> { @@ -1229,18 +1304,20 @@ impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut CodeWithScopeDeserial match self.stage { CodeWithScopeDeserializationStage::Code => { self.stage = CodeWithScopeDeserializationStage::Scope; - match self.root_deserializer.deserialize_str()? { + match self.read(|s| s.root_deserializer.deserialize_str())? { Cow::Borrowed(s) => visitor.visit_borrowed_str(s), Cow::Owned(s) => visitor.visit_string(s), } } CodeWithScopeDeserializationStage::Scope => { self.stage = CodeWithScopeDeserializationStage::Done; - self.root_deserializer.deserialize_document( - visitor, - self.hint, - DocumentType::EmbeddedDocument, - ) + self.read(|s| { + s.root_deserializer.deserialize_document( + visitor, + s.hint, + DocumentType::EmbeddedDocument, + ) + }) } CodeWithScopeDeserializationStage::Done => Err(Error::custom( "JavaScriptCodeWithScope fully deserialized already", @@ -1310,13 +1387,15 @@ impl<'de, 'd, 'a> serde::de::MapAccess<'de> for DbPointerAccess<'de, 'd, 'a> { struct DbPointerDeserializer<'de, 'a> { root_deserializer: &'a mut Deserializer<'de>, stage: DbPointerDeserializationStage, + hint: DeserializerHint, } impl<'de, 'a> DbPointerDeserializer<'de, 'a> { - fn new(root_deserializer: &'a mut Deserializer<'de>) -> Self { + fn new(root_deserializer: &'a mut Deserializer<'de>, hint: DeserializerHint) -> Self { Self { root_deserializer, stage: DbPointerDeserializationStage::TopLevel, + hint, } } } @@ -1342,7 +1421,8 @@ impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut DbPointerDeserializer } DbPointerDeserializationStage::Id => { self.stage = DbPointerDeserializationStage::Done; - visitor.visit_borrowed_bytes(self.root_deserializer.bytes.read_slice(12)?) + self.root_deserializer + .deserialize_objectid(visitor, self.hint) } DbPointerDeserializationStage::Done => { Err(Error::custom("DbPointer fully deserialized already")) @@ -1589,6 +1669,10 @@ impl<'a> BsonBuf<'a> { self.index } + fn bytes_remaining(&self) -> usize { + self.bytes.len() - self.bytes_read() + } + /// Verify the index has not run out of bounds. fn index_check(&self) -> std::io::Result<()> { if self.index >= self.bytes.len() { diff --git a/src/de/serde.rs b/src/de/serde.rs index 758a0bad..567b350a 100644 --- a/src/de/serde.rs +++ b/src/de/serde.rs @@ -369,10 +369,7 @@ impl<'de> Visitor<'de> for BsonVisitor { "$regularExpression" => { let re = visitor.next_value::()?; - return Ok(Bson::RegularExpression(Regex { - pattern: re.pattern, - options: re.options, - })); + return Ok(Bson::RegularExpression(Regex::new(re.pattern, re.options))); } "$dbPointer" => { diff --git a/src/raw/test/props.rs b/src/raw/test/props.rs index 850dcade..6f0157d2 100644 --- a/src/raw/test/props.rs +++ b/src/raw/test/props.rs @@ -22,11 +22,7 @@ pub(crate) fn arbitrary_bson() -> impl Strategy { any::().prop_map(Bson::Int32), any::().prop_map(Bson::Int64), any::<(String, String)>().prop_map(|(pattern, options)| { - let mut chars: Vec<_> = options.chars().collect(); - chars.sort_unstable(); - - let options: String = chars.into_iter().collect(); - Bson::RegularExpression(Regex { pattern, options }) + Bson::RegularExpression(Regex::new(pattern, options)) }), any::<[u8; 12]>().prop_map(|bytes| Bson::ObjectId(crate::oid::ObjectId::from_bytes(bytes))), (arbitrary_binary_subtype(), any::>()).prop_map(|(subtype, bytes)| { diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs index 4d8a286a..99c7f6e7 100644 --- a/src/ser/raw/value_serializer.rs +++ b/src/ser/raw/value_serializer.rs @@ -256,9 +256,16 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { SerializationStep::Symbol | SerializationStep::DbPointerRef => { write_string(&mut self.root_serializer.bytes, v)?; } - SerializationStep::RegExPattern | SerializationStep::RegExOptions => { + SerializationStep::RegExPattern => { write_cstring(&mut self.root_serializer.bytes, v)?; } + SerializationStep::RegExOptions => { + let mut chars: Vec<_> = v.chars().collect(); + chars.sort_unstable(); + + let sorted = chars.into_iter().collect::(); + write_cstring(&mut self.root_serializer.bytes, sorted.as_str())?; + } SerializationStep::Code => { write_string(&mut self.root_serializer.bytes, v)?; } diff --git a/src/tests/spec/corpus.rs b/src/tests/spec/corpus.rs index 9c6844dd..f1906f73 100644 --- a/src/tests/spec/corpus.rs +++ b/src/tests/spec/corpus.rs @@ -3,7 +3,12 @@ use std::{ str::FromStr, }; -use crate::{raw::RawDocument, tests::LOCK, Bson, Document}; +use crate::{ + raw::{RawBson, RawDocument}, + tests::LOCK, + Bson, + Document, +}; use pretty_assertions::assert_eq; use serde::Deserialize; @@ -79,11 +84,19 @@ fn run_test(test: TestFile) { let todocument_documentfromreader_cb: Document = crate::to_document(&documentfromreader_cb).expect(&description); - let document_from_raw_document: Document = RawDocument::new(canonical_bson.as_slice()) + let canonical_raw_document = + RawDocument::new(canonical_bson.as_slice()).expect(&description); + let document_from_raw_document: Document = + canonical_raw_document.try_into().expect(&description); + + let canonical_raw_bson_from_slice = crate::from_slice::(canonical_bson.as_slice()) .expect(&description) - .try_into() + .as_document() .expect(&description); + let canonical_raw_document_from_slice = + crate::from_slice::<&RawDocument>(canonical_bson.as_slice()).expect(&description); + // These cover the ways to serialize those `Documents` back to BSON. let mut documenttowriter_documentfromreader_cb = Vec::new(); documentfromreader_cb @@ -113,6 +126,12 @@ fn run_test(test: TestFile) { .to_writer(&mut documenttowriter_document_from_raw_document) .expect(&description); + // Serialize the raw versions "back" to BSON also. + let tovec_rawdocument = crate::to_vec(&canonical_raw_document).expect(&description); + let tovec_rawdocument_from_slice = + crate::to_vec(&canonical_raw_document_from_slice).expect(&description); + let tovec_rawbson = crate::to_vec(&canonical_raw_bson_from_slice).expect(&description); + // native_to_bson( bson_to_native(cB) ) = cB // now we ensure the hex for all 5 are equivalent to the canonical BSON provided by the @@ -159,6 +178,20 @@ fn run_test(test: TestFile) { description, ); + assert_eq!(tovec_rawdocument, tovec_rawbson, "{}", description); + assert_eq!( + tovec_rawdocument, tovec_rawdocument_from_slice, + "{}", + description + ); + + assert_eq!( + hex::encode(tovec_rawdocument).to_lowercase(), + valid.canonical_bson.to_lowercase(), + "{}", + description, + ); + // NaN == NaN is false, so we skip document comparisons that contain NaN if !description.to_ascii_lowercase().contains("nan") && !description.contains("decq541") { assert_eq!(documentfromreader_cb, fromreader_cb, "{}", description); From c3b6f7049b11475354da8745095adb44c8cb1831 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Fri, 29 Oct 2021 17:09:04 -0400 Subject: [PATCH 12/21] extend corpus tests, add serde-tests --- serde-tests/test.rs | 168 +++++++++++++++++++++++++++++++++++++++ src/de/mod.rs | 5 +- src/de/raw.rs | 48 +++++++++++ src/de/serde.rs | 91 ++++++++++++++++----- src/decimal128.rs | 9 ++- src/lib.rs | 14 ++-- src/raw/bson.rs | 9 ++- src/ser/raw/mod.rs | 4 + src/tests/spec/corpus.rs | 81 ++++++++++++++++++- 9 files changed, 397 insertions(+), 32 deletions(-) diff --git a/serde-tests/test.rs b/serde-tests/test.rs index 12b2356c..796c8513 100644 --- a/serde-tests/test.rs +++ b/serde-tests/test.rs @@ -24,8 +24,16 @@ use bson::{ Deserializer, Document, JavaScriptCodeWithScope, + RawArray, + RawBinary, + RawDbPointer, + RawDocument, + RawDocumentBuf, + RawJavaScriptCodeWithScope, + RawRegex, Regex, Timestamp, + Uuid, }; /// Verifies the following: @@ -112,6 +120,18 @@ where ); } +/// Verifies the following: +/// - Deserializing a `T` from the provided bytes does not error +/// - Serializing the `T` back to bytes produces the input. +fn run_raw_round_trip_test<'de, T>(bytes: &'de [u8], description: &str) +where + T: Deserialize<'de> + Serialize + std::fmt::Debug, +{ + let t: T = bson::from_slice(bytes).expect(description); + let vec = bson::to_vec(&t).expect(description); + assert_eq!(vec.as_slice(), bytes); +} + #[test] fn smoke() { #[derive(Serialize, Deserialize, PartialEq, Debug)] @@ -682,6 +702,154 @@ fn empty_array() { run_deserialize_test(&v, &doc, "empty_array"); } +#[test] +fn raw_doc_buf() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo { + d: RawDocumentBuf, + } + + let bytes = bson::to_vec(&doc! { + "d": { + "a": 12, + "b": 5.5, + "c": [1, true, "ok"], + "d": { "a": "b" }, + "e": ObjectId::new(), + } + }) + .expect("raw_doc_buf"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_doc_buf"); +} + +#[test] +fn raw_doc() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + d: &'a RawDocument, + } + + let bytes = bson::to_vec(&doc! { + "d": { + "a": 12, + "b": 5.5, + "c": [1, true, "ok"], + "d": { "a": "b" }, + "e": ObjectId::new(), + } + }) + .expect("raw doc"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_doc"); +} + +#[test] +fn raw_array() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + d: &'a RawArray, + } + + let bytes = bson::to_vec(&doc! { + "d": [1, true, { "ok": 1 }, [ "sub", "array" ], Uuid::new()] + }) + .expect("raw_array"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_array"); +} + +#[test] +fn raw_binary() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + generic: RawBinary<'a>, + + #[serde(borrow)] + old: RawBinary<'a>, + + #[serde(borrow)] + uuid: RawBinary<'a>, + + #[serde(borrow)] + other: RawBinary<'a>, + } + + let bytes = bson::to_vec(&doc! { + "generic": Binary { + bytes: vec![1, 2, 3, 4, 5], + subtype: BinarySubtype::Generic, + }, + "old": Binary { + bytes: vec![1, 2, 3], + subtype: BinarySubtype::BinaryOld, + }, + "uuid": Uuid::new(), + "other": Binary { + bytes: vec![1u8; 100], + subtype: BinarySubtype::UserDefined(100), + } + }) + .expect("raw_binary"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_binary"); +} + +#[test] +fn raw_regex() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + r: RawRegex<'a>, + } + + let bytes = bson::to_vec(&doc! { + "r": Regex { + pattern: "a[b-c]d".to_string(), + options: "ab".to_string(), + }, + }) + .expect("raw_regex"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_regex"); +} + +#[test] +fn raw_code_w_scope() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + r: RawJavaScriptCodeWithScope<'a>, + } + + let bytes = bson::to_vec(&doc! { + "r": JavaScriptCodeWithScope { + code: "console.log(x)".to_string(), + scope: doc! { "x": 1 }, + }, + }) + .expect("raw_code_w_scope"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_code_w_scope"); +} + +#[test] +fn raw_db_pointer() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + a: RawDbPointer<'a>, + } + + // From the "DBpointer" bson corpus test + let bytes = hex::decode("1A0000000C610002000000620056E1FC72E0C917E9C471416100").unwrap(); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_db_pointer"); +} + #[test] fn all_types() { #[derive(Debug, Deserialize, Serialize, PartialEq)] diff --git a/src/de/mod.rs b/src/de/mod.rs index 658a876d..a0e8fbfc 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -46,7 +46,10 @@ use ::serde::{ Deserialize, }; -pub(crate) use self::serde::{convert_unsigned_to_signed_raw, BsonVisitor}; +pub(crate) use self::{ + raw::Deserializer as RawDeserializer, + serde::{convert_unsigned_to_signed_raw, BsonVisitor}, +}; pub(crate) const MAX_BSON_SIZE: i32 = 16 * 1024 * 1024; pub(crate) const MIN_BSON_DOCUMENT_SIZE: i32 = 4 + 1; // 4 bytes for length, one byte for null terminator diff --git a/src/de/raw.rs b/src/de/raw.rs index 9a6afdd9..6d9074d8 100644 --- a/src/de/raw.rs +++ b/src/de/raw.rs @@ -643,6 +643,10 @@ impl<'d, 'de> serde::de::Deserializer<'de> for DocumentKeyDeserializer<'d, 'de> } } + fn is_human_readable(&self) -> bool { + false + } + forward_to_deserialize_any! { bool char str bytes byte_buf option unit unit_struct string identifier newtype_struct seq tuple tuple_struct struct map enum @@ -665,6 +669,10 @@ impl<'de> serde::de::Deserializer<'de> for FieldDeserializer { visitor.visit_borrowed_str(self.field_name) } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -750,6 +758,10 @@ impl<'de> serde::de::Deserializer<'de> for RawDocumentDeserializer<'de> { visitor.visit_borrowed_bytes(self.raw_doc.as_bytes()) } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -818,6 +830,10 @@ impl<'de> serde::de::Deserializer<'de> for ObjectIdDeserializer { } } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -876,6 +892,10 @@ impl<'de> serde::de::Deserializer<'de> for Decimal128Deserializer { visitor.visit_bytes(&self.0.bytes) } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -967,6 +987,10 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut TimestampDeserializer { } } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -1067,6 +1091,10 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut DateTimeDeserializer { } } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -1196,6 +1224,10 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BinaryDeserializer<'de> { } } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -1325,6 +1357,10 @@ impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut CodeWithScopeDeserial } } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -1430,6 +1466,10 @@ impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut DbPointerDeserializer } } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -1533,6 +1573,10 @@ impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut RegexDeserializer<'de } } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -1630,6 +1674,10 @@ impl<'de, 'a> serde::de::Deserializer<'de> for RawBsonDeserializer<'de> { } } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct diff --git a/src/de/serde.rs b/src/de/serde.rs index 567b350a..05ec7330 100644 --- a/src/de/serde.rs +++ b/src/de/serde.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, convert::{TryFrom, TryInto}, fmt, vec, @@ -17,7 +18,7 @@ use serde::de::{ VariantAccess, Visitor, }; -use serde_bytes::ByteBuf; +use serde_bytes::{ByteBuf, Bytes}; use crate::{ bson::{Binary, Bson, DbPointer, JavaScriptCodeWithScope, Regex, Timestamp}, @@ -265,15 +266,68 @@ impl<'de> Visitor<'de> for BsonVisitor { while let Some(k) = visitor.next_key::()? { match k.as_str() { "$oid" => { - let hex: String = visitor.next_value()?; - return Ok(Bson::ObjectId(ObjectId::parse_str(hex.as_str()).map_err( - |_| { - V::Error::invalid_value( - Unexpected::Str(&hex), - &"24-character, big-endian hex string", - ) - }, - )?)); + enum BytesOrHex<'a> { + Bytes([u8; 12]), + Hex(Cow<'a, str>), + } + + impl<'a, 'de: 'a> Deserialize<'de> for BytesOrHex<'a> { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct BytesOrHexVisitor; + + impl<'de> Visitor<'de> for BytesOrHexVisitor { + type Value = BytesOrHex<'de>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "hexstring or byte array") + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + Ok(BytesOrHex::Hex(Cow::Owned(v.to_string()))) + } + + fn visit_borrowed_str( + self, + v: &'de str, + ) -> Result + where + E: Error, + { + Ok(BytesOrHex::Hex(Cow::Borrowed(v))) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: Error, + { + Ok(BytesOrHex::Bytes(v.try_into().map_err(Error::custom)?)) + } + } + + deserializer.deserialize_any(BytesOrHexVisitor) + } + } + + let bytes_or_hex: BytesOrHex = visitor.next_value()?; + match bytes_or_hex { + BytesOrHex::Bytes(b) => return Ok(Bson::ObjectId(ObjectId::from_bytes(b))), + BytesOrHex::Hex(hex) => { + return Ok(Bson::ObjectId(ObjectId::parse_str(&hex).map_err( + |_| { + V::Error::invalid_value( + Unexpected::Str(&hex), + &"24-character, big-endian hex string", + ) + }, + )?)); + } + } } "$symbol" => { let string: String = visitor.next_value()?; @@ -419,16 +473,12 @@ impl<'de> Visitor<'de> for BsonVisitor { "$numberDecimalBytes" => { let bytes = visitor.next_value::()?; - let arr = bytes.into_vec().try_into().map_err(|v: Vec| { - Error::custom(format!( - "expected decimal128 as byte buffer, instead got buffer of length {}", - v.len() - )) - })?; - return Ok(Bson::Decimal128(Decimal128 { bytes: arr })); + return Ok(Bson::Decimal128(Decimal128::deserialize_from_slice( + &bytes, + )?)); } - _ => { + k => { let v = visitor.next_value::()?; doc.insert(k, v); } @@ -1046,7 +1096,10 @@ impl<'de> Deserialize<'de> for Decimal128 { { match Bson::deserialize(deserializer)? { Bson::Decimal128(d128) => Ok(d128), - _ => Err(D::Error::custom("expecting Decimal128")), + o => Err(D::Error::custom(format!( + "expecting Decimal128, got {:?}", + o + ))), } } } diff --git a/src/decimal128.rs b/src/decimal128.rs index c217bb5d..9ae0d39c 100644 --- a/src/decimal128.rs +++ b/src/decimal128.rs @@ -1,6 +1,6 @@ //! [BSON Decimal128](https://github.com/mongodb/specifications/blob/master/source/bson-decimal128/decimal128.rst) data type representation -use std::fmt; +use std::{array::TryFromSliceError, convert::TryInto, fmt}; /// Struct representing a BSON Decimal128 type. /// @@ -22,6 +22,13 @@ impl Decimal128 { pub fn bytes(&self) -> [u8; 128 / 8] { self.bytes } + + pub(crate) fn deserialize_from_slice( + bytes: &[u8], + ) -> std::result::Result { + let arr: [u8; 128 / 8] = bytes.try_into().map_err(E::custom)?; + return Ok(Decimal128 { bytes: arr }); + } } impl fmt::Debug for Decimal128 { diff --git a/src/lib.rs b/src/lib.rs index d9379dca..5c34fed5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -271,16 +271,14 @@ pub use self::{ bson::{Array, Binary, Bson, DbPointer, Document, JavaScriptCodeWithScope, Regex, Timestamp}, datetime::DateTime, de::{ - from_bson, - from_document, - from_reader, - from_reader_utf8_lossy, - from_slice, - from_slice_utf8_lossy, - Deserializer, + from_bson, from_document, from_reader, from_reader_utf8_lossy, from_slice, + from_slice_utf8_lossy, Deserializer, }, decimal128::Decimal128, - raw::{RawDocument, RawDocumentBuf, RawArray}, + raw::{ + RawArray, RawBinary, RawDbPointer, RawDocument, RawDocumentBuf, RawJavaScriptCodeWithScope, + RawRegex, + }, ser::{to_bson, to_document, to_vec, Serializer}, uuid::{Uuid, UuidRepresentation}, }; diff --git a/src/raw/bson.rs b/src/raw/bson.rs index 711f3b03..0247df0a 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -6,7 +6,7 @@ use serde::{ Deserialize, Serialize, }; -use serde_bytes::Bytes; +use serde_bytes::{ByteBuf, Bytes}; use super::{Error, RawArray, RawDocument, Result}; use crate::{ @@ -397,7 +397,12 @@ impl<'de: 'a, 'a> Deserialize<'de> for RawBson<'a> { let s: &str = map.next_value()?; Ok(RawBson::Symbol(s)) } - "$numberDecimalBytes" => Ok(RawBson::Decimal128(map.next_value()?)), + "$numberDecimalBytes" => { + let bytes = map.next_value::()?; + return Ok(RawBson::Decimal128(Decimal128::deserialize_from_slice( + &bytes, + )?)); + } "$regularExpression" => { #[derive(Debug, Deserialize)] struct BorrowedRegexBody<'a> { diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index 5c559721..b49dc18f 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -205,6 +205,10 @@ impl<'a> serde::Serializer for &'a mut Serializer { self.update_element_type(ElementType::EmbeddedDocument)?; self.bytes.write_all(v)?; } + SerializerHint::RawArray => { + self.update_element_type(ElementType::Array)?; + self.bytes.write_all(v)?; + } _ => { self.update_element_type(ElementType::Binary)?; diff --git a/src/tests/spec/corpus.rs b/src/tests/spec/corpus.rs index f1906f73..0fbba561 100644 --- a/src/tests/spec/corpus.rs +++ b/src/tests/spec/corpus.rs @@ -1,5 +1,6 @@ use std::{ convert::{TryFrom, TryInto}, + marker::PhantomData, str::FromStr, }; @@ -10,7 +11,7 @@ use crate::{ Document, }; use pretty_assertions::assert_eq; -use serde::Deserialize; +use serde::{de::DeserializeSeed, Deserialize, Deserializer}; use super::run_spec_test; @@ -64,6 +65,34 @@ struct ParseError { string: String, } +struct FieldVisitor<'a, T>(&'a str, PhantomData); + +impl<'de, 'a, T> serde::de::Visitor<'de> for FieldVisitor<'a, T> +where + T: Deserialize<'de>, +{ + type Value = T; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "expecting RawBson at field {}", self.0) + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + while let Some((k, v)) = map.next_entry::()? { + if k.as_str() == self.0 { + return Ok(v); + } + } + Err(serde::de::Error::custom(format!( + "missing field: {}", + self.0 + ))) + } +} + fn run_test(test: TestFile) { let _guard = LOCK.run_concurrently(); for valid in test.valid { @@ -132,6 +161,56 @@ fn run_test(test: TestFile) { crate::to_vec(&canonical_raw_document_from_slice).expect(&description); let tovec_rawbson = crate::to_vec(&canonical_raw_bson_from_slice).expect(&description); + // test Bson / RawBson field deserialization + if let Some(ref test_key) = test.test_key { + // skip regex tests that don't have the value at the test key + if !description.contains("$regex query operator") { + // deserialize the field from raw Bytes into a RawBson + let mut deserializer_raw = + crate::de::RawDeserializer::new(canonical_bson.as_slice(), false); + let raw_bson_field = deserializer_raw + .deserialize_any(FieldVisitor(test_key.as_str(), PhantomData::)) + .expect(&description); + // convert to an owned Bson and put into a Document + let bson: Bson = raw_bson_field.try_into().expect(&description); + let from_raw_doc = doc! { + test_key: bson + }; + + // deserialize the field from raw Bytes into a Bson + let mut deserializer_value = + crate::de::RawDeserializer::new(canonical_bson.as_slice(), false); + let bson_field = deserializer_value + .deserialize_any(FieldVisitor(test_key.as_str(), PhantomData::)) + .expect(&description); + // put into a Document + let from_value_doc = doc! { + test_key: bson_field, + }; + + // deserialize the field from a Bson into a Bson + let mut deserializer_value_value = + crate::Deserializer::new(Bson::Document(documentfromreader_cb.clone())); + let bson_field = deserializer_value_value + .deserialize_any(FieldVisitor(test_key.as_str(), PhantomData::)) + .expect(&description); + // put into a Document + let from_value_value_doc = doc! { + test_key: bson_field, + }; + + // convert back into raw BSON for comparison with canonical BSON + let from_raw_vec = crate::to_vec(&from_raw_doc).expect(&description); + let from_value_vec = crate::to_vec(&from_value_doc).expect(&description); + let from_value_value_vec = + crate::to_vec(&from_value_value_doc).expect(&description); + + assert_eq!(from_raw_vec, canonical_bson, "{}", description); + assert_eq!(from_value_vec, canonical_bson, "{}", description); + assert_eq!(from_value_value_vec, canonical_bson, "{}", description); + } + } + // native_to_bson( bson_to_native(cB) ) = cB // now we ensure the hex for all 5 are equivalent to the canonical BSON provided by the From 9f036061e4e69aee62b4f2047d47124bbc81dc76 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Mon, 1 Nov 2021 14:00:00 -0400 Subject: [PATCH 13/21] simplify binary deserialization --- src/de/mod.rs | 17 +--------- src/de/raw.rs | 92 ++++++++++++++------------------------------------- 2 files changed, 25 insertions(+), 84 deletions(-) diff --git a/src/de/mod.rs b/src/de/mod.rs index a0e8fbfc..0e350df3 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -270,7 +270,7 @@ pub(crate) fn deserialize_bson_kvp( impl Binary { pub(crate) fn from_reader(mut reader: R) -> Result { - let len = read_i32(&mut reader)?; + let mut len = read_i32(&mut reader)?; if !(0..=MAX_BSON_SIZE).contains(&len) { return Err(Error::invalid_length( len as usize, @@ -278,21 +278,6 @@ impl Binary { )); } let subtype = BinarySubtype::from(read_u8(&mut reader)?); - Self::from_reader_with_len_and_payload(reader, len, subtype) - } - - // TODO: RUST-976: call through to the RawBinary version of this instead of duplicating code - pub(crate) fn from_reader_with_len_and_payload( - mut reader: R, - mut len: i32, - subtype: BinarySubtype, - ) -> Result { - if !(0..=MAX_BSON_SIZE).contains(&len) { - return Err(Error::invalid_length( - len as usize, - &format!("binary length must be between 0 and {}", MAX_BSON_SIZE).as_str(), - )); - } // Skip length data in old binary. if let BinarySubtype::BinaryOld = subtype { diff --git a/src/de/raw.rs b/src/de/raw.rs index 6d9074d8..d92c28a3 100644 --- a/src/de/raw.rs +++ b/src/de/raw.rs @@ -269,24 +269,13 @@ impl<'de> Deserializer<'de> { BinarySubtype::Generic => { visitor.visit_borrowed_bytes(self.bytes.read_slice(len as usize)?) } - _ if matches!(hint, DeserializerHint::RawBson) => { + _ => { let binary = RawBinary::from_slice_with_len_and_payload( self.bytes.read_slice(len as usize)?, len, subtype, )?; - let mut d = BinaryDeserializer::borrowed(binary); - visitor.visit_map(BinaryAccess { - deserializer: &mut d, - }) - } - _ => { - let binary = Binary::from_reader_with_len_and_payload( - &mut self.bytes, - len, - subtype, - )?; - let mut d = BinaryDeserializer::new(binary); + let mut d = BinaryDeserializer::new(binary, hint); visitor.visit_map(BinaryAccess { deserializer: &mut d, }) @@ -1120,32 +1109,17 @@ impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d, 'de> { where K: serde::de::DeserializeSeed<'de>, { - match self.deserializer.stage { - BinaryDeserializationStage::TopLevel => seed - .deserialize(FieldDeserializer { - field_name: "$binary", - }) - .map(Some), - BinaryDeserializationStage::Subtype => seed - .deserialize(FieldDeserializer { - field_name: "subType", - }) - .map(Some), - BinaryDeserializationStage::Bytes - if matches!(self.deserializer.binary, BinaryContent::Borrowed(_)) => - { - seed.deserialize(FieldDeserializer { - field_name: "bytes", - }) - .map(Some) - } - BinaryDeserializationStage::Bytes => seed - .deserialize(FieldDeserializer { - field_name: "base64", - }) - .map(Some), - BinaryDeserializationStage::Done => Ok(None), - } + let field_name = match self.deserializer.stage { + BinaryDeserializationStage::TopLevel => "$binary", + BinaryDeserializationStage::Subtype => "subType", + BinaryDeserializationStage::Bytes => match self.deserializer.hint { + DeserializerHint::RawBson => "bytes", + _ => "base64", + }, + BinaryDeserializationStage::Done => return Ok(None), + }; + + seed.deserialize(FieldDeserializer { field_name }).map(Some) } fn next_value_seed(&mut self, seed: V) -> Result @@ -1156,31 +1130,17 @@ impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d, 'de> { } } -/// Storage of possibly borrowed, possibly owned binary data. -#[derive(Debug)] -enum BinaryContent<'a> { - Borrowed(RawBinary<'a>), - Owned(Binary), -} - struct BinaryDeserializer<'a> { - binary: BinaryContent<'a>, + binary: RawBinary<'a>, + hint: DeserializerHint, stage: BinaryDeserializationStage, } -impl BinaryDeserializer<'static> { - fn new(binary: Binary) -> Self { - Self { - binary: BinaryContent::Owned(binary), - stage: BinaryDeserializationStage::TopLevel, - } - } -} - impl<'a> BinaryDeserializer<'a> { - fn borrowed(binary: RawBinary<'a>) -> Self { + fn new(binary: RawBinary<'a>, hint: DeserializerHint) -> Self { Self { - binary: BinaryContent::Borrowed(binary), + binary, + hint, stage: BinaryDeserializationStage::TopLevel, } } @@ -1202,20 +1162,16 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BinaryDeserializer<'de> { } BinaryDeserializationStage::Subtype => { self.stage = BinaryDeserializationStage::Bytes; - match self.binary { - BinaryContent::Owned(ref b) => { - visitor.visit_string(hex::encode([u8::from(b.subtype)])) - } - BinaryContent::Borrowed(b) => visitor.visit_u8(b.subtype().into()), + match self.hint { + DeserializerHint::RawBson => visitor.visit_u8(self.binary.subtype().into()), + _ => visitor.visit_string(hex::encode([u8::from(self.binary.subtype)])), } } BinaryDeserializationStage::Bytes => { self.stage = BinaryDeserializationStage::Done; - match self.binary { - BinaryContent::Owned(ref b) => { - visitor.visit_string(base64::encode(b.bytes.as_slice())) - } - BinaryContent::Borrowed(b) => visitor.visit_borrowed_bytes(b.as_bytes()), + match self.hint { + DeserializerHint::RawBson => visitor.visit_borrowed_bytes(self.binary.bytes), + _ => visitor.visit_string(base64::encode(self.binary.bytes)), } } BinaryDeserializationStage::Done => { From 162000b6e53c06ea59919f7017da969c7068b58e Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Mon, 1 Nov 2021 14:35:48 -0400 Subject: [PATCH 14/21] revert breaking changes to binary serialization output --- src/ser/raw/mod.rs | 1 - src/ser/serde.rs | 27 +++++++++++---------------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index b49dc18f..da0e1fcd 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -334,7 +334,6 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_struct(self, name: &'static str, _len: usize) -> Result { - // println!("struct {}", name); let value_type = match name { "$oid" => Some(ValueType::ObjectId), "$date" => Some(ValueType::DateTime), diff --git a/src/ser/serde.rs b/src/ser/serde.rs index b68add82..c371825b 100644 --- a/src/ser/serde.rs +++ b/src/ser/serde.rs @@ -574,22 +574,17 @@ impl Serialize for Binary { where S: ser::Serializer, { - // if let BinarySubtype::Generic = self.subtype { - // serializer.serialize_bytes(self.bytes.as_slice()) - // } else { - // let mut state = serializer.serialize_struct("$binary", 1)?; - // let body = extjson::models::BinaryBody { - // base64: base64::encode(self.bytes.as_slice()), - // subtype: hex::encode([self.subtype.into()]), - // }; - // state.serialize_field("$binary", &body)?; - // state.end() - // } - let raw_binary = RawBinary { - bytes: self.bytes.as_slice(), - subtype: self.subtype, - }; - raw_binary.serialize(serializer) + if let BinarySubtype::Generic = self.subtype { + serializer.serialize_bytes(self.bytes.as_slice()) + } else { + let mut state = serializer.serialize_struct("$binary", 1)?; + let body = extjson::models::BinaryBody { + base64: base64::encode(self.bytes.as_slice()), + subtype: hex::encode([self.subtype.into()]), + }; + state.serialize_field("$binary", &body)?; + state.end() + } } } From bec78c779a4899cab0899cce04a7f9cc3072759e Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Mon, 1 Nov 2021 14:53:03 -0400 Subject: [PATCH 15/21] fix clippy, simplify binary serialization --- src/de/mod.rs | 8 +++--- src/de/raw.rs | 8 ++---- src/de/serde.rs | 2 +- src/decimal128.rs | 4 +-- src/raw/bson.rs | 9 ++----- src/raw/document_buf.rs | 9 +++---- src/raw/mod.rs | 3 --- src/ser/raw/mod.rs | 2 +- src/ser/raw/value_serializer.rs | 47 +++++++++++++++++---------------- src/ser/serde.rs | 2 +- src/tests/spec/corpus.rs | 4 +-- 11 files changed, 42 insertions(+), 56 deletions(-) diff --git a/src/de/mod.rs b/src/de/mod.rs index 0e350df3..c577c891 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -46,10 +46,10 @@ use ::serde::{ Deserialize, }; -pub(crate) use self::{ - raw::Deserializer as RawDeserializer, - serde::{convert_unsigned_to_signed_raw, BsonVisitor}, -}; +pub(crate) use self::serde::{convert_unsigned_to_signed_raw, BsonVisitor}; + +#[cfg(test)] +pub(crate) use self::raw::Deserializer as RawDeserializer; pub(crate) const MAX_BSON_SIZE: i32 = 16 * 1024 * 1024; pub(crate) const MIN_BSON_DOCUMENT_SIZE: i32 = 4 + 1; // 4 bytes for length, one byte for null terminator diff --git a/src/de/raw.rs b/src/de/raw.rs index d92c28a3..45175945 100644 --- a/src/de/raw.rs +++ b/src/de/raw.rs @@ -13,17 +13,13 @@ use serde::{ use crate::{ oid::ObjectId, - raw::{RawBinary, RawBson, RAW_ARRAY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, + raw::{RawBinary, RAW_ARRAY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, spec::{BinarySubtype, ElementType}, uuid::UUID_NEWTYPE_NAME, - Binary, Bson, DateTime, - DbPointer, Decimal128, - JavaScriptCodeWithScope, RawDocument, - Regex, Timestamp, }; @@ -1507,7 +1503,7 @@ impl<'de, 'a> RegexDeserializer<'de, 'a> { impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut RegexDeserializer<'de, 'a> { type Error = Error; - fn deserialize_any(mut self, visitor: V) -> Result + fn deserialize_any(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { diff --git a/src/de/serde.rs b/src/de/serde.rs index 05ec7330..1c7c4e7d 100644 --- a/src/de/serde.rs +++ b/src/de/serde.rs @@ -18,7 +18,7 @@ use serde::de::{ VariantAccess, Visitor, }; -use serde_bytes::{ByteBuf, Bytes}; +use serde_bytes::ByteBuf; use crate::{ bson::{Binary, Bson, DbPointer, JavaScriptCodeWithScope, Regex, Timestamp}, diff --git a/src/decimal128.rs b/src/decimal128.rs index 9ae0d39c..533b10dd 100644 --- a/src/decimal128.rs +++ b/src/decimal128.rs @@ -1,6 +1,6 @@ //! [BSON Decimal128](https://github.com/mongodb/specifications/blob/master/source/bson-decimal128/decimal128.rst) data type representation -use std::{array::TryFromSliceError, convert::TryInto, fmt}; +use std::{convert::TryInto, fmt}; /// Struct representing a BSON Decimal128 type. /// @@ -27,7 +27,7 @@ impl Decimal128 { bytes: &[u8], ) -> std::result::Result { let arr: [u8; 128 / 8] = bytes.try_into().map_err(E::custom)?; - return Ok(Decimal128 { bytes: arr }); + Ok(Decimal128 { bytes: arr }) } } diff --git a/src/raw/bson.rs b/src/raw/bson.rs index 0247df0a..e42e22c5 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -1,18 +1,13 @@ use std::convert::{TryFrom, TryInto}; -use serde::{ - de::{MapAccess, Unexpected, Visitor}, - ser::SerializeStruct, - Deserialize, - Serialize, -}; +use serde::{de::Visitor, ser::SerializeStruct, Deserialize, Serialize}; use serde_bytes::{ByteBuf, Bytes}; use super::{Error, RawArray, RawDocument, Result}; use crate::{ extjson, oid::{self, ObjectId}, - raw::{RAW_ARRAY_NEWTYPE, RAW_BINARY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, + raw::{RAW_ARRAY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, spec::{BinarySubtype, ElementType}, Bson, DateTime, diff --git a/src/raw/document_buf.rs b/src/raw/document_buf.rs index a4bd12b6..f1c216d3 100644 --- a/src/raw/document_buf.rs +++ b/src/raw/document_buf.rs @@ -4,12 +4,9 @@ use std::{ ops::Deref, }; -use serde::{de::Visitor, Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Serialize}; -use crate::{ - raw::{RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, - Document, -}; +use crate::Document; use super::{Error, ErrorKind, Iter, RawBson, RawDocument, Result}; @@ -160,7 +157,7 @@ impl Serialize for RawDocumentBuf { where S: serde::Serializer, { - let doc: &RawDocument = &self; + let doc: &RawDocument = self.deref(); doc.serialize(serializer) } } diff --git a/src/raw/mod.rs b/src/raw/mod.rs index 4faa046a..48359110 100644 --- a/src/raw/mod.rs +++ b/src/raw/mod.rs @@ -143,9 +143,6 @@ pub(crate) const RAW_ARRAY_NEWTYPE: &str = "$__private__bson_RawArray"; /// Special newtype name indicating that the type being (de)serialized is a raw BSON value. pub(crate) const RAW_BSON_NEWTYPE: &str = "$__private__bson_RawBson"; -/// Special newtype name indicating that the type being (de)serialized is a raw BSON value. -pub(crate) const RAW_BINARY_NEWTYPE: &str = "$__private__bson_RawBinary"; - /// Given a u8 slice, return an i32 calculated from the first four bytes in /// little endian order. fn f64_from_slice(val: &[u8]) -> Result { diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index da0e1fcd..169848a7 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -12,7 +12,7 @@ use self::value_serializer::{ValueSerializer, ValueType}; use super::{write_binary, write_cstring, write_f64, write_i32, write_i64, write_string}; use crate::{ - raw::{RAW_ARRAY_NEWTYPE, RAW_BINARY_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, + raw::{RAW_ARRAY_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, ser::{Error, Result}, spec::{BinarySubtype, ElementType}, uuid::UUID_NEWTYPE_NAME, diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs index 99c7f6e7..91ea8a09 100644 --- a/src/ser/raw/value_serializer.rs +++ b/src/ser/raw/value_serializer.rs @@ -7,6 +7,7 @@ use serde::{ use crate::{ oid::ObjectId, + raw::RAW_DOCUMENT_NEWTYPE, ser::{write_binary, write_cstring, write_i32, write_i64, write_string, Error, Result}, spec::{BinarySubtype, ElementType}, }; @@ -30,12 +31,15 @@ enum SerializationStep { DateTimeNumberLong, Binary, - BinaryBase64, - BinarySubType { base64: String }, - - RawBinary, - RawBinaryBytes, - RawBinarySubType { bytes: Vec }, + /// This step can either transition to the raw or base64 steps depending + /// on whether a string or bytes are serialized. + BinaryBytes, + BinarySubType { + base64: String, + }, + RawBinarySubType { + bytes: Vec, + }, Symbol, @@ -45,7 +49,9 @@ enum SerializationStep { Timestamp, TimestampTime, - TimestampIncrement { time: i64 }, + TimestampIncrement { + time: i64, + }, DbPointer, DbPointerRef, @@ -54,7 +60,10 @@ enum SerializationStep { Code, CodeWithScopeCode, - CodeWithScopeScope { code: String, raw: bool }, + CodeWithScopeScope { + code: String, + raw: bool, + }, MinKey, @@ -110,7 +119,7 @@ impl<'a> ValueSerializer<'a> { pub(super) fn new(rs: &'a mut Serializer, value_type: ValueType) -> Self { let state = match value_type { ValueType::DateTime => SerializationStep::DateTime, - ValueType::Binary => SerializationStep::RawBinary, + ValueType::Binary => SerializationStep::Binary, ValueType::ObjectId => SerializationStep::Oid, ValueType::Symbol => SerializationStep::Symbol, ValueType::RegularExpression => SerializationStep::RegEx, @@ -240,7 +249,7 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { let oid = ObjectId::parse_str(v).map_err(Error::custom)?; self.root_serializer.bytes.write_all(&oid.bytes())?; } - SerializationStep::BinaryBase64 => { + SerializationStep::BinaryBytes => { self.state = SerializationStep::BinarySubType { base64: v.to_string(), }; @@ -292,7 +301,7 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { self.root_serializer.bytes.write_all(v)?; Ok(()) } - SerializationStep::RawBinaryBytes => { + SerializationStep::BinaryBytes => { self.state = SerializationStep::RawBinarySubType { bytes: v.to_vec() }; Ok(()) } @@ -349,7 +358,7 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { match (&mut self.state, name) { ( SerializationStep::CodeWithScopeScope { - ref code, + code: _, ref mut raw, }, RAW_DOCUMENT_NEWTYPE, @@ -457,11 +466,11 @@ impl<'a, 'b> SerializeStruct for &'b mut ValueSerializer<'a> { value.serialize(&mut **self)?; self.state = SerializationStep::Done; } - (SerializationStep::RawBinary, "$binary") => { - self.state = SerializationStep::RawBinaryBytes; + (SerializationStep::Binary, "$binary") => { + self.state = SerializationStep::BinaryBytes; value.serialize(&mut **self)?; } - (SerializationStep::RawBinaryBytes, "bytes") => { + (SerializationStep::BinaryBytes, "base64" | "bytes") => { // state is updated in serialize value.serialize(&mut **self)?; } @@ -469,14 +478,6 @@ impl<'a, 'b> SerializeStruct for &'b mut ValueSerializer<'a> { value.serialize(&mut **self)?; self.state = SerializationStep::Done; } - (SerializationStep::Binary, "$binary") => { - self.state = SerializationStep::BinaryBase64; - value.serialize(&mut **self)?; - } - (SerializationStep::BinaryBase64, "base64") => { - // state is updated in serialize - value.serialize(&mut **self)?; - } (SerializationStep::BinarySubType { .. }, "subType") => { value.serialize(&mut **self)?; self.state = SerializationStep::Done; diff --git a/src/ser/serde.rs b/src/ser/serde.rs index c371825b..046978d8 100644 --- a/src/ser/serde.rs +++ b/src/ser/serde.rs @@ -17,7 +17,7 @@ use crate::{ datetime::DateTime, extjson, oid::ObjectId, - raw::{RawBinary, RawDbPointer, RawRegex}, + raw::{RawDbPointer, RawRegex}, spec::BinarySubtype, uuid::UUID_NEWTYPE_NAME, Binary, diff --git a/src/tests/spec/corpus.rs b/src/tests/spec/corpus.rs index 0fbba561..b4329fb1 100644 --- a/src/tests/spec/corpus.rs +++ b/src/tests/spec/corpus.rs @@ -11,7 +11,7 @@ use crate::{ Document, }; use pretty_assertions::assert_eq; -use serde::{de::DeserializeSeed, Deserialize, Deserializer}; +use serde::{Deserialize, Deserializer}; use super::run_spec_test; @@ -189,7 +189,7 @@ fn run_test(test: TestFile) { }; // deserialize the field from a Bson into a Bson - let mut deserializer_value_value = + let deserializer_value_value = crate::Deserializer::new(Bson::Document(documentfromreader_cb.clone())); let bson_field = deserializer_value_value .deserialize_any(FieldVisitor(test_key.as_str(), PhantomData::)) From 8a53f7bf233d8c8aa497da65b38c7a80e16706d7 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Mon, 1 Nov 2021 19:50:38 -0400 Subject: [PATCH 16/21] ensure RawDocument and co can round trip through non-BSON --- serde-tests/Cargo.toml | 5 + serde-tests/test.rs | 358 ++++++++++++++++++++++----------- src/de/raw.rs | 20 ++ src/de/serde.rs | 7 +- src/lib.rs | 2 +- src/raw/array.rs | 15 +- src/raw/bson.rs | 440 ++++++++++++++++++++--------------------- src/raw/document.rs | 17 +- src/raw/mod.rs | 2 + 9 files changed, 519 insertions(+), 347 deletions(-) diff --git a/serde-tests/Cargo.toml b/serde-tests/Cargo.toml index d5bbe9e7..93da0365 100644 --- a/serde-tests/Cargo.toml +++ b/serde-tests/Cargo.toml @@ -13,6 +13,11 @@ serde = { version = "1.0", features = ["derive"] } pretty_assertions = "0.6.1" hex = "0.4.2" +[dev-dependencies] +serde_json = "1" +rmp-serde = "0.15" +base64 = "0.13.0" + [lib] name = "serde_tests" path = "lib.rs" diff --git a/serde-tests/test.rs b/serde-tests/test.rs index 796c8513..c3e0d0a4 100644 --- a/serde-tests/test.rs +++ b/serde-tests/test.rs @@ -21,11 +21,13 @@ use bson::{ Binary, Bson, DateTime, + Decimal128, Deserializer, Document, JavaScriptCodeWithScope, RawArray, RawBinary, + RawBson, RawDbPointer, RawDocument, RawDocumentBuf, @@ -850,136 +852,264 @@ fn raw_db_pointer() { run_raw_round_trip_test::(bytes.as_slice(), "raw_db_pointer"); } +#[derive(Debug, Deserialize, Serialize, PartialEq)] +struct SubDoc { + a: i32, + b: i32, +} + +#[derive(Debug, Deserialize, Serialize, PartialEq)] +struct AllTypes { + x: i32, + y: i64, + s: String, + array: Vec, + bson: Bson, + oid: ObjectId, + null: Option<()>, + subdoc: Document, + b: bool, + d: f64, + binary: Binary, + binary_old: Binary, + binary_other: Binary, + date: DateTime, + regex: Regex, + ts: Timestamp, + i: SubDoc, + undefined: Bson, + code: Bson, + code_w_scope: JavaScriptCodeWithScope, + decimal: Decimal128, + symbol: Bson, + min_key: Bson, + max_key: Bson, +} + +impl AllTypes { + fn fixtures() -> (Self, Document) { + let binary = Binary { + bytes: vec![36, 36, 36], + subtype: BinarySubtype::Generic, + }; + let binary_old = Binary { + bytes: vec![36, 36, 36], + subtype: BinarySubtype::BinaryOld, + }; + let binary_other = Binary { + bytes: vec![36, 36, 36], + subtype: BinarySubtype::UserDefined(0x81), + }; + let date = DateTime::now(); + let regex = Regex { + pattern: "hello".to_string(), + options: "x".to_string(), + }; + let timestamp = Timestamp { + time: 123, + increment: 456, + }; + let code = Bson::JavaScriptCode("console.log(1)".to_string()); + let code_w_scope = JavaScriptCodeWithScope { + code: "console.log(a)".to_string(), + scope: doc! { "a": 1 }, + }; + let oid = ObjectId::new(); + let subdoc = doc! { "k": true, "b": { "hello": "world" } }; + + let decimal = { + let bytes = hex::decode("18000000136400D0070000000000000000000000003A3000").unwrap(); + let d = Document::from_reader(bytes.as_slice()).unwrap(); + match d.get("d") { + Some(Bson::Decimal128(d)) => *d, + c => panic!("expected decimal128, got {:?}", c), + } + }; + + let doc = doc! { + "x": 1, + "y": 2_i64, + "s": "oke", + "array": [ true, "oke", { "12": 24 } ], + "bson": 1234.5, + "oid": oid, + "null": Bson::Null, + "subdoc": subdoc.clone(), + "b": true, + "d": 12.5, + "binary": binary.clone(), + "binary_old": binary_old.clone(), + "binary_other": binary_other.clone(), + "date": date, + "regex": regex.clone(), + "ts": timestamp, + "i": { "a": 300, "b": 12345 }, + "undefined": Bson::Undefined, + "code": code.clone(), + "code_w_scope": code_w_scope.clone(), + "decimal": Bson::Decimal128(decimal), + "symbol": Bson::Symbol("ok".to_string()), + "min_key": Bson::MinKey, + "max_key": Bson::MaxKey, + }; + + let v = AllTypes { + x: 1, + y: 2, + s: "oke".to_string(), + array: vec![ + Bson::Boolean(true), + Bson::String("oke".to_string()), + Bson::Document(doc! { "12": 24 }), + ], + bson: Bson::Double(1234.5), + oid, + null: None, + subdoc, + b: true, + d: 12.5, + binary, + binary_old, + binary_other, + date, + regex, + ts: timestamp, + i: SubDoc { a: 300, b: 12345 }, + undefined: Bson::Undefined, + code, + code_w_scope, + decimal, + symbol: Bson::Symbol("ok".to_string()), + min_key: Bson::MinKey, + max_key: Bson::MaxKey, + }; + + (v, doc) + } +} + #[test] fn all_types() { - #[derive(Debug, Deserialize, Serialize, PartialEq)] - struct Bar { - a: i32, - b: i32, - } + let (v, doc) = AllTypes::fixtures(); - #[derive(Debug, Deserialize, Serialize, PartialEq)] - struct Foo { - x: i32, - y: i64, - s: String, - array: Vec, - bson: Bson, - oid: ObjectId, - null: Option<()>, - subdoc: Document, - b: bool, - d: f64, - binary: Binary, - binary_old: Binary, - binary_other: Binary, - date: DateTime, - regex: Regex, - ts: Timestamp, - i: Bar, - undefined: Bson, - code: Bson, - code_w_scope: JavaScriptCodeWithScope, - decimal: Bson, - symbol: Bson, - min_key: Bson, - max_key: Bson, - } + run_test(&v, &doc, "all types"); +} - let binary = Binary { - bytes: vec![36, 36, 36], - subtype: BinarySubtype::Generic, - }; - let binary_old = Binary { - bytes: vec![36, 36, 36], - subtype: BinarySubtype::BinaryOld, - }; - let binary_other = Binary { - bytes: vec![36, 36, 36], - subtype: BinarySubtype::UserDefined(0x81), - }; - let date = DateTime::now(); - let regex = Regex { - pattern: "hello".to_string(), - options: "x".to_string(), - }; - let timestamp = Timestamp { - time: 123, - increment: 456, - }; - let code = Bson::JavaScriptCode("console.log(1)".to_string()); - let code_w_scope = JavaScriptCodeWithScope { - code: "console.log(a)".to_string(), - scope: doc! { "a": 1 }, +#[test] +fn all_types_json() { + let (mut v, _) = AllTypes::fixtures(); + + let code = match v.code { + Bson::JavaScriptCode(ref c) => c.clone(), + c => panic!("expected code, found {:?}", c), }; - let oid = ObjectId::new(); - let subdoc = doc! { "k": true, "b": { "hello": "world" } }; - let decimal = { - let bytes = hex::decode("18000000136400D0070000000000000000000000003A3000").unwrap(); - let d = Document::from_reader(bytes.as_slice()).unwrap(); - d.get("d").unwrap().clone() + let code_w_scope = JavaScriptCodeWithScope { + code: "hello world".to_string(), + scope: doc! { "x": 1 }, }; + let scope_json = serde_json::json!({ "x": 1 }); + v.code_w_scope = code_w_scope.clone(); - let doc = doc! { + let json = serde_json::json!({ "x": 1, - "y": 2_i64, + "y": 2, "s": "oke", - "array": [ true, "oke", { "12": 24 } ], + "array": vec![ + serde_json::json!(true), + serde_json::json!("oke".to_string()), + serde_json::json!({ "12": 24 }), + ], "bson": 1234.5, - "oid": oid, - "null": Bson::Null, - "subdoc": subdoc.clone(), + "oid": { "$oid": v.oid.to_hex() }, + "null": serde_json::Value::Null, + "subdoc": { "k": true, "b": { "hello": "world" } }, "b": true, "d": 12.5, - "binary": binary.clone(), - "binary_old": binary_old.clone(), - "binary_other": binary_other.clone(), - "date": date, - "regex": regex.clone(), - "ts": timestamp, - "i": { "a": 300, "b": 12345 }, - "undefined": Bson::Undefined, - "code": code.clone(), - "code_w_scope": code_w_scope.clone(), - "decimal": decimal.clone(), - "symbol": Bson::Symbol("ok".to_string()), - "min_key": Bson::MinKey, - "max_key": Bson::MaxKey, - }; + "binary": v.binary.bytes, + "binary_old": { "$binary": { "base64": base64::encode(&v.binary_old.bytes), "subType": "02" } }, + "binary_other": { "$binary": { "base64": base64::encode(&v.binary_old.bytes), "subType": "81" } }, + "date": { "$date": { "$numberLong": v.date.timestamp_millis().to_string() } }, + "regex": { "$regularExpression": { "pattern": v.regex.pattern, "options": v.regex.options } }, + "ts": { "$timestamp": { "t": 123, "i": 456 } }, + "i": { "a": v.i.a, "b": v.i.b }, + "undefined": { "$undefined": true }, + "code": { "$code": code }, + "code_w_scope": { "$code": code_w_scope.code, "$scope": scope_json }, + "decimal": { "$numberDecimalBytes": v.decimal.bytes() }, + "symbol": { "$symbol": "ok" }, + "min_key": { "$minKey": 1 }, + "max_key": { "$maxKey": 1 }, + }); + + assert_eq!(serde_json::to_value(&v).unwrap(), json); +} - let v = Foo { - x: 1, - y: 2, - s: "oke".to_string(), - array: vec![ - Bson::Boolean(true), - Bson::String("oke".to_string()), - Bson::Document(doc! { "12": 24 }), - ], - bson: Bson::Double(1234.5), - oid, - null: None, - subdoc, - b: true, - d: 12.5, - binary, - binary_old, - binary_other, - date, - regex, - ts: timestamp, - i: Bar { a: 300, b: 12345 }, - undefined: Bson::Undefined, - code, - code_w_scope, - decimal, - symbol: Bson::Symbol("ok".to_string()), - min_key: Bson::MinKey, - max_key: Bson::MaxKey, +#[test] +fn all_types_rmp() { + let (v, _) = AllTypes::fixtures(); + let serialized = rmp_serde::to_vec_named(&v).unwrap(); + let back: AllTypes = rmp_serde::from_slice(&serialized).unwrap(); + + assert_eq!(back, v); +} + +#[test] +fn all_raw_types_rmp() { + #[derive(Debug, Serialize, Deserialize, PartialEq)] + struct AllRawTypes<'a> { + #[serde(borrow)] + bson: RawBson<'a>, + #[serde(borrow)] + document: &'a RawDocument, + #[serde(borrow)] + array: &'a RawArray, + buf: RawDocumentBuf, + #[serde(borrow)] + binary: RawBinary<'a>, + #[serde(borrow)] + code_w_scope: RawJavaScriptCodeWithScope<'a>, + #[serde(borrow)] + regex: RawRegex<'a>, + } + + let doc_bytes = bson::to_vec(&doc! { + "bson": "some string", + "array": [1, 2, 3], + "binary": Binary { bytes: vec![1, 2, 3], subtype: BinarySubtype::Generic }, + "binary_old": Binary { bytes: vec![1, 2, 3], subtype: BinarySubtype::BinaryOld }, + "code_w_scope": JavaScriptCodeWithScope { + code: "ok".to_string(), + scope: doc! { "x": 1 }, + }, + "regex": Regex { + pattern: "pattern".to_string(), + options: "opt".to_string() + } + }) + .unwrap(); + let doc_buf = RawDocumentBuf::new(doc_bytes).unwrap(); + let document = &doc_buf; + let array = document.get_array("array").unwrap(); + + let v = AllRawTypes { + bson: document.get("bson").unwrap().unwrap(), + array, + document, + buf: doc_buf.clone(), + binary: document.get_binary("binary").unwrap(), + code_w_scope: document + .get("code_w_scope") + .unwrap() + .unwrap() + .as_javascript_with_scope() + .unwrap(), + regex: document.get_regex("regex").unwrap(), }; + let serialized = rmp_serde::to_vec_named(&v).unwrap(); + let back: AllRawTypes = rmp_serde::from_slice(&serialized).unwrap(); - run_test(&v, &doc, "all types"); + assert_eq!(back, v); } #[test] diff --git a/src/de/raw.rs b/src/de/raw.rs index 45175945..478e00da 100644 --- a/src/de/raw.rs +++ b/src/de/raw.rs @@ -447,6 +447,26 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { DeserializerHint::BinarySubtype(BinarySubtype::Uuid), ), RAW_BSON_NEWTYPE => self.deserialize_next(visitor, DeserializerHint::RawBson), + RAW_DOCUMENT_NEWTYPE => { + if self.current_type != ElementType::EmbeddedDocument { + return Err(serde::de::Error::custom(format!( + "expected raw document, instead got {:?}", + self.current_type + ))); + } + + self.deserialize_next(visitor, DeserializerHint::RawBson) + } + RAW_ARRAY_NEWTYPE => { + if self.current_type != ElementType::Array { + return Err(serde::de::Error::custom(format!( + "expected raw array, instead got {:?}", + self.current_type + ))); + } + + self.deserialize_next(visitor, DeserializerHint::RawBson) + } _ => visitor.visit_newtype_struct(self), } } diff --git a/src/de/serde.rs b/src/de/serde.rs index 1c7c4e7d..23fc5e8d 100644 --- a/src/de/serde.rs +++ b/src/de/serde.rs @@ -1082,9 +1082,12 @@ impl<'de> Deserialize<'de> for Binary { where D: de::Deserializer<'de>, { - match Bson::deserialize(deserializer)? { + match deserializer.deserialize_byte_buf(BsonVisitor)? { Bson::Binary(binary) => Ok(binary), - _ => Err(D::Error::custom("expecting Binary")), + d => Err(D::Error::custom(format!( + "expecting Binary but got {:?} instead", + d + ))), } } } diff --git a/src/lib.rs b/src/lib.rs index 5c34fed5..be043012 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -276,7 +276,7 @@ pub use self::{ }, decimal128::Decimal128, raw::{ - RawArray, RawBinary, RawDbPointer, RawDocument, RawDocumentBuf, RawJavaScriptCodeWithScope, + RawArray, RawBinary, RawBson, RawDbPointer, RawDocument, RawDocumentBuf, RawJavaScriptCodeWithScope, RawRegex, }, ser::{to_bson, to_document, to_vec, Serializer}, diff --git a/src/raw/array.rs b/src/raw/array.rs index 44fa903d..665c5632 100644 --- a/src/raw/array.rs +++ b/src/raw/array.rs @@ -12,7 +12,14 @@ use super::{ RawRegex, Result, }; -use crate::{oid::ObjectId, raw::RAW_ARRAY_NEWTYPE, spec::ElementType, Bson, DateTime, Timestamp}; +use crate::{ + oid::ObjectId, + raw::{RawBsonVisitor, RAW_ARRAY_NEWTYPE}, + spec::{BinarySubtype, ElementType}, + Bson, + DateTime, + Timestamp, +}; /// A slice of a BSON document containing a BSON array value (akin to [`std::str`]). This can be /// retrieved from a [`RawDocument`] via [`RawDocument::get`]. @@ -248,8 +255,12 @@ impl<'de: 'a, 'a> Deserialize<'de> for &'a RawArray { where D: serde::Deserializer<'de>, { - match RawBson::deserialize(deserializer)? { + match deserializer.deserialize_newtype_struct(RAW_ARRAY_NEWTYPE, RawBsonVisitor)? { RawBson::Array(d) => Ok(d), + RawBson::Binary(b) if b.subtype == BinarySubtype::Generic => { + let doc = RawDocument::new(b.bytes).map_err(serde::de::Error::custom)?; + Ok(RawArray::from_doc(doc)) + } b => Err(serde::de::Error::custom(format!( "expected raw array reference, instead got {:?}", b diff --git a/src/raw/bson.rs b/src/raw/bson.rs index e42e22c5..b35909ff 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -245,260 +245,252 @@ impl<'a> RawBson<'a> { } } -impl<'de: 'a, 'a> Deserialize<'de> for RawBson<'a> { - fn deserialize(deserializer: D) -> std::result::Result +/// A visitor used to deserialize types backed by raw BSON. +pub(crate) struct RawBsonVisitor; + +impl<'de> Visitor<'de> for RawBsonVisitor { + type Value = RawBson<'de>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a raw BSON reference") + } + + fn visit_borrowed_str(self, v: &'de str) -> std::result::Result where - D: serde::Deserializer<'de>, + E: serde::de::Error, { - use serde::de::Error as SerdeError; + Ok(RawBson::String(v)) + } - struct RawBsonVisitor; + fn visit_borrowed_bytes(self, bytes: &'de [u8]) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Binary(RawBinary { + bytes, + subtype: BinarySubtype::Generic, + })) + } - impl<'de> Visitor<'de> for RawBsonVisitor { - type Value = RawBson<'de>; + fn visit_i8(self, v: i8) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Int32(v.into())) + } - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "a raw BSON reference") - } + fn visit_i16(self, v: i16) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Int32(v.into())) + } - fn visit_borrowed_str(self, v: &'de str) -> std::result::Result - where - E: serde::de::Error, - { - Ok(RawBson::String(v)) - } + fn visit_i32(self, v: i32) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Int32(v)) + } - fn visit_borrowed_bytes( - self, - bytes: &'de [u8], - ) -> std::result::Result - where - E: SerdeError, - { - Ok(RawBson::Binary(RawBinary { - bytes, - subtype: BinarySubtype::Generic, - })) - } + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Int64(v)) + } - fn visit_i8(self, v: i8) -> std::result::Result - where - E: SerdeError, - { - Ok(RawBson::Int32(v.into())) - } + fn visit_u8(self, value: u8) -> std::result::Result + where + E: serde::de::Error, + { + crate::de::convert_unsigned_to_signed_raw(value.into()) + } - fn visit_i16(self, v: i16) -> std::result::Result - where - E: SerdeError, - { - Ok(RawBson::Int32(v.into())) - } + fn visit_u16(self, value: u16) -> std::result::Result + where + E: serde::de::Error, + { + crate::de::convert_unsigned_to_signed_raw(value.into()) + } - fn visit_i32(self, v: i32) -> std::result::Result - where - E: serde::de::Error, - { - Ok(RawBson::Int32(v)) - } + fn visit_u32(self, value: u32) -> std::result::Result + where + E: serde::de::Error, + { + crate::de::convert_unsigned_to_signed_raw(value.into()) + } - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - Ok(RawBson::Int64(v)) - } + fn visit_u64(self, value: u64) -> std::result::Result + where + E: serde::de::Error, + { + crate::de::convert_unsigned_to_signed_raw(value) + } - fn visit_u8(self, value: u8) -> std::result::Result - where - E: serde::de::Error, - { - crate::de::convert_unsigned_to_signed_raw(value.into()) - } + fn visit_bool(self, v: bool) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Boolean(v)) + } - fn visit_u16(self, value: u16) -> std::result::Result - where - E: serde::de::Error, - { - crate::de::convert_unsigned_to_signed_raw(value.into()) - } + fn visit_f64(self, v: f64) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Double(v)) + } - fn visit_u32(self, value: u32) -> std::result::Result - where - E: serde::de::Error, - { - crate::de::convert_unsigned_to_signed_raw(value.into()) - } + fn visit_none(self) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Null) + } - fn visit_u64(self, value: u64) -> std::result::Result - where - E: serde::de::Error, - { - crate::de::convert_unsigned_to_signed_raw(value) - } + fn visit_unit(self) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Null) + } - fn visit_bool(self, v: bool) -> std::result::Result - where - E: serde::de::Error, - { - Ok(RawBson::Boolean(v)) - } + fn visit_newtype_struct(self, deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } - fn visit_f64(self, v: f64) -> std::result::Result - where - E: serde::de::Error, - { - Ok(RawBson::Double(v)) + fn visit_map(self, mut map: A) -> std::result::Result + where + A: serde::de::MapAccess<'de>, + { + let k = map + .next_key::<&str>()? + .ok_or_else(|| serde::de::Error::custom("expected a key when deserializing RawBson"))?; + match k { + "$oid" => { + let oid: ObjectId = map.next_value()?; + Ok(RawBson::ObjectId(oid)) } - - fn visit_none(self) -> std::result::Result - where - E: serde::de::Error, - { - Ok(RawBson::Null) + "$symbol" => { + let s: &str = map.next_value()?; + Ok(RawBson::Symbol(s)) } - - fn visit_unit(self) -> std::result::Result - where - E: serde::de::Error, - { - Ok(RawBson::Null) + "$numberDecimalBytes" => { + let bytes = map.next_value::()?; + return Ok(RawBson::Decimal128(Decimal128::deserialize_from_slice( + &bytes, + )?)); } + "$regularExpression" => { + #[derive(Debug, Deserialize)] + struct BorrowedRegexBody<'a> { + pattern: &'a str, - fn visit_newtype_struct( - self, - deserializer: D, - ) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(self) + options: &'a str, + } + let body: BorrowedRegexBody = map.next_value()?; + Ok(RawBson::RegularExpression(RawRegex { + pattern: body.pattern, + options: body.options, + })) } + "$undefined" => { + let _: bool = map.next_value()?; + Ok(RawBson::Undefined) + } + "$binary" => { + #[derive(Debug, Deserialize)] + struct BorrowedBinaryBody<'a> { + bytes: &'a [u8], - // use extjson for: ObjectId, datetime, timestamp, symbol, minkey, maxkey - fn visit_map(self, mut map: A) -> std::result::Result - where - A: serde::de::MapAccess<'de>, - { - let k = map.next_key::<&str>()?.ok_or_else(|| { - SerdeError::custom("expected a key when deserializing RawBson") - })?; - match k { - "$oid" => { - let oid: ObjectId = map.next_value()?; - Ok(RawBson::ObjectId(oid)) - } - "$symbol" => { - let s: &str = map.next_value()?; - Ok(RawBson::Symbol(s)) - } - "$numberDecimalBytes" => { - let bytes = map.next_value::()?; - return Ok(RawBson::Decimal128(Decimal128::deserialize_from_slice( - &bytes, - )?)); - } - "$regularExpression" => { - #[derive(Debug, Deserialize)] - struct BorrowedRegexBody<'a> { - pattern: &'a str, - - options: &'a str, - } - let body: BorrowedRegexBody = map.next_value()?; - Ok(RawBson::RegularExpression(RawRegex { - pattern: body.pattern, - options: body.options, - })) - } - "$undefined" => { - let _: bool = map.next_value()?; - Ok(RawBson::Undefined) - } - "$binary" => { - #[derive(Debug, Deserialize)] - struct BorrowedBinaryBody<'a> { - bytes: &'a [u8], - - #[serde(rename = "subType")] - subtype: u8, - } + #[serde(rename = "subType")] + subtype: u8, + } - let v = map.next_value::()?; + let v = map.next_value::()?; - Ok(RawBson::Binary(RawBinary { - bytes: v.bytes, - subtype: v.subtype.into(), - })) - } - "$date" => { - let v = map.next_value::()?; - Ok(RawBson::DateTime(DateTime::from_millis(v))) - } - "$timestamp" => { - let v = map.next_value::()?; - Ok(RawBson::Timestamp(Timestamp { - time: v.t, - increment: v.i, - })) - } - "$minKey" => { - let _ = map.next_value::()?; - Ok(RawBson::MinKey) - } - "$maxKey" => { - let _ = map.next_value::()?; - Ok(RawBson::MaxKey) - } - "$code" => { - let code = map.next_value::<&str>()?; - if let Some(key) = map.next_key::<&str>()? { - if key == "$scope" { - let scope = map.next_value::<&RawDocument>()?; - Ok(RawBson::JavaScriptCodeWithScope( - RawJavaScriptCodeWithScope { code, scope }, - )) - } else { - Err(SerdeError::unknown_field(key, &["$scope"])) - } - } else { - Ok(RawBson::JavaScriptCode(code)) - } - } - "$dbPointer" => { - #[derive(Deserialize)] - struct BorrowedDbPointerBody<'a> { - #[serde(rename = "$ref")] - ns: &'a str, - - #[serde(rename = "$id")] - id: ObjectId, - } - - let body: BorrowedDbPointerBody = map.next_value()?; - Ok(RawBson::DbPointer(RawDbPointer { - namespace: body.ns, - id: body.id, - })) - } - RAW_DOCUMENT_NEWTYPE => { - let bson = map.next_value::<&[u8]>()?; - let doc = RawDocument::new(bson).map_err(SerdeError::custom)?; - Ok(RawBson::Document(doc)) - } - RAW_ARRAY_NEWTYPE => { - let bson = map.next_value::<&[u8]>()?; - let doc = RawDocument::new(bson).map_err(SerdeError::custom)?; - Ok(RawBson::Array(RawArray::from_doc(doc))) + Ok(RawBson::Binary(RawBinary { + bytes: v.bytes, + subtype: v.subtype.into(), + })) + } + "$date" => { + let v = map.next_value::()?; + Ok(RawBson::DateTime(DateTime::from_millis(v))) + } + "$timestamp" => { + let v = map.next_value::()?; + Ok(RawBson::Timestamp(Timestamp { + time: v.t, + increment: v.i, + })) + } + "$minKey" => { + let _ = map.next_value::()?; + Ok(RawBson::MinKey) + } + "$maxKey" => { + let _ = map.next_value::()?; + Ok(RawBson::MaxKey) + } + "$code" => { + let code = map.next_value::<&str>()?; + if let Some(key) = map.next_key::<&str>()? { + if key == "$scope" { + let scope = map.next_value::<&RawDocument>()?; + Ok(RawBson::JavaScriptCodeWithScope( + RawJavaScriptCodeWithScope { code, scope }, + )) + } else { + Err(serde::de::Error::unknown_field(key, &["$scope"])) } - k => Err(SerdeError::custom(format!( - "can't deserialize RawBson from map, key={}", - k - ))), + } else { + Ok(RawBson::JavaScriptCode(code)) } } + "$dbPointer" => { + #[derive(Deserialize)] + struct BorrowedDbPointerBody<'a> { + #[serde(rename = "$ref")] + ns: &'a str, + + #[serde(rename = "$id")] + id: ObjectId, + } + + let body: BorrowedDbPointerBody = map.next_value()?; + Ok(RawBson::DbPointer(RawDbPointer { + namespace: body.ns, + id: body.id, + })) + } + RAW_DOCUMENT_NEWTYPE => { + let bson = map.next_value::<&[u8]>()?; + let doc = RawDocument::new(bson).map_err(serde::de::Error::custom)?; + Ok(RawBson::Document(doc)) + } + RAW_ARRAY_NEWTYPE => { + let bson = map.next_value::<&[u8]>()?; + let doc = RawDocument::new(bson).map_err(serde::de::Error::custom)?; + Ok(RawBson::Array(RawArray::from_doc(doc))) + } + k => Err(serde::de::Error::custom(format!( + "can't deserialize RawBson from map, key={}", + k + ))), } + } +} +impl<'de: 'a, 'a> Deserialize<'de> for RawBson<'a> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { deserializer.deserialize_newtype_struct(RAW_BSON_NEWTYPE, RawBsonVisitor) } } diff --git a/src/raw/document.rs b/src/raw/document.rs index d803b17f..20367817 100644 --- a/src/raw/document.rs +++ b/src/raw/document.rs @@ -6,7 +6,8 @@ use std::{ use serde::{ser::SerializeMap, Deserialize, Serialize}; use crate::{ - raw::{error::ErrorKind, RAW_DOCUMENT_NEWTYPE}, + raw::{error::ErrorKind, RawBsonVisitor, RAW_DOCUMENT_NEWTYPE}, + spec::BinarySubtype, DateTime, Timestamp, }; @@ -493,11 +494,19 @@ impl<'de: 'a, 'a> Deserialize<'de> for &'a RawDocument { where D: serde::Deserializer<'de>, { - match RawBson::deserialize(deserializer)? { + match deserializer.deserialize_newtype_struct(RAW_DOCUMENT_NEWTYPE, RawBsonVisitor)? { RawBson::Document(d) => Ok(d), - b => Err(serde::de::Error::custom(format!( + + // For non-BSON formats, RawDocument gets serialized as bytes, so we need to deserialize + // from them here too. For BSON, the deserialzier will return an error if it + // sees the RAW_DOCUMENT_NEWTYPE but the next type isn't a document. + RawBson::Binary(b) if b.subtype == BinarySubtype::Generic => { + RawDocument::new(b.bytes).map_err(serde::de::Error::custom) + } + + o => Err(serde::de::Error::custom(format!( "expected raw document reference, instead got {:?}", - b + o ))), } } diff --git a/src/raw/mod.rs b/src/raw/mod.rs index 48359110..89c16ea1 100644 --- a/src/raw/mod.rs +++ b/src/raw/mod.rs @@ -134,6 +134,8 @@ pub use self::{ iter::Iter, }; +pub(crate) use self::bson::RawBsonVisitor; + /// Special newtype name indicating that the type being (de)serialized is a raw BSON document. pub(crate) const RAW_DOCUMENT_NEWTYPE: &str = "$__private__bson_RawDocument"; From efb8c5a7fe85ac1b887f9222b5f2189abef74108 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Mon, 1 Nov 2021 20:31:27 -0400 Subject: [PATCH 17/21] ensure hint is cleared after use --- serde-tests/test.rs | 32 ++++++++++++++++++++++++++++++++ src/ser/raw/mod.rs | 12 ++++++------ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/serde-tests/test.rs b/serde-tests/test.rs index c3e0d0a4..2db16982 100644 --- a/serde-tests/test.rs +++ b/serde-tests/test.rs @@ -1208,3 +1208,35 @@ fn u2i() { bson::to_document(&v).unwrap_err(); bson::to_vec(&v).unwrap_err(); } + +#[test] +fn hint_cleared() { + #[derive(Debug, Serialize, Deserialize)] + struct Foo<'a> { + #[serde(borrow)] + doc: &'a RawDocument, + #[serde(borrow)] + binary: RawBinary<'a>, + } + + let binary_value = Binary { + bytes: vec![1, 2, 3, 4], + subtype: BinarySubtype::Generic, + }; + + let doc_value = doc! { + "binary": binary_value.clone() + }; + + let bytes = bson::to_vec(&doc_value).unwrap(); + + let doc = RawDocument::new(&bytes).unwrap(); + let binary = doc.get_binary("binary").unwrap(); + + let f = Foo { doc, binary }; + + let serialized_bytes = bson::to_vec(&f).unwrap(); + let round_doc: Document = bson::from_slice(&serialized_bytes).unwrap(); + + assert_eq!(round_doc, doc! { "doc": doc_value, "binary": binary_value }); +} diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index 169848a7..bc178f76 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -28,12 +28,12 @@ pub(crate) struct Serializer { /// but in serde, the serializer learns of the type after serializing the key. type_index: usize, - // /// Whether the binary value about to be serialized is a UUID or not. - // /// This is indicated by serializing a newtype with name UUID_NEWTYPE_NAME; - // is_uuid: bool, + /// Hint provided by the type being serialized. hint: SerializerHint, } +/// Various bits of information that the serialized type can provide to the serializer to +/// inform the purpose of the next serialization step. #[derive(Debug, Clone, Copy)] enum SerializerHint { None, @@ -200,7 +200,7 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_bytes(self, v: &[u8]) -> Result { - match self.hint { + match self.hint.take() { SerializerHint::RawDocument => { self.update_element_type(ElementType::EmbeddedDocument)?; self.bytes.write_all(v)?; @@ -209,10 +209,10 @@ impl<'a> serde::Serializer for &'a mut Serializer { self.update_element_type(ElementType::Array)?; self.bytes.write_all(v)?; } - _ => { + hint => { self.update_element_type(ElementType::Binary)?; - let subtype = if matches!(self.hint.take(), SerializerHint::Uuid) { + let subtype = if matches!(hint, SerializerHint::Uuid) { BinarySubtype::Uuid } else { BinarySubtype::Generic From bfb0dd8f17a918e0c32b8f86eaceac48da9f5104 Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Mon, 1 Nov 2021 21:08:13 -0400 Subject: [PATCH 18/21] fix compilation on 1.48 --- src/ser/raw/value_serializer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs index 91ea8a09..49880a85 100644 --- a/src/ser/raw/value_serializer.rs +++ b/src/ser/raw/value_serializer.rs @@ -470,7 +470,7 @@ impl<'a, 'b> SerializeStruct for &'b mut ValueSerializer<'a> { self.state = SerializationStep::BinaryBytes; value.serialize(&mut **self)?; } - (SerializationStep::BinaryBytes, "base64" | "bytes") => { + (SerializationStep::BinaryBytes, key) if key == "bytes" || key == "base64" => { // state is updated in serialize value.serialize(&mut **self)?; } From b05569bd99a31593f3e7cdf61efc5fc42a8627ed Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Mon, 8 Nov 2021 14:15:05 -0500 Subject: [PATCH 19/21] typo fix --- src/de/raw.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/de/raw.rs b/src/de/raw.rs index 478e00da..48a20de4 100644 --- a/src/de/raw.rs +++ b/src/de/raw.rs @@ -693,7 +693,7 @@ struct RawDocumentAccess<'d> { /// Whether the first key has been deserialized yet or not. deserialized_first: bool, - /// Whether or not this document being deserialized is for anarray or not. + /// Whether or not this document being deserialized is for an array or not. array: bool, } From 6da757495bfafda58ad79b260ffa4732cb1a7e8a Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Tue, 9 Nov 2021 11:52:15 -0500 Subject: [PATCH 20/21] add use --- src/raw/bson.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/raw/bson.rs b/src/raw/bson.rs index b35909ff..9529a1e7 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -5,6 +5,7 @@ use serde_bytes::{ByteBuf, Bytes}; use super::{Error, RawArray, RawDocument, Result}; use crate::{ + de::convert_unsigned_to_signed_raw, extjson, oid::{self, ObjectId}, raw::{RAW_ARRAY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, @@ -304,28 +305,28 @@ impl<'de> Visitor<'de> for RawBsonVisitor { where E: serde::de::Error, { - crate::de::convert_unsigned_to_signed_raw(value.into()) + convert_unsigned_to_signed_raw(value.into()) } fn visit_u16(self, value: u16) -> std::result::Result where E: serde::de::Error, { - crate::de::convert_unsigned_to_signed_raw(value.into()) + convert_unsigned_to_signed_raw(value.into()) } fn visit_u32(self, value: u32) -> std::result::Result where E: serde::de::Error, { - crate::de::convert_unsigned_to_signed_raw(value.into()) + convert_unsigned_to_signed_raw(value.into()) } fn visit_u64(self, value: u64) -> std::result::Result where E: serde::de::Error, { - crate::de::convert_unsigned_to_signed_raw(value) + convert_unsigned_to_signed_raw(value) } fn visit_bool(self, v: bool) -> std::result::Result From 2e56b8149873060671ced21a1b9f59fa2c6a3e0c Mon Sep 17 00:00:00 2001 From: Patrick Freed Date: Tue, 9 Nov 2021 11:53:00 -0500 Subject: [PATCH 21/21] fix typo --- src/raw/document.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/raw/document.rs b/src/raw/document.rs index 20367817..053b64d0 100644 --- a/src/raw/document.rs +++ b/src/raw/document.rs @@ -498,7 +498,7 @@ impl<'de: 'a, 'a> Deserialize<'de> for &'a RawDocument { RawBson::Document(d) => Ok(d), // For non-BSON formats, RawDocument gets serialized as bytes, so we need to deserialize - // from them here too. For BSON, the deserialzier will return an error if it + // from them here too. For BSON, the deserializier will return an error if it // sees the RAW_DOCUMENT_NEWTYPE but the next type isn't a document. RawBson::Binary(b) if b.subtype == BinarySubtype::Generic => { RawDocument::new(b.bytes).map_err(serde::de::Error::custom)