From 298946e2fc9caec4f1280917c73bc9ad4053b50c Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Wed, 6 Mar 2024 16:06:04 +0800 Subject: [PATCH 1/6] feat: Implement binding expression --- Cargo.toml | 1 + crates/iceberg/Cargo.toml | 1 + crates/iceberg/src/expr/mod.rs | 10 ++ crates/iceberg/src/expr/predicate.rs | 231 ++++++++++++++++++++++++++- crates/iceberg/src/expr/term.rs | 124 +++++++++++++- crates/iceberg/src/spec/datatypes.rs | 9 ++ crates/iceberg/src/spec/schema.rs | 52 ++++++ crates/iceberg/src/spec/values.rs | 15 ++ 8 files changed, 439 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dccc6bdf19..697317c1d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ chrono = "0.4" derive_builder = "0.20.0" either = "1" env_logger = "0.11.0" +fnv = "1" futures = "0.3" iceberg = { version = "0.2.0", path = "./crates/iceberg" } iceberg-catalog-rest = { version = "0.2.0", path = "./crates/catalog/rest" } diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml index 32288ee815..ee6b2ff4f5 100644 --- a/crates/iceberg/Cargo.toml +++ b/crates/iceberg/Cargo.toml @@ -40,6 +40,7 @@ bitvec = { workspace = true } chrono = { workspace = true } derive_builder = { workspace = true } either = { workspace = true } +fnv = { workspace = true } futures = { workspace = true } itertools = { workspace = true } lazy_static = { workspace = true } diff --git a/crates/iceberg/src/expr/mod.rs b/crates/iceberg/src/expr/mod.rs index c08c836c32..9b5282ba79 100644 --- a/crates/iceberg/src/expr/mod.rs +++ b/crates/iceberg/src/expr/mod.rs @@ -23,6 +23,8 @@ use std::fmt::{Display, Formatter}; pub use term::*; mod predicate; + +use crate::spec::SchemaRef; pub use predicate::*; /// Predicate operators used in expressions. @@ -147,6 +149,14 @@ impl PredicateOperator { } } +/// Bind expression to a schema. +pub trait Bind { + /// The type of the bounded result. + type Bound; + /// Bind an expression to a schema. + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> crate::Result; +} + #[cfg(test)] mod tests { use crate::expr::PredicateOperator; diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index 66a395624b..45be5d6c43 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -23,7 +23,14 @@ use crate::expr::{BoundReference, PredicateOperator, Reference}; use crate::spec::Datum; use itertools::Itertools; use std::collections::HashSet; +use crate::error::Result; +use crate::expr::{Bind, BoundReference, PredicateOperator, Reference}; +use crate::spec::{Datum, SchemaRef}; +use crate::{Error, ErrorKind}; +use fnv::FnvHashSet; + use std::fmt::{Debug, Display, Formatter}; +use std::mem::MaybeUninit; use std::ops::Not; /// Logical expression, such as `AND`, `OR`, `NOT`. @@ -55,6 +62,29 @@ impl LogicalExpression { } } +impl Bind for LogicalExpression { + type Bound = LogicalExpression; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result { + let mut bound_inputs = MaybeUninit::<[Box; N]>::uninit(); + for (i, input) in self.inputs.into_iter().enumerate() { + let input = input.bind(schema.clone(), case_sensitive)?; + // SAFETY: The pointer is valid from [`MaybeUninit`]. + unsafe { + bound_inputs + .as_mut_ptr() + .cast::>() + .add(i) + .write(Box::new(input)); + } + } + + // SAFETY: We have initialized all elements of the array. + let bound_inputs = unsafe { bound_inputs.assume_init() }; + Ok(LogicalExpression::new(bound_inputs)) + } +} + /// Unary predicate, for example, `a IS NULL`. #[derive(PartialEq)] pub struct UnaryExpression { @@ -79,6 +109,15 @@ impl Display for UnaryExpression { } } +impl Bind for UnaryExpression { + type Bound = UnaryExpression; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result { + let bound_term = self.term.bind(schema, case_sensitive)?; + Ok(UnaryExpression::new(self.op, bound_term)) + } +} + impl UnaryExpression { pub(crate) fn new(op: PredicateOperator, term: T) -> Self { debug_assert!(op.is_unary()); @@ -120,6 +159,15 @@ impl Display for BinaryExpression { } } +impl Bind for BinaryExpression { + type Bound = BinaryExpression; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result { + let bound_term = self.term.bind(schema.clone(), case_sensitive)?; + Ok(BinaryExpression::new(self.op, bound_term, self.literal)) + } +} + /// Set predicates, for example, `a in (1, 2, 3)`. #[derive(PartialEq)] pub struct SetExpression { @@ -128,7 +176,7 @@ pub struct SetExpression { /// Term of this predicate, for example, `a` in `a in (1, 2, 3)`. term: T, /// Literals of this predicate, for example, `(1, 2, 3)` in `a in (1, 2, 3)`. - literals: HashSet, + literals: FnvHashSet, } impl Debug for SetExpression { @@ -141,12 +189,22 @@ impl Debug for SetExpression { } } -impl SetExpression { - pub(crate) fn new(op: PredicateOperator, term: T, literals: HashSet) -> Self { +impl SetExpression { + pub(crate) fn new(op: PredicateOperator, term: T, literals: FnvHashSet) -> Self { + debug_assert!(op.is_set()); Self { op, term, literals } } } +impl Bind for SetExpression { + type Bound = SetExpression; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result { + let bound_term = self.term.bind(schema.clone(), case_sensitive)?; + Ok(SetExpression::new(self.op, bound_term, self.literals)) + } +} + impl Display for SetExpression { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let mut literal_strs = self.literals.iter().map(|l| format!("{}", l)); @@ -172,6 +230,146 @@ pub enum Predicate { Set(SetExpression), } +impl Bind for Predicate { + type Bound = BoundPredicate; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result { + match self { + Predicate::And(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + + let [left, right] = bound_expr.inputs; + Ok(match (left, right) { + (_, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => { + BoundPredicate::AlwaysFalse + } + (l, _) if matches!(&*l, &BoundPredicate::AlwaysFalse) => { + BoundPredicate::AlwaysFalse + } + (left, r) if matches!(&*r, &BoundPredicate::AlwaysTrue) => *left, + (l, right) if matches!(&*l, &BoundPredicate::AlwaysTrue) => *right, + (left, right) => BoundPredicate::And(LogicalExpression::new([left, right])), + }) + } + Predicate::Not(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let [inner] = bound_expr.inputs; + Ok(match inner { + e if matches!(&*e, &BoundPredicate::AlwaysTrue) => BoundPredicate::AlwaysFalse, + e if matches!(&*e, &BoundPredicate::AlwaysFalse) => BoundPredicate::AlwaysTrue, + e => BoundPredicate::Not(LogicalExpression::new([e])), + }) + } + Predicate::Or(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let [left, right] = bound_expr.inputs; + Ok(match (left, right) { + (_, r) if matches!(&*r, &BoundPredicate::AlwaysTrue) => { + BoundPredicate::AlwaysTrue + } + (l, _) if matches!(&*l, &BoundPredicate::AlwaysTrue) => { + BoundPredicate::AlwaysTrue + } + (left, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => *left, + (l, right) if matches!(&*l, &BoundPredicate::AlwaysFalse) => *right, + (left, right) => BoundPredicate::Or(LogicalExpression::new([left, right])), + }) + } + Predicate::Unary(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + + match &bound_expr.op { + &PredicateOperator::IsNull => { + if bound_expr.term.field().required { + return Ok(BoundPredicate::AlwaysFalse); + } + } + &PredicateOperator::NotNull => { + if bound_expr.term.field().required { + return Ok(BoundPredicate::AlwaysTrue); + } + } + &PredicateOperator::IsNan | &PredicateOperator::NotNan => { + if !bound_expr.term.field().field_type.is_floating_type() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Expecting floating point type, but found {}", + bound_expr.term.field().field_type + ), + )); + } + } + op => { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Expecting unary operator,but found {op}"), + )) + } + } + + Ok(BoundPredicate::Unary(bound_expr)) + } + Predicate::Binary(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let bound_literal = bound_expr.literal.to(&bound_expr.term.field().field_type)?; + Ok(BoundPredicate::Binary(BinaryExpression::new( + bound_expr.op, + bound_expr.term, + bound_literal, + ))) + } + Predicate::Set(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let bound_literals = bound_expr + .literals + .into_iter() + .map(|l| l.to(&bound_expr.term.field().field_type)) + .collect::>>()?; + + match &bound_expr.op { + &PredicateOperator::In => { + if bound_literals.is_empty() { + return Ok(BoundPredicate::AlwaysFalse); + } + if bound_literals.len() == 1 { + return Ok(BoundPredicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + bound_expr.term, + bound_literals.into_iter().next().unwrap(), + ))); + } + } + &PredicateOperator::NotIn => { + if bound_literals.is_empty() { + return Ok(BoundPredicate::AlwaysTrue); + } + if bound_literals.len() == 1 { + return Ok(BoundPredicate::Binary(BinaryExpression::new( + PredicateOperator::NotEq, + bound_expr.term, + bound_literals.into_iter().next().unwrap(), + ))); + } + } + op => { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Expecting unary operator,but found {op}"), + )) + } + } + + Ok(BoundPredicate::Set(SetExpression::new( + bound_expr.op, + bound_expr.term, + bound_literals, + ))) + } + } + } +} + impl Display for Predicate { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -415,4 +613,31 @@ mod tests { assert_eq!(result, expected); } + + use crate::expr::{Bind, Reference}; + use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type}; + use std::sync::Arc; + + fn table_schema_simple() -> SchemaRef { + Arc::new( + Schema::builder() + .with_schema_id(1) + .with_identifier_field_ids(vec![2]) + .with_fields(vec![ + NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(), + ]) + .build() + .unwrap(), + ) + } + + #[test] + fn test_bind_is_null() { + let schema = table_schema_simple(); + let expr = Reference::new("foo").is_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "foo IS NULL"); + } } diff --git a/crates/iceberg/src/expr/term.rs b/crates/iceberg/src/expr/term.rs index 6be502fff1..522c347156 100644 --- a/crates/iceberg/src/expr/term.rs +++ b/crates/iceberg/src/expr/term.rs @@ -17,6 +17,10 @@ //! Term definition. +use crate::expr::{BinaryExpression, Predicate, PredicateOperator, SetExpression}; +use crate::expr::{Bind, UnaryExpression}; +use crate::spec::{Datum, NestedField, NestedFieldRef, SchemaRef}; +use crate::{Error, ErrorKind}; use crate::expr::{BinaryExpression, Predicate, PredicateOperator, SetExpression, UnaryExpression}; use crate::spec::{Datum, NestedField, NestedFieldRef}; use std::collections::HashSet; @@ -65,6 +69,41 @@ impl Reference { )) } + /// Creates an is null expression. For example, `a IS NULL`. + /// + /// # Example + /// + /// ```rust + /// + /// use iceberg::expr::Reference; + /// let expr = Reference::new("a").is_null(); + /// + /// assert_eq!(&format!("{expr}"), "a IS NULL"); + /// ``` + pub fn is_null(self) -> Predicate { + Predicate::Unary(UnaryExpression::new(PredicateOperator::IsNull, self)) + } + + /// Creates an in expression. For example, `a IN (1, 2, 3)`. + /// + /// # Example + /// + /// ```rust + /// + /// use iceberg::expr::Reference; + /// use iceberg::spec::Datum; + /// let expr = Reference::new("a").r#in(vec![Datum::long(1), Datum::long(2), Datum::long(3)]); + /// + /// assert_eq!(&format!("{expr}"), "a IN (1, 3, 2)"); + /// ``` + pub fn r#in(self, values: impl IntoIterator) -> Predicate { + Predicate::Set(SetExpression::new( + PredicateOperator::In, + self, + values.into_iter().collect(), + )) + } + /// Creates a greater-than-or-equal-to than expression. For example, `a >= 10`. /// /// # Example @@ -160,8 +199,28 @@ impl Display for Reference { } } +impl Bind for Reference { + type Bound = BoundReference; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> crate::Result { + let field = if case_sensitive { + schema.field_by_name(&self.name) + } else { + schema.field_by_name_case_insensitive(&self.name) + }; + + let field = field.ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Field {} not found in schema", self.name), + ) + })?; + Ok(BoundReference::new(self.name, field.clone())) + } +} + /// A named reference in a bound expression after binding to a schema. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct BoundReference { // This maybe different from [`name`] filed in [`NestedField`] since this contains full path. // For example, if the field is `a.b.c`, then `field.name` is `c`, but `original_name` is `a.b.c`. @@ -192,3 +251,66 @@ impl Display for BoundReference { /// Bound term after binding to a schema. pub type BoundTerm = BoundReference; + +#[cfg(test)] +mod tests { + use crate::expr::{Bind, BoundReference, Reference}; + use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type}; + use std::sync::Arc; + + fn table_schema_simple() -> SchemaRef { + Arc::new( + Schema::builder() + .with_schema_id(1) + .with_identifier_field_ids(vec![2]) + .with_fields(vec![ + NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(), + ]) + .build() + .unwrap(), + ) + } + + #[test] + fn test_bind_reference() { + let schema = table_schema_simple(); + let reference = Reference::new("bar").bind(schema, true).unwrap(); + + let expected_ref = BoundReference::new( + "bar", + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + ); + + assert_eq!(expected_ref, reference); + } + + #[test] + fn test_bind_reference_case_insensitive() { + let schema = table_schema_simple(); + let reference = Reference::new("BaR").bind(schema, false).unwrap(); + + let expected_ref = BoundReference::new( + "BaR", + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + ); + + assert_eq!(expected_ref, reference); + } + + #[test] + fn test_bind_reference_failure() { + let schema = table_schema_simple(); + let result = Reference::new("bar_not_eix").bind(schema, true); + + assert!(result.is_err()); + } + + #[test] + fn test_bind_reference_case_insensitive_failure() { + let schema = table_schema_simple(); + let result = Reference::new("BaR_non").bind(schema, false); + assert!(result.is_err()); + } +} diff --git a/crates/iceberg/src/spec/datatypes.rs b/crates/iceberg/src/spec/datatypes.rs index 8f404e928a..6ea4175e51 100644 --- a/crates/iceberg/src/spec/datatypes.rs +++ b/crates/iceberg/src/spec/datatypes.rs @@ -135,6 +135,15 @@ impl Type { ensure_data_valid!(precision > 0 && precision <= MAX_DECIMAL_PRECISION, "Decimals with precision larger than {MAX_DECIMAL_PRECISION} are not supported: {precision}",); Ok(Type::Primitive(PrimitiveType::Decimal { precision, scale })) } + + /// Check if it's float or double type. + #[inline(always)] + pub fn is_floating_type(&self) -> bool { + matches!( + self, + Type::Primitive(PrimitiveType::Float) | Type::Primitive(PrimitiveType::Double) + ) + } } impl From for Type { diff --git a/crates/iceberg/src/spec/schema.rs b/crates/iceberg/src/spec/schema.rs index 34e383f651..ecd7c702dc 100644 --- a/crates/iceberg/src/spec/schema.rs +++ b/crates/iceberg/src/spec/schema.rs @@ -51,6 +51,7 @@ pub struct Schema { id_to_field: HashMap, name_to_id: HashMap, + lowercase_name_to_id: HashMap, id_to_name: HashMap, } @@ -117,6 +118,11 @@ impl SchemaBuilder { index.indexes() }; + let lowercase_name_to_id = name_to_id + .iter() + .map(|(k, v)| (k.to_lowercase(), *v)) + .collect(); + Ok(Schema { r#struct, schema_id: self.schema_id, @@ -127,6 +133,7 @@ impl SchemaBuilder { id_to_field, name_to_id, + lowercase_name_to_id, id_to_name, }) } @@ -212,6 +219,15 @@ impl Schema { .and_then(|id| self.field_by_id(*id)) } + /// Get field by field name, but in case-insensitive way. + /// + /// Both full name and short name could work here. + pub fn field_by_name_case_insensitive(&self, field_name: &str) -> Option<&NestedFieldRef> { + self.lowercase_name_to_id + .get(&field_name.to_lowercase()) + .and_then(|id| self.field_by_id(*id)) + } + /// Get field by alias. pub fn field_by_alias(&self, alias: &str) -> Option<&NestedFieldRef> { self.alias_to_id @@ -1032,6 +1048,42 @@ table { assert_eq!(&expected_name_to_id, &schema.name_to_id); } + #[test] + fn test_schema_index_by_name_case_insensitive() { + let expected_name_to_id = HashMap::from( + [ + ("fOo", 1), + ("Bar", 2), + ("BAz", 3), + ("quX", 4), + ("quX.ELEment", 5), + ("qUUx", 6), + ("QUUX.KEY", 7), + ("QUUX.Value", 8), + ("qUUX.valUE.Key", 9), + ("qUux.vALUE.Value", 10), + ("lOCAtION", 11), + ("LOCAtioN.ELeMENt", 12), + ("LoCATion.element.LATitude", 13), + ("locatION.ElemeNT.LONgitude", 14), + ("LOCAtiON.LATITUDE", 13), + ("LOCATION.LONGITUDE", 14), + ("PERSon", 15), + ("PERSON.Name", 16), + ("peRSON.AGe", 17), + ] + .map(|e| (e.0.to_string(), e.1)), + ); + + let schema = table_schema_nested(); + for (name, id) in expected_name_to_id { + assert_eq!( + Some(id), + schema.field_by_name_case_insensitive(&name).map(|f| f.id) + ); + } + } + #[test] fn test_schema_find_column_name() { let expected_column_name = HashMap::from([ diff --git a/crates/iceberg/src/spec/values.rs b/crates/iceberg/src/spec/values.rs index 113620f1c3..c0d3f09bdb 100644 --- a/crates/iceberg/src/spec/values.rs +++ b/crates/iceberg/src/spec/values.rs @@ -658,6 +658,21 @@ impl Datum { unreachable!("Decimal type must be primitive.") } } + + /// Convert the datum to `target_type`. + pub fn to(self, target_type: &Type) -> Result { + // TODO: We should allow more type conversions + match target_type { + Type::Primitive(typ) if typ == &self.r#type => Ok(self), + _ => Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Can't convert datum from {} type to {} type.", + self.r#type, target_type + ), + )), + } + } } /// Values present in iceberg type From 09b6ea3435c19fa4f990cb1f7787ed8457e46ae1 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Thu, 7 Mar 2024 16:56:59 +0800 Subject: [PATCH 2/6] Add tests --- crates/iceberg/src/expr/predicate.rs | 271 +++++++++++++++++++++++++-- crates/iceberg/src/expr/term.rs | 80 +++----- crates/iceberg/src/spec/schema.rs | 4 +- crates/iceberg/src/spec/values.rs | 2 +- 4 files changed, 283 insertions(+), 74 deletions(-) diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index 45be5d6c43..fb443a87eb 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -19,19 +19,17 @@ //! Predicate expressions are used to filter data, and evaluates to a boolean value. For example, //! `a > 10` is a predicate expression, and it evaluates to `true` if `a` is greater than `10`, -use crate::expr::{BoundReference, PredicateOperator, Reference}; -use crate::spec::Datum; +use std::fmt::{Debug, Display, Formatter}; +use std::mem::MaybeUninit; +use std::ops::Not; + +use fnv::FnvHashSet; use itertools::Itertools; -use std::collections::HashSet; + use crate::error::Result; use crate::expr::{Bind, BoundReference, PredicateOperator, Reference}; use crate::spec::{Datum, SchemaRef}; use crate::{Error, ErrorKind}; -use fnv::FnvHashSet; - -use std::fmt::{Debug, Display, Formatter}; -use std::mem::MaybeUninit; -use std::ops::Not; /// Logical expression, such as `AND`, `OR`, `NOT`. #[derive(PartialEq)] @@ -490,7 +488,11 @@ impl Predicate { impl Not for Predicate { type Output = Predicate; - /// Create a predicate which is the reverse of this predicate. For example: `NOT (a > 10)` + /// Create a predicate which is the reverse of this predicate. For example: `NOT (a > 10)`. + /// + /// This is different from [`Predicate::negate()`] since it doesn't rewrite expression, but + /// just adds a `NOT` operator. + /// /// # Example /// ///```rust @@ -530,12 +532,46 @@ pub enum BoundPredicate { Set(SetExpression), } +impl Display for BoundPredicate { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + BoundPredicate::AlwaysTrue => { + write!(f, "True") + } + BoundPredicate::AlwaysFalse => { + write!(f, "False") + } + BoundPredicate::And(expr) => { + write!(f, "({}) AND ({})", expr.inputs()[0], expr.inputs()[1]) + } + BoundPredicate::Or(expr) => { + write!(f, "({}) OR ({})", expr.inputs()[0], expr.inputs()[1]) + } + BoundPredicate::Not(expr) => { + write!(f, "NOT ({})", expr.inputs()[0]) + } + BoundPredicate::Unary(expr) => { + write!(f, "{}", expr) + } + BoundPredicate::Binary(expr) => { + write!(f, "{}", expr) + } + BoundPredicate::Set(expr) => { + write!(f, "{}", expr) + } + } + } +} + #[cfg(test)] mod tests { + use std::ops::Not; + use std::sync::Arc; + + use crate::expr::Bind; use crate::expr::Reference; use crate::spec::Datum; - use std::collections::HashSet; - use std::ops::Not; + use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type}; #[test] fn test_predicate_negate_and() { @@ -604,20 +640,15 @@ mod tests { #[test] fn test_predicate_negate_set() { - let expression = Reference::new("a").is_in(HashSet::from([Datum::long(5), Datum::long(6)])); + let expression = Reference::new("a").is_in([Datum::long(5), Datum::long(6)]); - let expected = - Reference::new("a").is_not_in(HashSet::from([Datum::long(5), Datum::long(6)])); + let expected = Reference::new("a").is_not_in([Datum::long(5), Datum::long(6)]); let result = expression.negate(); assert_eq!(result, expected); } - use crate::expr::{Bind, Reference}; - use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type}; - use std::sync::Arc; - fn table_schema_simple() -> SchemaRef { Arc::new( Schema::builder() @@ -640,4 +671,208 @@ mod tests { let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "foo IS NULL"); } + + #[test] + fn test_bind_is_null_required() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + } + + #[test] + fn test_bind_is_not_null() { + let schema = table_schema_simple(); + let expr = Reference::new("foo").is_not_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "foo IS NOT NULL"); + } + + #[test] + fn test_bind_is_not_null_required() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "True"); + } + + #[test] + fn test_bind_less_than() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").less_than(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar < 10"); + } + + #[test] + fn test_bind_less_than_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").less_than(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_greater_than_or_eq() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").greater_than_or_equal_to(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar >= 10"); + } + + #[test] + fn test_bind_greater_than_or_eq_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").greater_than_or_equal_to(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_in() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in([Datum::int(10), Datum::int(20)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar IN (20, 10)"); + } + + #[test] + fn test_bind_in_empty() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in(vec![]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + } + + #[test] + fn test_bind_in_one_literal() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in(vec![Datum::int(10)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar = 10"); + } + + #[test] + fn test_bind_in_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in(vec![Datum::int(10), Datum::string("abcd")]); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_not_in() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in([Datum::int(10), Datum::int(20)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar NOT IN (20, 10)"); + } + + #[test] + fn test_bind_not_in_empty() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in(vec![]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "True"); + } + + #[test] + fn test_bind_not_in_one_literal() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in(vec![Datum::int(10)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar != 10"); + } + + #[test] + fn test_bind_not_in_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in([Datum::int(10), Datum::string("abcd")]); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_and() { + let schema = table_schema_simple(); + let expr = Reference::new("bar") + .less_than(Datum::int(10)) + .and(Reference::new("foo").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "(bar < 10) AND (foo IS NULL)"); + } + + #[test] + fn test_bind_and_always_false() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .and(Reference::new("bar").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + } + + #[test] + fn test_bind_and_always_true() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .and(Reference::new("bar").is_not_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#); + } + + #[test] + fn test_bind_or() { + let schema = table_schema_simple(); + let expr = Reference::new("bar") + .less_than(Datum::int(10)) + .or(Reference::new("foo").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "(bar < 10) OR (foo IS NULL)"); + } + + #[test] + fn test_bind_or_always_true() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .or(Reference::new("bar").is_not_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "True"); + } + + #[test] + fn test_bind_or_always_false() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .or(Reference::new("bar").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#); + } + + #[test] + fn test_bind_not() { + let schema = table_schema_simple(); + let expr = !Reference::new("bar").less_than(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "NOT (bar < 10)"); + } + + #[test] + fn test_bind_not_always_true() { + let schema = table_schema_simple(); + let expr = !Reference::new("bar").is_not_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + } + + #[test] + fn test_bind_not_always_false() { + let schema = table_schema_simple(); + let expr = !Reference::new("bar").is_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), r#"True"#); + } } diff --git a/crates/iceberg/src/expr/term.rs b/crates/iceberg/src/expr/term.rs index 522c347156..e39c97e0f6 100644 --- a/crates/iceberg/src/expr/term.rs +++ b/crates/iceberg/src/expr/term.rs @@ -17,14 +17,14 @@ //! Term definition. -use crate::expr::{BinaryExpression, Predicate, PredicateOperator, SetExpression}; -use crate::expr::{Bind, UnaryExpression}; +use std::fmt::{Display, Formatter}; + +use fnv::FnvHashSet; + +use crate::expr::Bind; +use crate::expr::{BinaryExpression, Predicate, PredicateOperator, SetExpression, UnaryExpression}; use crate::spec::{Datum, NestedField, NestedFieldRef, SchemaRef}; use crate::{Error, ErrorKind}; -use crate::expr::{BinaryExpression, Predicate, PredicateOperator, SetExpression, UnaryExpression}; -use crate::spec::{Datum, NestedField, NestedFieldRef}; -use std::collections::HashSet; -use std::fmt::{Display, Formatter}; /// Unbound term before binding to a schema. pub type Term = Reference; @@ -69,41 +69,6 @@ impl Reference { )) } - /// Creates an is null expression. For example, `a IS NULL`. - /// - /// # Example - /// - /// ```rust - /// - /// use iceberg::expr::Reference; - /// let expr = Reference::new("a").is_null(); - /// - /// assert_eq!(&format!("{expr}"), "a IS NULL"); - /// ``` - pub fn is_null(self) -> Predicate { - Predicate::Unary(UnaryExpression::new(PredicateOperator::IsNull, self)) - } - - /// Creates an in expression. For example, `a IN (1, 2, 3)`. - /// - /// # Example - /// - /// ```rust - /// - /// use iceberg::expr::Reference; - /// use iceberg::spec::Datum; - /// let expr = Reference::new("a").r#in(vec![Datum::long(1), Datum::long(2), Datum::long(3)]); - /// - /// assert_eq!(&format!("{expr}"), "a IN (1, 3, 2)"); - /// ``` - pub fn r#in(self, values: impl IntoIterator) -> Predicate { - Predicate::Set(SetExpression::new( - PredicateOperator::In, - self, - values.into_iter().collect(), - )) - } - /// Creates a greater-than-or-equal-to than expression. For example, `a >= 10`. /// /// # Example @@ -162,16 +127,20 @@ impl Reference { /// /// ```rust /// - /// use std::collections::HashSet; + /// use fnv::FnvHashSet; /// use iceberg::expr::Reference; /// use iceberg::spec::Datum; - /// let expr = Reference::new("a").is_in(HashSet::from([Datum::long(5), Datum::long(6)])); + /// let expr = Reference::new("a").is_in([Datum::long(5), Datum::long(6)]); /// /// let as_string = format!("{expr}"); /// assert!(&as_string == "a IN (5, 6)" || &as_string == "a IN (6, 5)"); /// ``` - pub fn is_in(self, literals: HashSet) -> Predicate { - Predicate::Set(SetExpression::new(PredicateOperator::In, self, literals)) + pub fn is_in(self, literals: impl IntoIterator) -> Predicate { + Predicate::Set(SetExpression::new( + PredicateOperator::In, + self, + FnvHashSet::from_iter(literals), + )) } /// Creates an is-not-in expression. For example, `a IS NOT IN (5, 6)`. @@ -180,16 +149,20 @@ impl Reference { /// /// ```rust /// - /// use std::collections::HashSet; + /// use fnv::FnvHashSet; /// use iceberg::expr::Reference; /// use iceberg::spec::Datum; - /// let expr = Reference::new("a").is_not_in(HashSet::from([Datum::long(5), Datum::long(6)])); + /// let expr = Reference::new("a").is_not_in([Datum::long(5), Datum::long(6)]); /// /// let as_string = format!("{expr}"); /// assert!(&as_string == "a NOT IN (5, 6)" || &as_string == "a NOT IN (6, 5)"); /// ``` - pub fn is_not_in(self, literals: HashSet) -> Predicate { - Predicate::Set(SetExpression::new(PredicateOperator::NotIn, self, literals)) + pub fn is_not_in(self, literals: impl IntoIterator) -> Predicate { + Predicate::Set(SetExpression::new( + PredicateOperator::NotIn, + self, + FnvHashSet::from_iter(literals), + )) } } @@ -254,9 +227,10 @@ pub type BoundTerm = BoundReference; #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::expr::{Bind, BoundReference, Reference}; use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type}; - use std::sync::Arc; fn table_schema_simple() -> SchemaRef { Arc::new( @@ -289,10 +263,10 @@ mod tests { #[test] fn test_bind_reference_case_insensitive() { let schema = table_schema_simple(); - let reference = Reference::new("BaR").bind(schema, false).unwrap(); + let reference = Reference::new("BAR").bind(schema, false).unwrap(); let expected_ref = BoundReference::new( - "BaR", + "BAR", NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), ); @@ -310,7 +284,7 @@ mod tests { #[test] fn test_bind_reference_case_insensitive_failure() { let schema = table_schema_simple(); - let result = Reference::new("BaR_non").bind(schema, false); + let result = Reference::new("bar_non_exist").bind(schema, false); assert!(result.is_err()); } } diff --git a/crates/iceberg/src/spec/schema.rs b/crates/iceberg/src/spec/schema.rs index ecd7c702dc..975a2a9ef7 100644 --- a/crates/iceberg/src/spec/schema.rs +++ b/crates/iceberg/src/spec/schema.rs @@ -1060,8 +1060,8 @@ table { ("qUUx", 6), ("QUUX.KEY", 7), ("QUUX.Value", 8), - ("qUUX.valUE.Key", 9), - ("qUux.vALUE.Value", 10), + ("qUUX.VALUE.Key", 9), + ("qUux.VaLue.Value", 10), ("lOCAtION", 11), ("LOCAtioN.ELeMENt", 12), ("LoCATion.element.LATitude", 13), diff --git a/crates/iceberg/src/spec/values.rs b/crates/iceberg/src/spec/values.rs index c0d3f09bdb..0595773eb3 100644 --- a/crates/iceberg/src/spec/values.rs +++ b/crates/iceberg/src/spec/values.rs @@ -106,7 +106,7 @@ impl Display for Datum { (_, PrimitiveLiteral::TimestampTZ(val)) => { write!(f, "{}", microseconds_to_datetimetz(*val)) } - (_, PrimitiveLiteral::String(val)) => write!(f, "{}", val), + (_, PrimitiveLiteral::String(val)) => write!(f, r#""{}""#, val), (_, PrimitiveLiteral::UUID(val)) => write!(f, "{}", val), (_, PrimitiveLiteral::Fixed(val)) => display_bytes(val, f), (_, PrimitiveLiteral::Binary(val)) => display_bytes(val, f), From 8fd070234feb44a68ae9c4f3c387123e558420f0 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Thu, 7 Mar 2024 17:37:40 +0800 Subject: [PATCH 3/6] Fix doc test --- crates/iceberg/src/spec/values.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/iceberg/src/spec/values.rs b/crates/iceberg/src/spec/values.rs index 0595773eb3..00f2e57d2b 100644 --- a/crates/iceberg/src/spec/values.rs +++ b/crates/iceberg/src/spec/values.rs @@ -529,7 +529,7 @@ impl Datum { /// use iceberg::spec::Datum; /// let t = Datum::string("ss"); /// - /// assert_eq!(&format!("{t}"), "ss"); + /// assert_eq!(&format!("{t}"), r#""ss""#); /// ``` pub fn string(s: S) -> Self { Self { From 0422f07ce5bfe4f458e652e30f175ae9fa7a7f1f Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Sat, 9 Mar 2024 19:56:32 +0800 Subject: [PATCH 4/6] Fix comments --- Cargo.toml | 1 + crates/iceberg/Cargo.toml | 1 + crates/iceberg/src/expr/mod.rs | 2 +- crates/iceberg/src/expr/predicate.rs | 34 ++++++++++------------------ 4 files changed, 15 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 697317c1d9..c6a9e15e45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ rust-version = "1.75.0" [workspace.dependencies] anyhow = "1.0.72" apache-avro = "0.16" +array-init = "2" arrow-arith = { version = ">=46" } arrow-array = { version = ">=46" } arrow-schema = { version = ">=46" } diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml index ee6b2ff4f5..eb60412bb1 100644 --- a/crates/iceberg/Cargo.toml +++ b/crates/iceberg/Cargo.toml @@ -31,6 +31,7 @@ keywords = ["iceberg"] [dependencies] anyhow = { workspace = true } apache-avro = { workspace = true } +array-init = { workspace = true } arrow-arith = { workspace = true } arrow-array = { workspace = true } arrow-schema = { workspace = true } diff --git a/crates/iceberg/src/expr/mod.rs b/crates/iceberg/src/expr/mod.rs index 9b5282ba79..567cf7e913 100644 --- a/crates/iceberg/src/expr/mod.rs +++ b/crates/iceberg/src/expr/mod.rs @@ -151,7 +151,7 @@ impl PredicateOperator { /// Bind expression to a schema. pub trait Bind { - /// The type of the bounded result. + /// The type of the bound result. type Bound; /// Bind an expression to a schema. fn bind(self, schema: SchemaRef, case_sensitive: bool) -> crate::Result; diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index fb443a87eb..dd3e08017d 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -20,9 +20,9 @@ //! `a > 10` is a predicate expression, and it evaluates to `true` if `a` is greater than `10`, use std::fmt::{Debug, Display, Formatter}; -use std::mem::MaybeUninit; use std::ops::Not; +use array_init::array_init; use fnv::FnvHashSet; use itertools::Itertools; @@ -60,25 +60,20 @@ impl LogicalExpression { } } -impl Bind for LogicalExpression { +impl Bind for LogicalExpression +where + T::Bound: Sized, +{ type Bound = LogicalExpression; fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result { - let mut bound_inputs = MaybeUninit::<[Box; N]>::uninit(); + let mut outputs: [Option>; N] = array_init(|_| None); for (i, input) in self.inputs.into_iter().enumerate() { - let input = input.bind(schema.clone(), case_sensitive)?; - // SAFETY: The pointer is valid from [`MaybeUninit`]. - unsafe { - bound_inputs - .as_mut_ptr() - .cast::>() - .add(i) - .write(Box::new(input)); - } + outputs[i] = Some(Box::new(input.bind(schema.clone(), case_sensitive)?)); } - // SAFETY: We have initialized all elements of the array. - let bound_inputs = unsafe { bound_inputs.assume_init() }; + // It's safe to use `unwrap` here since they are all `Some`. + let bound_inputs = array_init::from_iter(outputs.into_iter().map(Option::unwrap)).unwrap(); Ok(LogicalExpression::new(bound_inputs)) } } @@ -250,13 +245,8 @@ impl Bind for Predicate { }) } Predicate::Not(expr) => { - let bound_expr = expr.bind(schema, case_sensitive)?; - let [inner] = bound_expr.inputs; - Ok(match inner { - e if matches!(&*e, &BoundPredicate::AlwaysTrue) => BoundPredicate::AlwaysFalse, - e if matches!(&*e, &BoundPredicate::AlwaysFalse) => BoundPredicate::AlwaysTrue, - e => BoundPredicate::Not(LogicalExpression::new([e])), - }) + let [inner] = expr.inputs; + inner.negate().bind(schema, case_sensitive) } Predicate::Or(expr) => { let bound_expr = expr.bind(schema, case_sensitive)?; @@ -857,7 +847,7 @@ mod tests { let schema = table_schema_simple(); let expr = !Reference::new("bar").less_than(Datum::int(10)); let bound_expr = expr.bind(schema, true).unwrap(); - assert_eq!(&format!("{bound_expr}"), "NOT (bar < 10)"); + assert_eq!(&format!("{bound_expr}"), "bar >= 10"); } #[test] From 665f47a65ee7785794a86d322e647a8526143302 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Mon, 11 Mar 2024 11:24:54 +0800 Subject: [PATCH 5/6] Revert not rewrite in binding --- crates/iceberg/src/expr/predicate.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index dd3e08017d..7c8da24019 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -245,8 +245,13 @@ impl Bind for Predicate { }) } Predicate::Not(expr) => { - let [inner] = expr.inputs; - inner.negate().bind(schema, case_sensitive) + let bound_expr = expr.bind(schema, case_sensitive)?; + let [inner] = bound_expr.inputs; + Ok(match inner { + e if matches!(&*e, &BoundPredicate::AlwaysTrue) => BoundPredicate::AlwaysFalse, + e if matches!(&*e, &BoundPredicate::AlwaysFalse) => BoundPredicate::AlwaysTrue, + e => BoundPredicate::Not(LogicalExpression::new([e])), + }) } Predicate::Or(expr) => { let bound_expr = expr.bind(schema, case_sensitive)?; @@ -847,7 +852,7 @@ mod tests { let schema = table_schema_simple(); let expr = !Reference::new("bar").less_than(Datum::int(10)); let bound_expr = expr.bind(schema, true).unwrap(); - assert_eq!(&format!("{bound_expr}"), "bar >= 10"); + assert_eq!(&format!("{bound_expr}"), "NOT (bar < 10)"); } #[test] From 775f5e8dc8874364e7c7b55f905488c422d84c68 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Tue, 12 Mar 2024 11:29:12 +0800 Subject: [PATCH 6/6] Fix comments --- crates/iceberg/src/expr/predicate.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index 7c8da24019..4ab9aae30f 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -257,10 +257,10 @@ impl Bind for Predicate { let bound_expr = expr.bind(schema, case_sensitive)?; let [left, right] = bound_expr.inputs; Ok(match (left, right) { - (_, r) if matches!(&*r, &BoundPredicate::AlwaysTrue) => { - BoundPredicate::AlwaysTrue - } - (l, _) if matches!(&*l, &BoundPredicate::AlwaysTrue) => { + (l, r) + if matches!(&*r, &BoundPredicate::AlwaysTrue) + || matches!(&*l, &BoundPredicate::AlwaysTrue) => + { BoundPredicate::AlwaysTrue } (left, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => *left, @@ -296,7 +296,7 @@ impl Bind for Predicate { op => { return Err(Error::new( ErrorKind::Unexpected, - format!("Expecting unary operator,but found {op}"), + format!("Expecting unary operator, but found {op}"), )) } }