diff --git a/crates/iceberg/src/expr/mod.rs b/crates/iceberg/src/expr/mod.rs index dccafb79af..3d77c4df88 100644 --- a/crates/iceberg/src/expr/mod.rs +++ b/crates/iceberg/src/expr/mod.rs @@ -18,15 +18,14 @@ //! This module contains expressions. mod term; - -use std::fmt::{Display, Formatter}; - pub use term::*; pub(crate) mod accessor; mod predicate; +pub(crate) mod visitors; +pub use predicate::*; use crate::spec::SchemaRef; -pub use predicate::*; +use std::fmt::{Display, Formatter}; /// Predicate operators used in expressions. /// @@ -34,6 +33,7 @@ pub use predicate::*; /// [`PredicateOperator::is_unary`], [`PredicateOperator::is_binary`], [`PredicateOperator::is_set`] #[allow(missing_docs)] #[derive(Debug, Clone, Copy, PartialEq)] +#[non_exhaustive] #[repr(u16)] pub enum PredicateOperator { // Unary operators diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index 530923f15c..1163f3bf06 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -119,6 +119,10 @@ impl UnaryExpression { pub(crate) fn op(&self) -> PredicateOperator { self.op } + + pub(crate) fn term(&self) -> &T { + &self.term + } } /// Binary predicate, for example, `a > 10`. @@ -147,12 +151,17 @@ impl BinaryExpression { debug_assert!(op.is_binary()); Self { op, term, literal } } + pub(crate) fn op(&self) -> PredicateOperator { self.op } pub(crate) fn literal(&self) -> &Datum { &self.literal } + + pub(crate) fn term(&self) -> &T { + &self.term + } } impl Display for BinaryExpression { @@ -200,12 +209,17 @@ impl SetExpression { debug_assert!(op.is_set()); Self { op, term, literals } } + pub(crate) fn op(&self) -> PredicateOperator { self.op } pub(crate) fn literals(&self) -> &FnvHashSet { &self.literals } + + pub(crate) fn term(&self) -> &T { + &self.term + } } impl Bind for SetExpression { @@ -232,6 +246,10 @@ impl Display for SetExpression { /// Unbound predicate expression before binding to a schema. #[derive(Debug, PartialEq)] pub enum Predicate { + /// AlwaysTrue predicate, for example, `TRUE`. + AlwaysTrue, + /// AlwaysFalse predicate, for example, `FALSE`. + AlwaysFalse, /// And predicate, for example, `a > 10 AND b < 20`. And(LogicalExpression), /// Or predicate, for example, `a > 10 OR b < 20`. @@ -382,6 +400,8 @@ impl Bind for Predicate { bound_literals, ))) } + Predicate::AlwaysTrue => Ok(BoundPredicate::AlwaysTrue), + Predicate::AlwaysFalse => Ok(BoundPredicate::AlwaysFalse), } } } @@ -389,6 +409,12 @@ impl Bind for Predicate { impl Display for Predicate { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { + Predicate::AlwaysTrue => { + write!(f, "TRUE") + } + Predicate::AlwaysFalse => { + write!(f, "FALSE") + } Predicate::And(expr) => { write!(f, "({}) AND ({})", expr.inputs()[0], expr.inputs()[1]) } @@ -476,6 +502,8 @@ impl Predicate { /// ``` pub fn negate(self) -> Predicate { match self { + Predicate::AlwaysTrue => Predicate::AlwaysFalse, + Predicate::AlwaysFalse => Predicate::AlwaysTrue, Predicate::And(expr) => Predicate::Or(LogicalExpression::new( expr.inputs.map(|expr| Box::new(expr.negate())), )), @@ -540,6 +568,8 @@ impl Predicate { Predicate::Unary(expr) => Predicate::Unary(expr), Predicate::Binary(expr) => Predicate::Binary(expr), Predicate::Set(expr) => Predicate::Set(expr), + Predicate::AlwaysTrue => Predicate::AlwaysTrue, + Predicate::AlwaysFalse => Predicate::AlwaysFalse, } } } diff --git a/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs b/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs new file mode 100644 index 0000000000..f29afbf619 --- /dev/null +++ b/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs @@ -0,0 +1,644 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expr::{BoundPredicate, BoundReference, PredicateOperator}; +use crate::spec::Datum; +use crate::Result; +use fnv::FnvHashSet; + +/// A visitor for [`BoundPredicate`]s. Visits in post-order. +pub trait BoundPredicateVisitor { + /// The return type of this visitor + type T; + + /// Called after an `AlwaysTrue` predicate is visited + fn always_true(&mut self) -> Result; + + /// Called after an `AlwaysFalse` predicate is visited + fn always_false(&mut self) -> Result; + + /// Called after an `And` predicate is visited + fn and(&mut self, lhs: Self::T, rhs: Self::T) -> Result; + + /// Called after an `Or` predicate is visited + fn or(&mut self, lhs: Self::T, rhs: Self::T) -> Result; + + /// Called after a `Not` predicate is visited + fn not(&mut self, inner: Self::T) -> Result; + + /// Called after a predicate with an `IsNull` operator is visited + fn is_null(&mut self, reference: &BoundReference) -> Result; + + /// Called after a predicate with a `NotNull` operator is visited + fn not_null(&mut self, reference: &BoundReference) -> Result; + + /// Called after a predicate with an `IsNan` operator is visited + fn is_nan(&mut self, reference: &BoundReference) -> Result; + + /// Called after a predicate with a `NotNan` operator is visited + fn not_nan(&mut self, reference: &BoundReference) -> Result; + + /// Called after a predicate with a `LessThan` operator is visited + fn less_than(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + + /// Called after a predicate with a `LessThanOrEq` operator is visited + fn less_than_or_eq(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + + /// Called after a predicate with a `GreaterThan` operator is visited + fn greater_than(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + + /// Called after a predicate with a `GreaterThanOrEq` operator is visited + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + ) -> Result; + + /// Called after a predicate with an `Eq` operator is visited + fn eq(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + + /// Called after a predicate with a `NotEq` operator is visited + fn not_eq(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + + /// Called after a predicate with a `StartsWith` operator is visited + fn starts_with(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + + /// Called after a predicate with a `NotStartsWith` operator is visited + fn not_starts_with(&mut self, reference: &BoundReference, literal: &Datum) -> Result; + + /// Called after a predicate with an `In` operator is visited + fn r#in(&mut self, reference: &BoundReference, literals: &FnvHashSet) + -> Result; + + /// Called after a predicate with a `NotIn` operator is visited + fn not_in( + &mut self, + reference: &BoundReference, + literals: &FnvHashSet, + ) -> Result; +} + +/// Visits a [`BoundPredicate`] with the provided visitor, +/// in post-order +pub(crate) fn visit( + visitor: &mut V, + predicate: &BoundPredicate, +) -> Result { + match predicate { + BoundPredicate::AlwaysTrue => visitor.always_true(), + BoundPredicate::AlwaysFalse => visitor.always_false(), + BoundPredicate::And(expr) => { + let [left_pred, right_pred] = expr.inputs(); + + let left_result = visit(visitor, left_pred)?; + let right_result = visit(visitor, right_pred)?; + + visitor.and(left_result, right_result) + } + BoundPredicate::Or(expr) => { + let [left_pred, right_pred] = expr.inputs(); + + let left_result = visit(visitor, left_pred)?; + let right_result = visit(visitor, right_pred)?; + + visitor.or(left_result, right_result) + } + BoundPredicate::Not(expr) => { + let [inner_pred] = expr.inputs(); + + let inner_result = visit(visitor, inner_pred)?; + + visitor.not(inner_result) + } + BoundPredicate::Unary(expr) => match expr.op() { + PredicateOperator::IsNull => visitor.is_null(expr.term()), + PredicateOperator::NotNull => visitor.not_null(expr.term()), + PredicateOperator::IsNan => visitor.is_nan(expr.term()), + PredicateOperator::NotNan => visitor.not_nan(expr.term()), + op => { + panic!("Unexpected op for unary predicate: {}", &op) + } + }, + BoundPredicate::Binary(expr) => { + let reference = expr.term(); + let literal = expr.literal(); + match expr.op() { + PredicateOperator::LessThan => visitor.less_than(reference, literal), + PredicateOperator::LessThanOrEq => visitor.less_than_or_eq(reference, literal), + PredicateOperator::GreaterThan => visitor.greater_than(reference, literal), + PredicateOperator::GreaterThanOrEq => { + visitor.greater_than_or_eq(reference, literal) + } + PredicateOperator::Eq => visitor.eq(reference, literal), + PredicateOperator::NotEq => visitor.not_eq(reference, literal), + PredicateOperator::StartsWith => visitor.starts_with(reference, literal), + PredicateOperator::NotStartsWith => visitor.not_starts_with(reference, literal), + op => { + panic!("Unexpected op for binary predicate: {}", &op) + } + } + } + BoundPredicate::Set(expr) => { + let reference = expr.term(); + let literals = expr.literals(); + match expr.op() { + PredicateOperator::In => visitor.r#in(reference, literals), + PredicateOperator::NotIn => visitor.not_in(reference, literals), + op => { + panic!("Unexpected op for set predicate: {}", &op) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor}; + use crate::expr::{ + BinaryExpression, Bind, BoundReference, Predicate, PredicateOperator, Reference, + SetExpression, UnaryExpression, + }; + use crate::spec::{Datum, NestedField, PrimitiveType, Schema, SchemaRef, Type}; + use fnv::FnvHashSet; + use std::ops::Not; + use std::sync::Arc; + + struct TestEvaluator {} + impl BoundPredicateVisitor for TestEvaluator { + type T = bool; + + fn always_true(&mut self) -> crate::Result { + Ok(true) + } + + fn always_false(&mut self) -> crate::Result { + Ok(false) + } + + fn and(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result { + Ok(lhs && rhs) + } + + fn or(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result { + Ok(lhs || rhs) + } + + fn not(&mut self, inner: Self::T) -> crate::Result { + Ok(!inner) + } + + fn is_null(&mut self, _reference: &BoundReference) -> crate::Result { + Ok(true) + } + + fn not_null(&mut self, _reference: &BoundReference) -> crate::Result { + Ok(false) + } + + fn is_nan(&mut self, _reference: &BoundReference) -> crate::Result { + Ok(true) + } + + fn not_nan(&mut self, _reference: &BoundReference) -> crate::Result { + Ok(false) + } + + fn less_than( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(true) + } + + fn less_than_or_eq( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(false) + } + + fn greater_than( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(true) + } + + fn greater_than_or_eq( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(false) + } + + fn eq(&mut self, _reference: &BoundReference, _literal: &Datum) -> crate::Result { + Ok(true) + } + + fn not_eq(&mut self, _reference: &BoundReference, _literal: &Datum) -> crate::Result { + Ok(false) + } + + fn starts_with( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(true) + } + + fn not_starts_with( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + ) -> crate::Result { + Ok(false) + } + + fn r#in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + ) -> crate::Result { + Ok(true) + } + + fn not_in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + ) -> crate::Result { + Ok(false) + } + } + + fn create_test_schema() -> SchemaRef { + let schema = Schema::builder() + .with_fields(vec![ + Arc::new(NestedField::required( + 1, + "a", + Type::Primitive(PrimitiveType::Int), + )), + Arc::new(NestedField::required( + 2, + "b", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 3, + "c", + Type::Primitive(PrimitiveType::Float), + )), + ]) + .build() + .unwrap(); + + let schema_arc = Arc::new(schema); + schema_arc.clone() + } + + #[test] + fn test_always_true() { + let predicate = Predicate::AlwaysTrue; + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_always_false() { + let predicate = Predicate::AlwaysFalse; + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_logical_and() { + let predicate = Predicate::AlwaysTrue.and(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + + let predicate = Predicate::AlwaysFalse.and(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + + let predicate = Predicate::AlwaysTrue.and(Predicate::AlwaysTrue); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_logical_or() { + let predicate = Predicate::AlwaysTrue.or(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + + let predicate = Predicate::AlwaysFalse.or(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + + let predicate = Predicate::AlwaysTrue.or(Predicate::AlwaysTrue); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not() { + let predicate = Predicate::AlwaysFalse.not(); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + + let predicate = Predicate::AlwaysTrue.not(); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_is_null() { + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + Reference::new("c"), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not_null() { + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNull, + Reference::new("a"), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_is_nan() { + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("b"), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not_nan() { + let predicate = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("b"), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_less_than() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_less_than_or_eq() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThanOrEq, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_greater_than() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThan, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_greater_than_or_eq() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_eq() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not_eq() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotEq, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_starts_with() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::StartsWith, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not_starts_with() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotStartsWith, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_in() { + let predicate = Predicate::Set(SetExpression::new( + PredicateOperator::In, + Reference::new("a"), + FnvHashSet::from_iter(vec![Datum::int(1)]), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not_in() { + let predicate = Predicate::Set(SetExpression::new( + PredicateOperator::NotIn, + Reference::new("a"), + FnvHashSet::from_iter(vec![Datum::int(1)]), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } +} diff --git a/crates/iceberg/src/expr/visitors/mod.rs b/crates/iceberg/src/expr/visitors/mod.rs new file mode 100644 index 0000000000..242a55c4b0 --- /dev/null +++ b/crates/iceberg/src/expr/visitors/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(crate) mod bound_predicate_visitor;