diff --git a/parquet/src/parquet_thrift.rs b/parquet/src/parquet_thrift.rs index 221532ea8332..8ee018ef95db 100644 --- a/parquet/src/parquet_thrift.rs +++ b/parquet/src/parquet_thrift.rs @@ -35,6 +35,66 @@ use crate::{ errors::{ParquetError, Result}, write_thrift_field, }; +use std::io::Error; +use std::str::Utf8Error; + +#[derive(Debug)] +pub(crate) enum ThriftProtocolError { + Eof, + IO(Error), + InvalidFieldType(u8), + InvalidElementType(u8), + FieldDeltaOverflow { field_delta: u8, last_field_id: i16 }, + InvalidBoolean(u8), + Utf8Error, + SkipDepth(FieldType), + SkipUnsupportedType(FieldType), +} + +impl From for ParquetError { + #[inline(never)] + fn from(e: ThriftProtocolError) -> Self { + match e { + ThriftProtocolError::Eof => eof_err!("Unexpected EOF"), + ThriftProtocolError::IO(e) => e.into(), + ThriftProtocolError::InvalidFieldType(value) => { + general_err!("Unexpected struct field type {}", value) + } + ThriftProtocolError::InvalidElementType(value) => { + general_err!("Unexpected list/set element type{}", value) + } + ThriftProtocolError::FieldDeltaOverflow { + field_delta, + last_field_id, + } => general_err!("cannot add {} to {}", field_delta, last_field_id), + ThriftProtocolError::InvalidBoolean(value) => { + general_err!("cannot convert {} into bool", value) + } + ThriftProtocolError::Utf8Error => general_err!("invalid utf8"), + ThriftProtocolError::SkipDepth(field_type) => { + general_err!("cannot parse past {:?}", field_type) + } + ThriftProtocolError::SkipUnsupportedType(field_type) => { + general_err!("cannot skip field type {:?}", field_type) + } + } + } +} + +impl From for ThriftProtocolError { + fn from(_: Utf8Error) -> Self { + // ignore error payload to reduce the size of ThriftProtocolError + Self::Utf8Error + } +} + +impl From for ThriftProtocolError { + fn from(e: Error) -> Self { + Self::IO(e) + } +} + +pub type ThriftProtocolResult = Result; /// Wrapper for thrift `double` fields. This is used to provide /// an implementation of `Eq` for floats. This implementation @@ -87,8 +147,8 @@ pub(crate) enum FieldType { } impl TryFrom for FieldType { - type Error = ParquetError; - fn try_from(value: u8) -> Result { + type Error = ThriftProtocolError; + fn try_from(value: u8) -> ThriftProtocolResult { match value { 0 => Ok(Self::Stop), 1 => Ok(Self::BooleanTrue), @@ -103,13 +163,13 @@ impl TryFrom for FieldType { 10 => Ok(Self::Set), 11 => Ok(Self::Map), 12 => Ok(Self::Struct), - _ => Err(general_err!("Unexpected struct field type{}", value)), + _ => Err(ThriftProtocolError::InvalidFieldType(value)), } } } impl TryFrom for FieldType { - type Error = ParquetError; + type Error = ThriftProtocolError; fn try_from(value: ElementType) -> std::result::Result { match value { ElementType::Bool => Ok(Self::BooleanTrue), @@ -121,7 +181,7 @@ impl TryFrom for FieldType { ElementType::Binary => Ok(Self::Binary), ElementType::List => Ok(Self::List), ElementType::Struct => Ok(Self::Struct), - _ => Err(general_err!("Unexpected list element type{:?}", value)), + _ => Err(ThriftProtocolError::InvalidFieldType(value as u8)), } } } @@ -143,8 +203,8 @@ pub(crate) enum ElementType { } impl TryFrom for ElementType { - type Error = ParquetError; - fn try_from(value: u8) -> Result { + type Error = ThriftProtocolError; + fn try_from(value: u8) -> ThriftProtocolResult { match value { // For historical and compatibility reasons, a reader should be capable to deal with both cases. // The only valid value in the original spec was 2, but due to an widespread implementation bug @@ -162,7 +222,7 @@ impl TryFrom for ElementType { 10 => Ok(Self::Set), 11 => Ok(Self::Map), 12 => Ok(Self::Struct), - _ => Err(general_err!("Unexpected list/set element type{}", value)), + _ => Err(ThriftProtocolError::InvalidElementType(value)), } } } @@ -202,20 +262,20 @@ pub(crate) struct ListIdentifier { /// [compact]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md pub(crate) trait ThriftCompactInputProtocol<'a> { /// Read a single byte from the input. - fn read_byte(&mut self) -> Result; + fn read_byte(&mut self) -> ThriftProtocolResult; /// Read a Thrift encoded [binary] from the input. /// /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding - fn read_bytes(&mut self) -> Result<&'a [u8]>; + fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]>; - fn read_bytes_owned(&mut self) -> Result>; + fn read_bytes_owned(&mut self) -> ThriftProtocolResult>; /// Skip the next `n` bytes of input. - fn skip_bytes(&mut self, n: usize) -> Result<()>; + fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()>; /// Read a ULEB128 encoded unsigned varint from the input. - fn read_vlq(&mut self) -> Result { + fn read_vlq(&mut self) -> ThriftProtocolResult { let mut in_progress = 0; let mut shift = 0; loop { @@ -229,13 +289,13 @@ pub(crate) trait ThriftCompactInputProtocol<'a> { } /// Read a zig-zag encoded signed varint from the input. - fn read_zig_zag(&mut self) -> Result { + fn read_zig_zag(&mut self) -> ThriftProtocolResult { let val = self.read_vlq()?; Ok((val >> 1) as i64 ^ -((val & 1) as i64)) } /// Read the [`ListIdentifier`] for a Thrift encoded list. - fn read_list_begin(&mut self) -> Result { + fn read_list_begin(&mut self) -> ThriftProtocolResult { let header = self.read_byte()?; let element_type = ElementType::try_from(header & 0x0f)?; @@ -253,8 +313,16 @@ pub(crate) trait ThriftCompactInputProtocol<'a> { }) } + // Full field ids are uncommon. + // Not inlining this method reduces the code size of `read_field_begin`, which then ideally gets + // inlined everywhere. + #[cold] + fn read_full_field_id(&mut self) -> ThriftProtocolResult { + self.read_i16() + } + /// Read the [`FieldIdentifier`] for a field in a Thrift encoded struct. - fn read_field_begin(&mut self, last_field_id: i16) -> Result { + fn read_field_begin(&mut self, last_field_id: i16) -> ThriftProtocolResult { // we can read at least one byte, which is: // - the type // - the field delta and the type @@ -277,17 +345,14 @@ pub(crate) trait ThriftCompactInputProtocol<'a> { bool_val = Some(true); } let field_id = if field_delta != 0 { - last_field_id.checked_add(field_delta as i16).map_or_else( - || { - Err(general_err!(format!( - "cannot add {} to {}", - field_delta, last_field_id - ))) + last_field_id.checked_add(field_delta as i16).ok_or( + ThriftProtocolError::FieldDeltaOverflow { + field_delta, + last_field_id, }, - Ok, )? } else { - self.read_i16()? + self.read_full_field_id()? }; Ok(FieldIdentifier { @@ -305,7 +370,7 @@ pub(crate) trait ThriftCompactInputProtocol<'a> { /// This also skips validation of the field type. /// /// Returns a tuple of `(field_type, field_delta)`. - fn read_field_header(&mut self) -> Result<(u8, u8)> { + fn read_field_header(&mut self) -> ThriftProtocolResult<(u8, u8)> { let field_type = self.read_byte()?; let field_delta = (field_type & 0xf0) >> 4; let field_type = field_type & 0xf; @@ -314,7 +379,7 @@ pub(crate) trait ThriftCompactInputProtocol<'a> { /// Read a boolean list element. This should not be used for struct fields. For the latter, /// use the [`FieldIdentifier::bool_val`] field. - fn read_bool(&mut self) -> Result { + fn read_bool(&mut self) -> ThriftProtocolResult { let b = self.read_byte()?; // Previous versions of the thrift specification said to use 0 and 1 inside collections, // but that differed from existing implementations. @@ -323,43 +388,43 @@ pub(crate) trait ThriftCompactInputProtocol<'a> { match b { 0x01 => Ok(true), 0x00 | 0x02 => Ok(false), - unkn => Err(general_err!(format!("cannot convert {unkn} into bool"))), + _ => Err(ThriftProtocolError::InvalidBoolean(b)), } } /// Read a Thrift [binary] as a UTF-8 encoded string. /// /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding - fn read_string(&mut self) -> Result<&'a str> { + fn read_string(&mut self) -> ThriftProtocolResult<&'a str> { let slice = self.read_bytes()?; Ok(std::str::from_utf8(slice)?) } /// Read an `i8`. - fn read_i8(&mut self) -> Result { + fn read_i8(&mut self) -> ThriftProtocolResult { Ok(self.read_byte()? as _) } /// Read an `i16`. - fn read_i16(&mut self) -> Result { + fn read_i16(&mut self) -> ThriftProtocolResult { Ok(self.read_zig_zag()? as _) } /// Read an `i32`. - fn read_i32(&mut self) -> Result { + fn read_i32(&mut self) -> ThriftProtocolResult { Ok(self.read_zig_zag()? as _) } /// Read an `i64`. - fn read_i64(&mut self) -> Result { + fn read_i64(&mut self) -> ThriftProtocolResult { self.read_zig_zag() } /// Read a Thrift `double` as `f64`. - fn read_double(&mut self) -> Result; + fn read_double(&mut self) -> ThriftProtocolResult; /// Skip a ULEB128 encoded varint. - fn skip_vlq(&mut self) -> Result<()> { + fn skip_vlq(&mut self) -> ThriftProtocolResult<()> { loop { let byte = self.read_byte()?; if byte & 0x80 == 0 { @@ -371,14 +436,14 @@ pub(crate) trait ThriftCompactInputProtocol<'a> { /// Skip a thrift [binary]. /// /// [binary]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding - fn skip_binary(&mut self) -> Result<()> { + fn skip_binary(&mut self) -> ThriftProtocolResult<()> { let len = self.read_vlq()? as usize; self.skip_bytes(len) } /// Skip a field with type `field_type` recursively until the default /// maximum skip depth (currently 64) is reached. - fn skip(&mut self, field_type: FieldType) -> Result<()> { + fn skip(&mut self, field_type: FieldType) -> ThriftProtocolResult<()> { const DEFAULT_SKIP_DEPTH: i8 = 64; self.skip_till_depth(field_type, DEFAULT_SKIP_DEPTH) } @@ -396,9 +461,9 @@ pub(crate) trait ThriftCompactInputProtocol<'a> { } /// Skip a field with type `field_type` recursively up to `depth` levels. - fn skip_till_depth(&mut self, field_type: FieldType, depth: i8) -> Result<()> { + fn skip_till_depth(&mut self, field_type: FieldType, depth: i8) -> ThriftProtocolResult<()> { if depth == 0 { - return Err(general_err!(format!("cannot parse past {:?}", field_type))); + return Err(ThriftProtocolError::SkipDepth(field_type)); } match field_type { @@ -431,7 +496,7 @@ pub(crate) trait ThriftCompactInputProtocol<'a> { Ok(()) } // no list or map types in parquet format - u => Err(general_err!(format!("cannot skip field type {:?}", &u))), + _ => Err(ThriftProtocolError::SkipUnsupportedType(field_type)), } } } @@ -455,44 +520,40 @@ impl<'a> ThriftSliceInputProtocol<'a> { impl<'b, 'a: 'b> ThriftCompactInputProtocol<'b> for ThriftSliceInputProtocol<'a> { #[inline] - fn read_byte(&mut self) -> Result { - let ret = *self.buf.first().ok_or_else(eof_error)?; + fn read_byte(&mut self) -> ThriftProtocolResult { + let ret = *self.buf.first().ok_or(ThriftProtocolError::Eof)?; self.buf = &self.buf[1..]; Ok(ret) } - fn read_bytes(&mut self) -> Result<&'b [u8]> { + fn read_bytes(&mut self) -> ThriftProtocolResult<&'b [u8]> { let len = self.read_vlq()? as usize; - let ret = self.buf.get(..len).ok_or_else(eof_error)?; + let ret = self.buf.get(..len).ok_or(ThriftProtocolError::Eof)?; self.buf = &self.buf[len..]; Ok(ret) } - fn read_bytes_owned(&mut self) -> Result> { + fn read_bytes_owned(&mut self) -> ThriftProtocolResult> { Ok(self.read_bytes()?.to_vec()) } #[inline] - fn skip_bytes(&mut self, n: usize) -> Result<()> { - self.buf.get(..n).ok_or_else(eof_error)?; + fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> { + self.buf.get(..n).ok_or(ThriftProtocolError::Eof)?; self.buf = &self.buf[n..]; Ok(()) } - fn read_double(&mut self) -> Result { - let slice = self.buf.get(..8).ok_or_else(eof_error)?; + fn read_double(&mut self) -> ThriftProtocolResult { + let slice = self.buf.get(..8).ok_or(ThriftProtocolError::Eof)?; self.buf = &self.buf[8..]; match slice.try_into() { Ok(slice) => Ok(f64::from_le_bytes(slice)), - Err(_) => Err(general_err!("Unexpected error converting slice")), + Err(_) => unreachable!(), } } } -fn eof_error() -> ParquetError { - eof_err!("Unexpected EOF") -} - /// A Thrift input protocol that wraps a [`Read`] object. /// /// Note that this is only intended for use in reading Parquet page headers. This will panic @@ -509,24 +570,24 @@ impl ThriftReadInputProtocol { impl<'a, R: Read> ThriftCompactInputProtocol<'a> for ThriftReadInputProtocol { #[inline] - fn read_byte(&mut self) -> Result { + fn read_byte(&mut self) -> ThriftProtocolResult { let mut buf = [0_u8; 1]; self.reader.read_exact(&mut buf)?; Ok(buf[0]) } - fn read_bytes(&mut self) -> Result<&'a [u8]> { + fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]> { unimplemented!() } - fn read_bytes_owned(&mut self) -> Result> { + fn read_bytes_owned(&mut self) -> ThriftProtocolResult> { let len = self.read_vlq()? as usize; let mut v = Vec::with_capacity(len); std::io::copy(&mut self.reader.by_ref().take(len as u64), &mut v)?; Ok(v) } - fn skip_bytes(&mut self, n: usize) -> Result<()> { + fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> { std::io::copy( &mut self.reader.by_ref().take(n as u64), &mut std::io::sink(), @@ -534,7 +595,7 @@ impl<'a, R: Read> ThriftCompactInputProtocol<'a> for ThriftReadInputProtocol Ok(()) } - fn read_double(&mut self) -> Result { + fn read_double(&mut self) -> ThriftProtocolResult { let mut buf = [0_u8; 8]; self.reader.read_exact(&mut buf)?; Ok(f64::from_le_bytes(buf)) @@ -552,31 +613,31 @@ pub(crate) trait ReadThrift<'a, R: ThriftCompactInputProtocol<'a>> { impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for bool { fn read_thrift(prot: &mut R) -> Result { - prot.read_bool() + Ok(prot.read_bool()?) } } impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i8 { fn read_thrift(prot: &mut R) -> Result { - prot.read_i8() + Ok(prot.read_i8()?) } } impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i16 { fn read_thrift(prot: &mut R) -> Result { - prot.read_i16() + Ok(prot.read_i16()?) } } impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i32 { fn read_thrift(prot: &mut R) -> Result { - prot.read_i32() + Ok(prot.read_i32()?) } } impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i64 { fn read_thrift(prot: &mut R) -> Result { - prot.read_i64() + Ok(prot.read_i64()?) } } @@ -588,7 +649,7 @@ impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for OrderedF64 { impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a str { fn read_thrift(prot: &mut R) -> Result { - prot.read_string() + Ok(prot.read_string()?) } } @@ -600,7 +661,7 @@ impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for String { impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a [u8] { fn read_thrift(prot: &mut R) -> Result { - prot.read_bytes() + Ok(prot.read_bytes()?) } }