From c86c7687c7ea62e45e495225d3ed995c20aa21b2 Mon Sep 17 00:00:00 2001 From: Scott Donnelly Date: Thu, 4 Apr 2024 20:49:02 +0100 Subject: [PATCH] feat: add BoundPredicateVisitor. Add AlwaysTrue and AlwaysFalse to Predicate --- crates/iceberg/src/expr/mod.rs | 8 +- crates/iceberg/src/expr/predicate.rs | 30 ++ .../expr/visitors/bound_predicate_visitor.rs | 317 ++++++++++++++++++ crates/iceberg/src/expr/visitors/mod.rs | 18 + 4 files changed, 369 insertions(+), 4 deletions(-) create mode 100644 crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs create mode 100644 crates/iceberg/src/expr/visitors/mod.rs diff --git a/crates/iceberg/src/expr/mod.rs b/crates/iceberg/src/expr/mod.rs index dccafb79af..3d77c4df88 100644 --- a/crates/iceberg/src/expr/mod.rs +++ b/crates/iceberg/src/expr/mod.rs @@ -18,15 +18,14 @@ //! This module contains expressions. mod term; - -use std::fmt::{Display, Formatter}; - pub use term::*; pub(crate) mod accessor; mod predicate; +pub(crate) mod visitors; +pub use predicate::*; use crate::spec::SchemaRef; -pub use predicate::*; +use std::fmt::{Display, Formatter}; /// Predicate operators used in expressions. /// @@ -34,6 +33,7 @@ pub use predicate::*; /// [`PredicateOperator::is_unary`], [`PredicateOperator::is_binary`], [`PredicateOperator::is_set`] #[allow(missing_docs)] #[derive(Debug, Clone, Copy, PartialEq)] +#[non_exhaustive] #[repr(u16)] pub enum PredicateOperator { // Unary operators diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index 530923f15c..1163f3bf06 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -119,6 +119,10 @@ impl UnaryExpression { pub(crate) fn op(&self) -> PredicateOperator { self.op } + + pub(crate) fn term(&self) -> &T { + &self.term + } } /// Binary predicate, for example, `a > 10`. @@ -147,12 +151,17 @@ impl BinaryExpression { debug_assert!(op.is_binary()); Self { op, term, literal } } + pub(crate) fn op(&self) -> PredicateOperator { self.op } pub(crate) fn literal(&self) -> &Datum { &self.literal } + + pub(crate) fn term(&self) -> &T { + &self.term + } } impl Display for BinaryExpression { @@ -200,12 +209,17 @@ impl SetExpression { debug_assert!(op.is_set()); Self { op, term, literals } } + pub(crate) fn op(&self) -> PredicateOperator { self.op } pub(crate) fn literals(&self) -> &FnvHashSet { &self.literals } + + pub(crate) fn term(&self) -> &T { + &self.term + } } impl Bind for SetExpression { @@ -232,6 +246,10 @@ impl Display for SetExpression { /// Unbound predicate expression before binding to a schema. #[derive(Debug, PartialEq)] pub enum Predicate { + /// AlwaysTrue predicate, for example, `TRUE`. + AlwaysTrue, + /// AlwaysFalse predicate, for example, `FALSE`. + AlwaysFalse, /// And predicate, for example, `a > 10 AND b < 20`. And(LogicalExpression), /// Or predicate, for example, `a > 10 OR b < 20`. @@ -382,6 +400,8 @@ impl Bind for Predicate { bound_literals, ))) } + Predicate::AlwaysTrue => Ok(BoundPredicate::AlwaysTrue), + Predicate::AlwaysFalse => Ok(BoundPredicate::AlwaysFalse), } } } @@ -389,6 +409,12 @@ impl Bind for Predicate { impl Display for Predicate { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { + Predicate::AlwaysTrue => { + write!(f, "TRUE") + } + Predicate::AlwaysFalse => { + write!(f, "FALSE") + } Predicate::And(expr) => { write!(f, "({}) AND ({})", expr.inputs()[0], expr.inputs()[1]) } @@ -476,6 +502,8 @@ impl Predicate { /// ``` pub fn negate(self) -> Predicate { match self { + Predicate::AlwaysTrue => Predicate::AlwaysFalse, + Predicate::AlwaysFalse => Predicate::AlwaysTrue, Predicate::And(expr) => Predicate::Or(LogicalExpression::new( expr.inputs.map(|expr| Box::new(expr.negate())), )), @@ -540,6 +568,8 @@ impl Predicate { Predicate::Unary(expr) => Predicate::Unary(expr), Predicate::Binary(expr) => Predicate::Binary(expr), Predicate::Set(expr) => Predicate::Set(expr), + Predicate::AlwaysTrue => Predicate::AlwaysTrue, + Predicate::AlwaysFalse => Predicate::AlwaysFalse, } } } diff --git a/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs b/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs new file mode 100644 index 0000000000..6d36fa146a --- /dev/null +++ b/crates/iceberg/src/expr/visitors/bound_predicate_visitor.rs @@ -0,0 +1,317 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expr::{BoundPredicate, BoundReference, PredicateOperator}; +use crate::spec::Datum; +use crate::Result; +use fnv::FnvHashSet; + +pub(crate) enum OpLiteral<'a> { + Single(&'a Datum), + Set(&'a FnvHashSet), +} + +/// A visitor for [`BoundPredicate`]s. Visits in post-order. +pub trait BoundPredicateVisitor { + /// The return type of this visitor + type T; + + /// Called after an `AlwaysTrue` predicate is visited + fn always_true(&mut self) -> Result; + + /// Called after an `AlwaysFalse` predicate is visited + fn always_false(&mut self) -> Result; + + /// Called after an `And` predicate is visited + fn and(&mut self, lhs: Self::T, rhs: Self::T) -> Result; + + /// Called after an `Or` predicate is visited + fn or(&mut self, lhs: Self::T, rhs: Self::T) -> Result; + + /// Called after a `Not` predicate is visited + fn not(&mut self, inner: Self::T) -> Result; + + /// Called after visiting a UnaryPredicate, BinaryPredicate, + /// or SetPredicate. Passes the predicate's operator in all cases, + /// as well as the term and literals in the case of binary and set + /// predicates. + fn op( + &mut self, + op: PredicateOperator, + reference: &BoundReference, + literal: Option, + predicate: &BoundPredicate, + ) -> Result; +} + +/// Visits a [`BoundPredicate`] with the provided visitor, +/// in post-order +pub(crate) fn visit( + visitor: &mut V, + predicate: &BoundPredicate, +) -> Result { + match predicate { + BoundPredicate::AlwaysTrue => visitor.always_true(), + BoundPredicate::AlwaysFalse => visitor.always_false(), + BoundPredicate::And(expr) => { + let [left_pred, right_pred] = expr.inputs(); + + let left_result = visit(visitor, left_pred)?; + let right_result = visit(visitor, right_pred)?; + + visitor.and(left_result, right_result) + } + BoundPredicate::Or(expr) => { + let [left_pred, right_pred] = expr.inputs(); + + let left_result = visit(visitor, left_pred)?; + let right_result = visit(visitor, right_pred)?; + + visitor.or(left_result, right_result) + } + BoundPredicate::Not(expr) => { + let [inner_pred] = expr.inputs(); + + let inner_result = visit(visitor, inner_pred)?; + + visitor.not(inner_result) + } + BoundPredicate::Unary(expr) => visitor.op(expr.op(), expr.term(), None, predicate), + BoundPredicate::Binary(expr) => visitor.op( + expr.op(), + expr.term(), + Some(OpLiteral::Single(expr.literal())), + predicate, + ), + BoundPredicate::Set(expr) => visitor.op( + expr.op(), + expr.term(), + Some(OpLiteral::Set(expr.literals())), + predicate, + ), + } +} + +#[cfg(test)] +mod tests { + use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor, OpLiteral}; + use crate::expr::{ + BinaryExpression, Bind, BoundPredicate, BoundReference, Predicate, PredicateOperator, + Reference, + }; + use crate::spec::{Datum, NestedField, PrimitiveType, Schema, SchemaRef, Type}; + use std::ops::Not; + use std::sync::Arc; + + struct TestEvaluator {} + impl BoundPredicateVisitor for TestEvaluator { + type T = bool; + + fn always_true(&mut self) -> crate::Result { + Ok(true) + } + + fn always_false(&mut self) -> crate::Result { + Ok(false) + } + + fn and(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result { + Ok(lhs && rhs) + } + + fn or(&mut self, lhs: Self::T, rhs: Self::T) -> crate::Result { + Ok(lhs || rhs) + } + + fn not(&mut self, inner: Self::T) -> crate::Result { + Ok(!inner) + } + + fn op( + &mut self, + op: PredicateOperator, + _reference: &BoundReference, + _literal: Option, + _predicate: &BoundPredicate, + ) -> crate::Result { + Ok(match op { + PredicateOperator::IsNull => true, + PredicateOperator::NotNull => false, + PredicateOperator::IsNan => true, + PredicateOperator::NotNan => false, + PredicateOperator::LessThan => true, + PredicateOperator::LessThanOrEq => false, + PredicateOperator::GreaterThan => true, + PredicateOperator::GreaterThanOrEq => false, + PredicateOperator::Eq => true, + PredicateOperator::NotEq => false, + PredicateOperator::StartsWith => true, + PredicateOperator::NotStartsWith => false, + PredicateOperator::In => false, + PredicateOperator::NotIn => true, + }) + } + } + + fn create_test_schema() -> SchemaRef { + let schema = Schema::builder() + .with_fields(vec![Arc::new(NestedField::required( + 1, + "a", + Type::Primitive(PrimitiveType::Int), + ))]) + .build() + .unwrap(); + + let schema_arc = Arc::new(schema); + schema_arc.clone() + } + + #[test] + fn test_always_true() { + let predicate = Predicate::AlwaysTrue; + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_always_false() { + let predicate = Predicate::AlwaysFalse; + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_logical_and() { + let predicate = Predicate::AlwaysTrue.and(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + + let predicate = Predicate::AlwaysFalse.and(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + + let predicate = Predicate::AlwaysTrue.and(Predicate::AlwaysTrue); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_logical_or() { + let predicate = Predicate::AlwaysTrue.or(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + + let predicate = Predicate::AlwaysFalse.or(Predicate::AlwaysFalse); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + + let predicate = Predicate::AlwaysTrue.or(Predicate::AlwaysTrue); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + } + + #[test] + fn test_not() { + let predicate = Predicate::AlwaysFalse.not(); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + + let predicate = Predicate::AlwaysTrue.not(); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } + + #[test] + fn test_op() { + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(result.unwrap()); + + let predicate = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThanOrEq, + Reference::new("a"), + Datum::int(10), + )); + let bound_predicate = predicate.bind(create_test_schema(), false).unwrap(); + + let mut test_evaluator = TestEvaluator {}; + + let result = visit(&mut test_evaluator, &bound_predicate); + + assert!(!result.unwrap()); + } +} diff --git a/crates/iceberg/src/expr/visitors/mod.rs b/crates/iceberg/src/expr/visitors/mod.rs new file mode 100644 index 0000000000..242a55c4b0 --- /dev/null +++ b/crates/iceberg/src/expr/visitors/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(crate) mod bound_predicate_visitor;