diff --git a/rmp-serde/src/decode.rs b/rmp-serde/src/decode.rs index 4f3bee3..9bd0056 100644 --- a/rmp-serde/src/decode.rs +++ b/rmp-serde/src/decode.rs @@ -64,6 +64,16 @@ macro_rules! depth_count( } ); +/// Inspired by serde. We have our own error type and can save some compile time by avoiding `?`. +macro_rules! tri { + ($expr:expr) => { + match $expr { + Ok(val) => val, + Err(err) => return Err(Error::from(err)), + } + }; +} + impl error::Error for Error { #[cold] fn source(&self) -> Option<&(dyn error::Error + 'static)> { @@ -340,60 +350,169 @@ fn read_128_buf<'de, R: ReadSlice<'de>>(rd: &mut R, len: u8) -> Result buf, Reference::Copied(buf) => buf, }; Ok(i128::from_be_bytes(buf.try_into().map_err(|_| Error::LengthMismatch(16))?)) } -fn read_str_data<'de, V, R>(rd: &mut R, len: u32, visitor: V) -> Result - where V: Visitor<'de>, R: ReadSlice<'de> +fn read_str_len<'de, R>(rd: &mut R, marker: Marker) -> Result + where R: ReadSlice<'de>{ + match marker { + Marker::FixStr(len) => Ok(len.into()), + Marker::Str8 => read_u8(rd).map(u32::from), + Marker::Str16 => read_u16(rd).map(u32::from), + Marker::Str32 => read_u32(rd), + _ => Err(Error::TypeMismatch(Marker::Reserved)), + } +} + +enum StrData<'de, 'r> { + Str(&'r str), + StrError(Utf8Error, &'r [u8]), + BorrowedStr(&'de str), + BorrowedStrError(Utf8Error, &'de [u8]), +} + +fn read_str_data<'de, 'r, R>(rd: &'r mut R, marker: Marker) -> Result, Error> + where R: ReadSlice<'de> { - match read_bin_data(rd, len)? { + let len = tri!(read_str_len(rd, marker)); + match read_bin_content(rd, len)? { Reference::Borrowed(buf) => { match str::from_utf8(buf) { - Ok(s) => visitor.visit_borrowed_str(s), + Ok(s) => Ok(StrData::BorrowedStr(s)), Err(err) => { - // Allow to unpack invalid UTF-8 bytes into a byte array. - match visitor.visit_borrowed_bytes::(buf) { - Ok(buf) => Ok(buf), - Err(..) => Err(Error::Utf8Error(err)), - } + Ok(StrData::BorrowedStrError(err, buf)) } } } Reference::Copied(buf) => { match str::from_utf8(buf) { - Ok(s) => visitor.visit_str(s), + Ok(s) => Ok(StrData::Str(s)), Err(err) => { - // Allow to unpack invalid UTF-8 bytes into a byte array. - match visitor.visit_bytes::(buf) { - Ok(buf) => Ok(buf), - Err(..) => Err(Error::Utf8Error(err)), - } + Ok(StrData::StrError(err, buf)) } } } } } -fn read_bin_data<'a, 'de, R: ReadSlice<'de>>(rd: &'a mut R, len: u32) -> Result, Error> { - rd.read_slice(len as usize).map_err(Error::InvalidDataRead) +fn visit_str_data<'de, V>(visitor: V, data: StrData<'de, '_>) -> Result where V: Visitor<'de> { + match data { + StrData::Str(s) => visitor.visit_str(s), + StrData::StrError(err, buf) => { + // Allow to unpack invalid UTF-8 bytes into a byte array. + match visitor.visit_bytes::(buf) { + Ok(buf) => Ok(buf), + Err(..) => Err(Error::Utf8Error(err)), + } + }, + StrData::BorrowedStr(s) => visitor.visit_borrowed_str(s), + StrData::BorrowedStrError(err, buf) => { + // Allow to unpack invalid UTF-8 bytes into a byte array. + match visitor.visit_borrowed_bytes::(buf) { + Ok(buf) => Ok(buf), + Err(..) => Err(Error::Utf8Error(err)), + } + }, + } +} + +fn read_array_len<'de, R>(rd: &mut R, marker: Marker) -> Result + where R: ReadSlice<'de> { + match marker { + Marker::FixArray(len) => Ok(len.into()), + Marker::Array16 => read_u16(rd).map(u32::from), + Marker::Array32 => read_u32(rd), + _ => Err(Error::TypeMismatch(Marker::Reserved)), + } +} + +fn read_bin_len<'de, R>(rd: &mut R, marker: Marker) -> Result + where R: ReadSlice<'de>{ + match marker { + Marker::Bin8 => read_u8(rd).map(u32::from), + Marker::Bin16 => read_u16(rd).map(u32::from), + Marker::Bin32 => read_u32(rd), + _ => Err(Error::TypeMismatch(Marker::Reserved)), + } +} + +fn read_map_len<'de, R>(rd: &mut R, marker: Marker) -> Result + where R: ReadSlice<'de> { + match marker { + Marker::FixMap(len) => Ok(len.into()), + Marker::Map16 => read_u16(rd).map(u32::from), + Marker::Map32 => read_u32(rd), + _ => Err(Error::TypeMismatch(Marker::Reserved)), + } +} + +#[inline(never)] +fn read_bin_data<'a, 'de, R: ReadSlice<'de>>(rd: &'a mut R, marker: Marker) -> Result, Error> { + let len = tri!(read_bin_len(rd, marker)); + read_bin_content(rd, len) +} + +fn read_bin_content<'a, 'de, R: ReadSlice<'de>>(rd: &'a mut R, len: u32) -> Result, Error> { + match rd.read_slice(len as usize) { + Ok(b) => Ok(b), + Err(e) => Err(Error::InvalidDataRead(e)), + } } fn read_u8(rd: &mut R) -> Result { - byteorder::ReadBytesExt::read_u8(rd).map_err(Error::InvalidDataRead) + match byteorder::ReadBytesExt::read_u8(rd) { + Ok(v) => Ok(v), + Err(e) => Err(Error::InvalidDataRead(e)), + } } fn read_u16(rd: &mut R) -> Result { - rd.read_u16::() - .map_err(Error::InvalidDataRead) + match rd.read_u16::() { + Ok(v) => Ok(v), + Err(e) => Err(Error::InvalidDataRead(e)), + } } fn read_u32(rd: &mut R) -> Result { - rd.read_u32::() - .map_err(Error::InvalidDataRead) + match rd.read_u32::() { + Ok(v) => Ok(v), + Err(e) => Err(Error::InvalidDataRead(e)), + } +} + +enum AnyNumber { + U8(u8), + U16(u16), + U32(u32), + U64(u64), + I8(i8), + I16(i16), + I32(i32), + I64(i64), + F32(f32), + F64(f64), +} + +#[inline(never)] +fn read_any_number<'de, R>(rd: &mut R, marker: Marker) -> Result + where R: ReadSlice<'de> { + match marker { + Marker::U8 => Ok(AnyNumber::U8(rd.read_data_u8()?)), + Marker::U16 => Ok(AnyNumber::U16(rd.read_data_u16()?)), + Marker::U32 => Ok(AnyNumber::U32(rd.read_data_u32()?)), + Marker::U64 => Ok(AnyNumber::U64(rd.read_data_u64()?)), + Marker::I8 => Ok(AnyNumber::I8(rd.read_data_i8()?)), + Marker::I16 => Ok(AnyNumber::I16(rd.read_data_i16()?)), + Marker::I32 => Ok(AnyNumber::I32(rd.read_data_i32()?)), + Marker::I64 => Ok(AnyNumber::I64(rd.read_data_i64()?)), + Marker::F32 => Ok(AnyNumber::F32(rd.read_data_f32()?)), + Marker::F64 => Ok(AnyNumber::F64(rd.read_data_f64()?)), + other_marker => Err(Error::TypeMismatch(other_marker)), + } } fn ext_len(rd: &mut R, marker: Marker) -> Result { @@ -511,23 +630,36 @@ fn any_num<'de, R: ReadSlice<'de>, V: Visitor<'de>>(rd: &mut R, visitor: V, mark Marker::False => visitor.visit_bool(marker == Marker::True), Marker::FixPos(val) => visitor.visit_u8(val), Marker::FixNeg(val) => visitor.visit_i8(val), - Marker::U8 => visitor.visit_u8(rd.read_data_u8()?), - Marker::U16 => visitor.visit_u16(rd.read_data_u16()?), - Marker::U32 => visitor.visit_u32(rd.read_data_u32()?), - Marker::U64 => visitor.visit_u64(rd.read_data_u64()?), - Marker::I8 => visitor.visit_i8(rd.read_data_i8()?), - Marker::I16 => visitor.visit_i16(rd.read_data_i16()?), - Marker::I32 => visitor.visit_i32(rd.read_data_i32()?), - Marker::I64 => visitor.visit_i64(rd.read_data_i64()?), - Marker::F32 => visitor.visit_f32(rd.read_data_f32()?), - Marker::F64 => visitor.visit_f64(rd.read_data_f64()?), + Marker::U8 | + Marker::U16 | + Marker::U32 | + Marker::U64 | + Marker::I8 | + Marker::I16 | + Marker::I32 | + Marker::I64 | + Marker::F32 | + Marker::F64 => { + match tri!(read_any_number(rd, marker)) { + AnyNumber::U8(n) => visitor.visit_u8(n), + AnyNumber::U16(n) => visitor.visit_u16(n), + AnyNumber::U32(n) => visitor.visit_u32(n), + AnyNumber::U64(n) => visitor.visit_u64(n), + AnyNumber::I8(n) => visitor.visit_i8(n), + AnyNumber::I16(n) => visitor.visit_i16(n), + AnyNumber::I32(n) => visitor.visit_i32(n), + AnyNumber::I64(n) => visitor.visit_i64(n), + AnyNumber::F32(n) => visitor.visit_f32(n), + AnyNumber::F64(n) => visitor.visit_f64(n), + } + } other_marker => Err(Error::TypeMismatch(other_marker)), } } impl<'de, R: ReadSlice<'de>, C: SerializerConfig> Deserializer { fn any_inner>(&mut self, visitor: V, allow_bytes: bool) -> Result { - let marker = self.take_or_read_marker()?; + let marker = tri!(self.take_or_read_marker()); match marker { Marker::Null | Marker::True | @@ -545,28 +677,16 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> Deserializer { Marker::F32 | Marker::F64 => any_num(&mut self.rd, visitor, marker), Marker::FixStr(_) | Marker::Str8 | Marker::Str16 | Marker::Str32 => { - let len = match marker { - Marker::FixStr(len) => Ok(len.into()), - Marker::Str8 => read_u8(&mut self.rd).map(u32::from), - Marker::Str16 => read_u16(&mut self.rd).map(u32::from), - Marker::Str32 => read_u32(&mut self.rd), - _ => return Err(Error::TypeMismatch(Marker::Reserved)), - }?; - read_str_data(&mut self.rd, len, visitor) + let data = tri!(read_str_data(&mut self.rd, marker)); + visit_str_data(visitor, data) } Marker::FixArray(_) | Marker::Array16 | Marker::Array32 => { - let len = match marker { - Marker::FixArray(len) => len.into(), - Marker::Array16 => read_u16(&mut self.rd)?.into(), - Marker::Array32 => read_u32(&mut self.rd)?, - _ => return Err(Error::TypeMismatch(Marker::Reserved)), - }; - + let len = tri!(read_array_len(&mut self.rd, marker)); depth_count!(self.depth, { let mut seq = SeqAccess::new(self, len); - let res = visitor.visit_seq(&mut seq)?; + let res = tri!(visitor.visit_seq(&mut seq)); match seq.left { 0 => Ok(res), excess => Err(Error::LengthMismatch(len - excess)), @@ -576,16 +696,10 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> Deserializer { Marker::FixMap(_) | Marker::Map16 | Marker::Map32 => { - let len = match marker { - Marker::FixMap(len) => len.into(), - Marker::Map16 => read_u16(&mut self.rd)?.into(), - Marker::Map32 => read_u32(&mut self.rd)?, - _ => return Err(Error::TypeMismatch(Marker::Reserved)), - }; - + let len = tri!(read_map_len(&mut self.rd, marker)); depth_count!(self.depth, { let mut seq = MapAccess::new(self, len); - let res = visitor.visit_map(&mut seq)?; + let res = tri!(visitor.visit_map(&mut seq)); match seq.left { 0 => Ok(res), excess => Err(Error::LengthMismatch(len - excess)), @@ -593,13 +707,7 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> Deserializer { }) } Marker::Bin8 | Marker::Bin16 | Marker::Bin32 => { - let len = match marker { - Marker::Bin8 => read_u8(&mut self.rd).map(u32::from), - Marker::Bin16 => read_u16(&mut self.rd).map(u32::from), - Marker::Bin32 => read_u32(&mut self.rd), - _ => return Err(Error::TypeMismatch(Marker::Reserved)), - }?; - match read_bin_data(&mut self.rd, len)? { + match tri!(read_bin_data(&mut self.rd, marker)) { Reference::Borrowed(buf) if allow_bytes => visitor.visit_borrowed_bytes(buf), Reference::Copied(buf) if allow_bytes => visitor.visit_bytes(buf), Reference::Borrowed(buf) | Reference::Copied(buf) => { @@ -615,7 +723,7 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> Deserializer { Marker::Ext8 | Marker::Ext16 | Marker::Ext32 => { - let len = ext_len(&mut self.rd, marker)?; + let len = tri!(ext_len(&mut self.rd, marker)); depth_count!(self.depth, visitor.visit_newtype_struct(ExtDeserializer::new(self, len))) } Marker::Reserved => Err(Error::TypeMismatch(Marker::Reserved)), @@ -833,7 +941,7 @@ impl<'de, 'a, R: ReadSlice<'de> + 'a, C: SerializerConfig> de::SeqAccess<'de> fo { if self.left > 0 { self.left -= 1; - Ok(Some(seed.deserialize(&mut *self.de)?)) + Ok(Some(tri!(seed.deserialize(&mut *self.de)))) } else { Ok(None) } @@ -906,7 +1014,7 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> de::EnumAccess<'de> where V: de::DeserializeSeed<'de>, { - let variant = seed.deserialize(&mut *self.de)?; + let variant = tri!(seed.deserialize(&mut *self.de)); Ok((variant, self)) } } @@ -973,7 +1081,7 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> de::EnumAccess<'de> for Varian fn variant_seed(self, seed: V) -> Result<(V::Value, Self), Error> where V: de::DeserializeSeed<'de>, { - Ok((seed.deserialize(&mut *self.de)?, self)) + Ok((tri!(seed.deserialize(&mut *self.de)), self)) } } @@ -982,7 +1090,7 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> de::VariantAccess<'de> for Var #[inline] fn unit_variant(self) -> Result<(), Error> { - decode::read_nil(&mut self.de.rd)?; + tri!(decode::read_nil(&mut self.de.rd)); Ok(()) } diff --git a/rmp/src/decode/mod.rs b/rmp/src/decode/mod.rs index 8b5d813..27f2dbd 100644 --- a/rmp/src/decode/mod.rs +++ b/rmp/src/decode/mod.rs @@ -109,7 +109,10 @@ pub trait RmpRead: sealed::Sealed { #[inline] #[doc(hidden)] fn read_data_u8(&mut self) -> Result> { - self.read_u8().map_err(ValueReadError::InvalidDataRead) + match self.read_u8() { + Ok(v) => Ok(v), + Err(e) => Err(ValueReadError::InvalidDataRead(e)), + } } /// Read a single (signed) byte from this stream. #[inline]