From 408cee1ccd544fa6e7d17b52582129bb86c362e6 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 28 Sep 2025 16:49:41 +0200 Subject: [PATCH 01/23] #17801 Improve nullability reporting of case expressions --- datafusion/core/tests/tpcds_planning.rs | 3 +- datafusion/expr/src/expr_fn.rs | 5 + datafusion/expr/src/expr_schema.rs | 199 +++++++++++++- .../physical-expr/src/expressions/case.rs | 251 +++++++++++++++++- 4 files changed, 449 insertions(+), 9 deletions(-) diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 252d76d0f9d9..bee3b48a574b 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1052,9 +1052,10 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { for sql in &sql { let df = ctx.sql(sql).await?; let (state, plan) = df.into_parts(); - let plan = state.optimize(&plan)?; if create_physical { let _ = state.create_physical_plan(&plan).await?; + } else { + let _ = state.optimize(&plan)?; } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4d8b94ba27ff..08ffab8e426b 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -340,6 +340,11 @@ pub fn is_null(expr: Expr) -> Expr { Expr::IsNull(Box::new(expr)) } +/// Create is not null expression +pub fn is_not_null(expr: Expr) -> Expr { + Expr::IsNotNull(Box::new(expr)) +} + /// Create is true expression pub fn is_true(expr: Expr) -> Expr { Expr::IsTrue(Box::new(expr)) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e803e3534130..553882619252 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -32,6 +32,7 @@ use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, }; +use datafusion_expr_common::operator::Operator; use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use std::sync::Arc; @@ -283,6 +284,11 @@ impl ExprSchemable for Expr { let then_nullable = case .when_then_expr .iter() + .filter(|(w, t)| { + // Disregard branches where we can determine statically that the predicate + // is always false when the then expression would evaluate to null + const_result_when_value_is_null(w, t).unwrap_or(true) + }) .map(|(_, t)| t.nullable(input_schema)) .collect::>>()?; if then_nullable.contains(&true) { @@ -647,6 +653,50 @@ impl ExprSchemable for Expr { } } +/// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`. +/// Returns a `Some` value containing the const result if so; otherwise returns `None`. +fn const_result_when_value_is_null(predicate: &Expr, value: &Expr) -> Option { + match predicate { + Expr::IsNotNull(e) => { + if e.as_ref().eq(value) { + Some(false) + } else { + None + } + } + Expr::IsNull(e) => { + if e.as_ref().eq(value) { + Some(true) + } else { + None + } + } + Expr::Not(e) => const_result_when_value_is_null(e, value).map(|b| !b), + Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Operator::And => { + let l = const_result_when_value_is_null(left, value); + let r = const_result_when_value_is_null(right, value); + match (l, r) { + (Some(l), Some(r)) => Some(l && r), + (Some(l), None) => Some(l), + (None, Some(r)) => Some(r), + _ => None, + } + } + Operator::Or => { + let l = const_result_when_value_is_null(left, value); + let r = const_result_when_value_is_null(right, value); + match (l, r) { + (Some(l), Some(r)) => Some(l || r), + _ => None, + } + } + _ => None, + }, + _ => None, + } +} + impl Expr { /// Common method for window functions that applies type coercion /// to all arguments of the window function to check if it matches @@ -777,7 +827,10 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result MockExprSchema, + ) -> Result<()> { + assert_eq!( + expr.nullable(&get_schema(true))?, + nullable, + "Nullability of '{expr}' should be {nullable} when column is nullable" + ); + assert!( + !expr.nullable(&get_schema(false))?, + "Nullability of '{expr}' should be false when column is not nullable" + ); + Ok(()) + } + + #[test] + fn test_case_expression_nullability() -> Result<()> { + let get_schema = |nullable| { + MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_nullable(nullable) + }; + + check_nullability( + when(is_not_null(col("foo")), col("foo")).otherwise(lit(0))?, + false, + get_schema, + )?; + + check_nullability( + when(not(is_null(col("foo"))), col("foo")).otherwise(lit(0))?, + false, + get_schema, + )?; + + check_nullability( + when(binary_expr(col("foo"), Operator::Eq, lit(5)), col("foo")) + .otherwise(lit(0))?, + true, + get_schema, + )?; + + check_nullability( + when( + and( + is_not_null(col("foo")), + binary_expr(col("foo"), Operator::Eq, lit(5)), + ), + col("foo"), + ) + .otherwise(lit(0))?, + false, + get_schema, + )?; + + check_nullability( + when( + and( + binary_expr(col("foo"), Operator::Eq, lit(5)), + is_not_null(col("foo")), + ), + col("foo"), + ) + .otherwise(lit(0))?, + false, + get_schema, + )?; + + check_nullability( + when( + or( + is_not_null(col("foo")), + binary_expr(col("foo"), Operator::Eq, lit(5)), + ), + col("foo"), + ) + .otherwise(lit(0))?, + true, + get_schema, + )?; + + check_nullability( + when( + or( + binary_expr(col("foo"), Operator::Eq, lit(5)), + is_not_null(col("foo")), + ), + col("foo"), + ) + .otherwise(lit(0))?, + true, + get_schema, + )?; + + check_nullability( + when( + or( + is_not_null(col("foo")), + binary_expr(col("foo"), Operator::Eq, lit(5)), + ), + col("foo"), + ) + .otherwise(lit(0))?, + true, + get_schema, + )?; + + check_nullability( + when( + or( + binary_expr(col("foo"), Operator::Eq, lit(5)), + is_not_null(col("foo")), + ), + col("foo"), + ) + .otherwise(lit(0))?, + true, + get_schema, + )?; + + check_nullability( + when( + or( + and( + binary_expr(col("foo"), Operator::Eq, lit(5)), + is_not_null(col("foo")), + ), + and( + binary_expr(col("foo"), Operator::Eq, col("bar")), + is_not_null(col("foo")), + ), + ), + col("foo"), + ) + .otherwise(lit(0))?, + false, + get_schema, + )?; + + Ok(()) + } + #[test] fn test_inlist_nullability() { let get_schema = |nullable| { diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 5409cfe8e7e4..b9da8fddb4bf 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -15,12 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::try_cast; +use crate::expressions::{try_cast, BinaryExpr, IsNotNullExpr, IsNullExpr, NotExpr}; use crate::PhysicalExpr; -use std::borrow::Cow; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; @@ -30,8 +26,12 @@ use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; +use std::borrow::Cow; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; use super::{Column, Literal}; +use datafusion_expr_common::operator::Operator; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; @@ -481,6 +481,11 @@ impl PhysicalExpr for CaseExpr { let then_nullable = self .when_then_expr .iter() + .filter(|(w, t)| { + // Disregard branches where we can determine statically that the predicate + // is always false when the then expression would evaluate to null + const_result_when_value_is_null(w.as_ref(), t.as_ref()).unwrap_or(true) + }) .map(|(_, t)| t.nullable(input_schema)) .collect::>>()?; if then_nullable.contains(&true) { @@ -588,6 +593,54 @@ impl PhysicalExpr for CaseExpr { } } +/// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`. +/// Returns a `Some` value containing the const result if so; otherwise returns `None`. +fn const_result_when_value_is_null( + predicate: &dyn PhysicalExpr, + value: &dyn PhysicalExpr, +) -> Option { + let predicate_any = predicate.as_any(); + if let Some(not_null) = predicate_any.downcast_ref::() { + if not_null.arg().as_ref().dyn_eq(value) { + Some(false) + } else { + None + } + } else if let Some(null) = predicate_any.downcast_ref::() { + if null.arg().as_ref().dyn_eq(value) { + Some(true) + } else { + None + } + } else if let Some(not) = predicate_any.downcast_ref::() { + const_result_when_value_is_null(not.arg().as_ref(), value).map(|b| !b) + } else if let Some(binary) = predicate_any.downcast_ref::() { + match binary.op() { + Operator::And => { + let l = const_result_when_value_is_null(binary.left().as_ref(), value); + let r = const_result_when_value_is_null(binary.right().as_ref(), value); + match (l, r) { + (Some(l), Some(r)) => Some(l && r), + (Some(l), None) => Some(l), + (None, Some(r)) => Some(r), + _ => None, + } + } + Operator::Or => { + let l = const_result_when_value_is_null(binary.left().as_ref(), value); + let r = const_result_when_value_is_null(binary.right().as_ref(), value); + match (l, r) { + (Some(l), Some(r)) => Some(l || r), + _ => None, + } + } + _ => None, + } + } else { + None + } +} + /// Create a CASE expression pub fn case( expr: Option>, @@ -601,7 +654,8 @@ pub fn case( mod tests { use super::*; - use crate::expressions::{binary, cast, col, lit, BinaryExpr}; + use crate::expressions; + use crate::expressions::{binary, cast, col, is_not_null, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::Field; @@ -609,7 +663,6 @@ mod tests { use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::type_coercion::binary::comparison_coercion; - use datafusion_expr::Operator; use datafusion_physical_expr_common::physical_expr::fmt_sql; #[test] @@ -1435,4 +1488,188 @@ mod tests { Ok(()) } + + fn when_then_else( + when: &Arc, + then: &Arc, + els: &Arc, + ) -> Result> { + let case = CaseExpr::try_new( + None, + vec![(Arc::clone(when), Arc::clone(then))], + Some(Arc::clone(els)), + )?; + Ok(Arc::new(case)) + } + + #[test] + fn test_case_expression_nullability_with_nullable_column() -> Result<()> { + case_expression_nullability(true) + } + + #[test] + fn test_case_expression_nullability_with_not_nullable_column() -> Result<()> { + case_expression_nullability(false) + } + + fn case_expression_nullability(col_is_nullable: bool) -> Result<()> { + let schema = + Schema::new(vec![Field::new("foo", DataType::Int32, col_is_nullable)]); + + let foo = col("foo", &schema)?; + let foo_is_not_null = is_not_null(Arc::clone(&foo))?; + let foo_is_null = expressions::is_null(Arc::clone(&foo))?; + let not_foo_is_null = expressions::not(Arc::clone(&foo_is_null))?; + let zero = lit(0); + let foo_eq_zero = + binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?; + + assert_not_nullable(when_then_else(&foo_is_not_null, &foo, &zero)?, &schema); + assert_not_nullable(when_then_else(¬_foo_is_null, &foo, &zero)?, &schema); + assert_nullability( + when_then_else(&foo_eq_zero, &foo, &zero)?, + &schema, + col_is_nullable, + ); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_is_not_null), + Operator::And, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_eq_zero), + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_nullability( + when_then_else( + &binary( + Arc::clone(&foo_is_not_null), + Operator::Or, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_nullability( + when_then_else( + &binary( + Arc::clone(&foo_eq_zero), + Operator::Or, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_nullability( + when_then_else( + &binary( + Arc::clone(&foo_is_not_null), + Operator::Or, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_nullability( + when_then_else( + &binary( + binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?, + Operator::Or, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_not_nullable( + when_then_else( + &binary( + binary( + binary( + Arc::clone(&foo), + Operator::Eq, + Arc::clone(&zero), + &schema, + )?, + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + Operator::Or, + binary( + binary( + Arc::clone(&foo), + Operator::Eq, + Arc::clone(&foo), + &schema, + )?, + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + Ok(()) + } + + fn assert_not_nullable(expr: Arc, schema: &Schema) { + assert!(!expr.nullable(schema).unwrap()); + } + + fn assert_nullable(expr: Arc, schema: &Schema) { + assert!(expr.nullable(schema).unwrap()); + } + + fn assert_nullability(expr: Arc, schema: &Schema, nullable: bool) { + if nullable { + assert_nullable(expr, schema); + } else { + assert_not_nullable(expr, schema); + } + } } From 045fc9c4676ed8e2183fc5e4c9394245bfcab5b8 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 29 Sep 2025 19:10:02 +0200 Subject: [PATCH 02/23] #17801 Clarify logical expression test cases --- datafusion/expr/src/expr_schema.rs | 212 +++++++++++------------------ 1 file changed, 76 insertions(+), 136 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 553882619252..8bc92cb2b84a 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -28,6 +28,7 @@ use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::tree_node::TreeNode; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, @@ -827,10 +828,7 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result MockExprSchema, - ) -> Result<()> { + fn assert_nullability(expr: &Expr, schema: &dyn ExprSchema, nullable: bool) { assert_eq!( - expr.nullable(&get_schema(true))?, + expr.nullable(schema).unwrap(), nullable, - "Nullability of '{expr}' should be {nullable} when column is nullable" - ); - assert!( - !expr.nullable(&get_schema(false))?, - "Nullability of '{expr}' should be false when column is not nullable" + "Nullability of '{expr}' should be {nullable}" ); - Ok(()) + } + + fn assert_not_nullable(expr: &Expr, schema: &dyn ExprSchema) { + assert_nullability(expr, schema, false); + } + + fn assert_nullable(expr: &Expr, schema: &dyn ExprSchema) { + assert_nullability(expr, schema, true); } #[test] fn test_case_expression_nullability() -> Result<()> { - let get_schema = |nullable| { - MockExprSchema::new() - .with_data_type(DataType::Int32) - .with_nullable(nullable) - }; + let nullable_schema = MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_nullable(true); - check_nullability( - when(is_not_null(col("foo")), col("foo")).otherwise(lit(0))?, - false, - get_schema, - )?; - - check_nullability( - when(not(is_null(col("foo"))), col("foo")).otherwise(lit(0))?, - false, - get_schema, - )?; - - check_nullability( - when(binary_expr(col("foo"), Operator::Eq, lit(5)), col("foo")) - .otherwise(lit(0))?, - true, - get_schema, - )?; - - check_nullability( - when( - and( - is_not_null(col("foo")), - binary_expr(col("foo"), Operator::Eq, lit(5)), - ), - col("foo"), - ) - .otherwise(lit(0))?, - false, - get_schema, - )?; - - check_nullability( - when( - and( - binary_expr(col("foo"), Operator::Eq, lit(5)), - is_not_null(col("foo")), - ), - col("foo"), - ) - .otherwise(lit(0))?, - false, - get_schema, - )?; - - check_nullability( - when( - or( - is_not_null(col("foo")), - binary_expr(col("foo"), Operator::Eq, lit(5)), - ), - col("foo"), - ) - .otherwise(lit(0))?, - true, - get_schema, - )?; - - check_nullability( - when( - or( - binary_expr(col("foo"), Operator::Eq, lit(5)), - is_not_null(col("foo")), - ), - col("foo"), - ) - .otherwise(lit(0))?, - true, - get_schema, - )?; - - check_nullability( - when( - or( - is_not_null(col("foo")), - binary_expr(col("foo"), Operator::Eq, lit(5)), - ), - col("foo"), - ) - .otherwise(lit(0))?, - true, - get_schema, - )?; - - check_nullability( - when( - or( - binary_expr(col("foo"), Operator::Eq, lit(5)), - is_not_null(col("foo")), - ), - col("foo"), - ) - .otherwise(lit(0))?, - true, - get_schema, - )?; - - check_nullability( - when( - or( - and( - binary_expr(col("foo"), Operator::Eq, lit(5)), - is_not_null(col("foo")), - ), - and( - binary_expr(col("foo"), Operator::Eq, col("bar")), - is_not_null(col("foo")), - ), - ), - col("foo"), - ) - .otherwise(lit(0))?, - false, - get_schema, - )?; + let not_nullable_schema = MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_nullable(false); + + // CASE WHEN x IS NOT NULL THEN x ELSE 0 + let e1 = when(col("x").is_not_null(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e1, &nullable_schema); + assert_not_nullable(&e1, ¬_nullable_schema); + + // CASE WHEN NOT x IS NULL THEN x ELSE 0 + let e2 = when(not(col("x").is_null()), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e2, &nullable_schema); + assert_not_nullable(&e2, ¬_nullable_schema); + + // CASE WHEN X = 5 THEN x ELSE 0 + let e3 = when(col("x").eq(lit(5)), col("x")).otherwise(lit(0))?; + assert_nullable(&e3, &nullable_schema); + assert_not_nullable(&e3, ¬_nullable_schema); + + // CASE WHEN x IS NOT NULL AND x = 5 THEN x ELSE 0 + let e4 = when(and(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e4, &nullable_schema); + assert_not_nullable(&e4, ¬_nullable_schema); + + // CASE WHEN x = 5 AND x IS NOT NULL THEN x ELSE 0 + let e5 = when(and(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e5, &nullable_schema); + assert_not_nullable(&e5, ¬_nullable_schema); + + // CASE WHEN x IS NOT NULL OR x = 5 THEN x ELSE 0 + let e6 = when(or(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) + .otherwise(lit(0))?; + assert_nullable(&e6, &nullable_schema); + assert_not_nullable(&e6, ¬_nullable_schema); + + // CASE WHEN x = 5 OR x IS NOT NULL THEN x ELSE 0 + let e7 = when(or(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) + .otherwise(lit(0))?; + assert_nullable(&e7, &nullable_schema); + assert_not_nullable(&e7, ¬_nullable_schema); + + // CASE WHEN (x = 5 AND x IS NOT NULL) OR (x = bar AND x IS NOT NULL) THEN x ELSE 0 + let e8 = when( + or( + and(col("x").eq(lit(5)), col("x").is_not_null()), + and(col("x").eq(col("bar")), col("x").is_not_null()), + ), + col("x"), + ) + .otherwise(lit(0))?; + assert_not_nullable(&e8, &nullable_schema); + assert_not_nullable(&e8, ¬_nullable_schema); + + // CASE WHEN x = 5 OR x IS NULL THEN x ELSE 0 + let e9 = when(or(col("x").eq(lit(5)), col("x").is_null()), col("x")) + .otherwise(lit(0))?; + assert_nullable(&e9, &nullable_schema); + assert_not_nullable(&e9, ¬_nullable_schema); Ok(()) } From de8b780a1702a17a76f724b06bbabdc5dcce1401 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 30 Sep 2025 15:35:18 +0200 Subject: [PATCH 03/23] #17801 Attempt to clarify const evaluation logic --- datafusion/expr/src/expr_schema.rs | 87 +++++++++++++++++++----------- 1 file changed, 56 insertions(+), 31 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 8bc92cb2b84a..6c9dda23c318 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -28,11 +28,7 @@ use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::tree_node::TreeNode; -use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, - Result, Spans, TableReference, -}; +use datafusion_common::{not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, ScalarValue, Spans, TableReference}; use datafusion_expr_common::operator::Operator; use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -281,18 +277,32 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(field, _) => Ok(field.is_nullable()), Expr::Literal(value, _) => Ok(value.is_null()), Expr::Case(case) => { - // This expression is nullable if any of the input expressions are nullable + // This expression is nullable if any of the then expressions are nullable let then_nullable = case .when_then_expr .iter() - .filter(|(w, t)| { - // Disregard branches where we can determine statically that the predicate - // is always false when the then expression would evaluate to null - const_result_when_value_is_null(w, t).unwrap_or(true) + .filter_map(|(w, t)| { + match t.nullable(input_schema) { + // Branches with a then expressions that is not nullable can be skipped + Ok(false) => None, + // Pass error determining nullability on verbatim + err @ Err(_) => Some(err), + // For branches with a nullable then expressions try to determine + // using limited const evaluation if the branch will be taken when + // the then expression evaluates to null. + Ok(true) => match const_result_when_value_is_null(w, t, input_schema) { + // Const evaluation was inconclusive or determined the branch would + // be taken + None | Some(true) => Some(Ok(true)), + // Const evaluation proves the branch will never be taken. + // The most common pattern for this is + // `WHEN x IS NOT NULL THEN x`. + Some(false) => None, + }, + } }) - .map(|(_, t)| t.nullable(input_schema)) .collect::>>()?; - if then_nullable.contains(&true) { + if !then_nullable.is_empty() { Ok(true) } else if let Some(e) = &case.else_expr { e.nullable(input_schema) @@ -656,27 +666,23 @@ impl ExprSchemable for Expr { /// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`. /// Returns a `Some` value containing the const result if so; otherwise returns `None`. -fn const_result_when_value_is_null(predicate: &Expr, value: &Expr) -> Option { +fn const_result_when_value_is_null(predicate: &Expr, value: &Expr, input_schema: &dyn ExprSchema) -> Option { match predicate { Expr::IsNotNull(e) => { - if e.as_ref().eq(value) { - Some(false) - } else { - None - } + // If `e` is null, then `e IS NOT NULL` is false + // If `e` is not null, then `e IS NOT NULL` is true + is_null(e, value, input_schema).map(|is_null| !is_null) } Expr::IsNull(e) => { - if e.as_ref().eq(value) { - Some(true) - } else { - None - } + // If `e` is null, then `e IS NULL` is true + // If `e` is not null, then `e IS NULL` is false + is_null(e, value, input_schema) } - Expr::Not(e) => const_result_when_value_is_null(e, value).map(|b| !b), + Expr::Not(e) => const_result_when_value_is_null(e, value, input_schema).map(|b| !b), Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { Operator::And => { - let l = const_result_when_value_is_null(left, value); - let r = const_result_when_value_is_null(right, value); + let l = const_result_when_value_is_null(left, value, input_schema); + let r = const_result_when_value_is_null(right, value, input_schema); match (l, r) { (Some(l), Some(r)) => Some(l && r), (Some(l), None) => Some(l), @@ -685,8 +691,8 @@ fn const_result_when_value_is_null(predicate: &Expr, value: &Expr) -> Option { - let l = const_result_when_value_is_null(left, value); - let r = const_result_when_value_is_null(right, value); + let l = const_result_when_value_is_null(left, value, input_schema); + let r = const_result_when_value_is_null(right, value, input_schema); match (l, r) { (Some(l), Some(r)) => Some(l || r), _ => None, @@ -698,6 +704,25 @@ fn const_result_when_value_is_null(predicate: &Expr, value: &Expr) -> Option Option { + // We're assuming `value` is null + if expr.eq(value) { + return Some(true); + } + + match expr { + // Literal null is obviously null + Expr::Literal(ScalarValue::Null, _) => Some(true), + // We're assuming `value` is null + _ => match expr.nullable(input_schema) { + // If `expr` is not nullable, we can be certain `expr` is not null + Ok(false) => Some(false), + // Otherwise inconclusive + _ => None, + } + } +} + impl Expr { /// Common method for window functions that applies type coercion /// to all arguments of the window function to check if it matches @@ -881,11 +906,11 @@ mod tests { assert!(expr.nullable(&get_schema(false)).unwrap()); } - fn assert_nullability(expr: &Expr, schema: &dyn ExprSchema, nullable: bool) { + fn assert_nullability(expr: &Expr, schema: &dyn ExprSchema, expected: bool) { assert_eq!( expr.nullable(schema).unwrap(), - nullable, - "Nullability of '{expr}' should be {nullable}" + expected, + "Nullability of '{expr}' should be {expected}" ); } From bbd29490a5f12cf526cd38a99a47fc53eb18fd6e Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Tue, 30 Sep 2025 20:27:15 +0200 Subject: [PATCH 04/23] #17801 Extend predicate const evaluation --- datafusion/expr/src/expr_schema.rs | 390 ++++++++++++++++++++++------- 1 file changed, 301 insertions(+), 89 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 6c9dda23c318..ec2344f962f2 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -28,7 +28,10 @@ use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::{not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, ScalarValue, Spans, TableReference}; +use datafusion_common::{ + not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, + Result, ScalarValue, Spans, TableReference, +}; use datafusion_expr_common::operator::Operator; use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -278,31 +281,43 @@ impl ExprSchemable for Expr { Expr::Literal(value, _) => Ok(value.is_null()), Expr::Case(case) => { // This expression is nullable if any of the then expressions are nullable - let then_nullable = case + let any_nullable_thens = !case .when_then_expr .iter() .filter_map(|(w, t)| { match t.nullable(input_schema) { - // Branches with a then expressions that is not nullable can be skipped + // Branches with a then expression that is not nullable can be skipped Ok(false) => None, // Pass error determining nullability on verbatim - err @ Err(_) => Some(err), + Err(e) => Some(Err(e)), // For branches with a nullable then expressions try to determine // using limited const evaluation if the branch will be taken when // the then expression evaluates to null. - Ok(true) => match const_result_when_value_is_null(w, t, input_schema) { - // Const evaluation was inconclusive or determined the branch would - // be taken - None | Some(true) => Some(Ok(true)), - // Const evaluation proves the branch will never be taken. - // The most common pattern for this is - // `WHEN x IS NOT NULL THEN x`. - Some(false) => None, - }, + Ok(true) => { + let const_result = WhenThenConstEvaluator { + then_expr: t, + input_schema, + } + .const_eval_predicate(w); + + match const_result { + // Const evaluation was inconclusive or determined the branch + // would be taken + None | Some(TriStateBool::True) => Some(Ok(())), + // Const evaluation proves the branch will never be taken. + // The most common pattern for this is + // `WHEN x IS NOT NULL THEN x`. + Some(TriStateBool::False) + | Some(TriStateBool::Uncertain) => None, + } + } } }) - .collect::>>()?; - if !then_nullable.is_empty() { + .collect::>>()? + .is_empty(); + + if any_nullable_thens { + // There is at least one reachable nullable then Ok(true) } else if let Some(e) = &case.else_expr { e.nullable(input_schema) @@ -664,61 +679,223 @@ impl ExprSchemable for Expr { } } -/// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`. -/// Returns a `Some` value containing the const result if so; otherwise returns `None`. -fn const_result_when_value_is_null(predicate: &Expr, value: &Expr, input_schema: &dyn ExprSchema) -> Option { - match predicate { - Expr::IsNotNull(e) => { - // If `e` is null, then `e IS NOT NULL` is false - // If `e` is not null, then `e IS NOT NULL` is true - is_null(e, value, input_schema).map(|is_null| !is_null) - } - Expr::IsNull(e) => { - // If `e` is null, then `e IS NULL` is true - // If `e` is not null, then `e IS NULL` is false - is_null(e, value, input_schema) - } - Expr::Not(e) => const_result_when_value_is_null(e, value, input_schema).map(|b| !b), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::And => { - let l = const_result_when_value_is_null(left, value, input_schema); - let r = const_result_when_value_is_null(right, value, input_schema); - match (l, r) { - (Some(l), Some(r)) => Some(l && r), - (Some(l), None) => Some(l), - (None, Some(r)) => Some(r), +enum TriStateBool { + True, + False, + Uncertain, +} + +struct WhenThenConstEvaluator<'a> { + then_expr: &'a Expr, + input_schema: &'a dyn ExprSchema, +} + +impl WhenThenConstEvaluator<'_> { + /// Attempts to const evaluate the given predicate. + /// Returns a `Some` value containing the const result if so; otherwise returns `None`. + fn const_eval_predicate(&self, predicate: &Expr) -> Option { + match predicate { + // Literal null is equivalent to boolean uncertain + Expr::Literal(ScalarValue::Null, _) => Some(TriStateBool::Uncertain), + Expr::IsNotNull(e) => { + if let Ok(false) = e.nullable(self.input_schema) { + // If `e` is not nullable, then `e IS NOT NULL` is always true + return Some(TriStateBool::True); + } + + match e.get_type(self.input_schema) { + Ok(DataType::Boolean) => match self.const_eval_predicate(e) { + Some(TriStateBool::True) | Some(TriStateBool::False) => { + Some(TriStateBool::True) + } + Some(TriStateBool::Uncertain) => Some(TriStateBool::False), + None => None, + }, + Ok(_) => match self.is_null(e) { + Some(true) => Some(TriStateBool::False), + Some(false) => Some(TriStateBool::True), + None => None, + }, + Err(_) => None, + } + } + Expr::IsNull(e) => { + if let Ok(false) = e.nullable(self.input_schema) { + // If `e` is not nullable, then `e IS NULL` is always false + return Some(TriStateBool::False); + } + + match e.get_type(self.input_schema) { + Ok(DataType::Boolean) => match self.const_eval_predicate(e) { + Some(TriStateBool::True) | Some(TriStateBool::False) => { + Some(TriStateBool::False) + } + Some(TriStateBool::Uncertain) => Some(TriStateBool::True), + None => None, + }, + Ok(_) => match self.is_null(e) { + Some(true) => Some(TriStateBool::True), + Some(false) => Some(TriStateBool::False), + None => None, + }, + Err(_) => None, + } + } + Expr::IsTrue(e) => match self.const_eval_predicate(e) { + Some(TriStateBool::True) => Some(TriStateBool::True), + Some(_) => Some(TriStateBool::False), + _ => None, + }, + Expr::IsNotTrue(e) => match self.const_eval_predicate(e) { + Some(TriStateBool::True) => Some(TriStateBool::False), + Some(_) => Some(TriStateBool::True), + _ => None, + }, + Expr::IsFalse(e) => match self.const_eval_predicate(e) { + Some(TriStateBool::False) => Some(TriStateBool::True), + Some(_) => Some(TriStateBool::False), + _ => None, + }, + Expr::IsNotFalse(e) => match self.const_eval_predicate(e) { + Some(TriStateBool::False) => Some(TriStateBool::False), + Some(_) => Some(TriStateBool::True), + _ => None, + }, + Expr::IsUnknown(e) => match self.const_eval_predicate(e) { + Some(TriStateBool::Uncertain) => Some(TriStateBool::True), + Some(_) => Some(TriStateBool::False), + _ => None, + }, + Expr::IsNotUnknown(e) => match self.const_eval_predicate(e) { + Some(TriStateBool::Uncertain) => Some(TriStateBool::False), + Some(_) => Some(TriStateBool::True), + _ => None, + }, + Expr::Like(Like { expr, pattern, .. }) => { + match (self.is_null(expr), self.is_null(pattern)) { + (Some(true), _) | (_, Some(true)) => Some(TriStateBool::Uncertain), _ => None, } } - Operator::Or => { - let l = const_result_when_value_is_null(left, value, input_schema); - let r = const_result_when_value_is_null(right, value, input_schema); - match (l, r) { - (Some(l), Some(r)) => Some(l || r), + Expr::SimilarTo(Like { expr, pattern, .. }) => { + match (self.is_null(expr), self.is_null(pattern)) { + (Some(true), _) | (_, Some(true)) => Some(TriStateBool::Uncertain), _ => None, } } - _ => None, - }, - _ => None, - } -} - -fn is_null(expr: &Expr, value: &Expr, input_schema: &dyn ExprSchema) -> Option { - // We're assuming `value` is null - if expr.eq(value) { - return Some(true); + Expr::Between(Between { + expr, low, high, .. + }) => match (self.is_null(expr), self.is_null(low), self.is_null(high)) { + (Some(true), _, _) | (_, Some(true), _) | (_, _, Some(true)) => { + Some(TriStateBool::Uncertain) + } + _ => None, + }, + Expr::Not(e) => match self.const_eval_predicate(e) { + Some(TriStateBool::True) => Some(TriStateBool::False), + Some(TriStateBool::False) => Some(TriStateBool::True), + Some(TriStateBool::Uncertain) => Some(TriStateBool::Uncertain), + None => None, + }, + Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Operator::And => { + match ( + self.const_eval_predicate(left), + self.const_eval_predicate(right), + ) { + (Some(TriStateBool::False), _) + | (_, Some(TriStateBool::False)) => Some(TriStateBool::False), + (Some(TriStateBool::True), Some(TriStateBool::True)) => { + Some(TriStateBool::True) + } + (Some(TriStateBool::Uncertain), Some(_)) + | (Some(_), Some(TriStateBool::Uncertain)) => { + Some(TriStateBool::Uncertain) + } + _ => None, + } + } + Operator::Or => { + match ( + self.const_eval_predicate(left), + self.const_eval_predicate(right), + ) { + (Some(TriStateBool::True), _) | (_, Some(TriStateBool::True)) => { + Some(TriStateBool::True) + } + (Some(TriStateBool::False), Some(TriStateBool::False)) => { + Some(TriStateBool::False) + } + (Some(TriStateBool::Uncertain), Some(_)) + | (Some(_), Some(TriStateBool::Uncertain)) => { + Some(TriStateBool::Uncertain) + } + _ => None, + } + } + _ => match (self.is_null(left), self.is_null(right)) { + (Some(true), _) | (_, Some(true)) => Some(TriStateBool::Uncertain), + _ => None, + }, + }, + e => match self.is_null(e) { + Some(true) => Some(TriStateBool::Uncertain), + _ => None, + }, + } } - match expr { - // Literal null is obviously null - Expr::Literal(ScalarValue::Null, _) => Some(true), - // We're assuming `value` is null - _ => match expr.nullable(input_schema) { - // If `expr` is not nullable, we can be certain `expr` is not null - Ok(false) => Some(false), - // Otherwise inconclusive - _ => None, + /// Determines if the given expression is null. + /// + /// This function returns: + /// - `Some(true)` is `expr` is certainly null + /// - `Some(false)` is `expr` can certainly not be null + /// - `None` if the result is inconclusive + fn is_null(&self, expr: &Expr) -> Option { + match expr { + // Literal null is obviously null + Expr::Literal(ScalarValue::Null, _) => Some(true), + Expr::Negative(e) => self.is_null(e), + Expr::Like(Like { expr, pattern, .. }) => { + match (self.is_null(expr), self.is_null(pattern)) { + (Some(true), _) | (_, Some(true)) => Some(true), + _ => None, + } + } + Expr::SimilarTo(Like { expr, pattern, .. }) => { + match (self.is_null(expr), self.is_null(pattern)) { + (Some(true), _) | (_, Some(true)) => Some(true), + _ => None, + } + } + Expr::Not(e) => self.is_null(e), + Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { + match (self.is_null(left), self.is_null(right)) { + (Some(true), _) | (_, Some(true)) => Some(true), + _ => None, + } + } + Expr::Between(Between { + expr, low, high, .. + }) => match (self.is_null(expr), self.is_null(low), self.is_null(high)) { + (Some(true), _, _) | (_, Some(true), _) | (_, _, Some(true)) => { + Some(true) + } + _ => None, + }, + e => { + if e.eq(self.then_expr) { + // Evaluation occurs under the assumption that `then_expr` evaluates to null + Some(true) + } else { + match expr.nullable(self.input_schema) { + // If `expr` is not nullable, we can be certain `expr` is not null + Ok(false) => Some(false), + // Otherwise inconclusive + _ => None, + } + } + } } } } @@ -933,46 +1110,46 @@ mod tests { .with_nullable(false); // CASE WHEN x IS NOT NULL THEN x ELSE 0 - let e1 = when(col("x").is_not_null(), col("x")).otherwise(lit(0))?; - assert_not_nullable(&e1, &nullable_schema); - assert_not_nullable(&e1, ¬_nullable_schema); + let e = when(col("x").is_not_null(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); // CASE WHEN NOT x IS NULL THEN x ELSE 0 - let e2 = when(not(col("x").is_null()), col("x")).otherwise(lit(0))?; - assert_not_nullable(&e2, &nullable_schema); - assert_not_nullable(&e2, ¬_nullable_schema); + let e = when(not(col("x").is_null()), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); // CASE WHEN X = 5 THEN x ELSE 0 - let e3 = when(col("x").eq(lit(5)), col("x")).otherwise(lit(0))?; - assert_nullable(&e3, &nullable_schema); - assert_not_nullable(&e3, ¬_nullable_schema); + let e = when(col("x").eq(lit(5)), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); // CASE WHEN x IS NOT NULL AND x = 5 THEN x ELSE 0 - let e4 = when(and(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) + let e = when(and(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) .otherwise(lit(0))?; - assert_not_nullable(&e4, &nullable_schema); - assert_not_nullable(&e4, ¬_nullable_schema); + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); // CASE WHEN x = 5 AND x IS NOT NULL THEN x ELSE 0 - let e5 = when(and(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) + let e = when(and(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) .otherwise(lit(0))?; - assert_not_nullable(&e5, &nullable_schema); - assert_not_nullable(&e5, ¬_nullable_schema); + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); // CASE WHEN x IS NOT NULL OR x = 5 THEN x ELSE 0 - let e6 = when(or(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) + let e = when(or(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) .otherwise(lit(0))?; - assert_nullable(&e6, &nullable_schema); - assert_not_nullable(&e6, ¬_nullable_schema); + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); // CASE WHEN x = 5 OR x IS NOT NULL THEN x ELSE 0 - let e7 = when(or(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) + let e = when(or(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) .otherwise(lit(0))?; - assert_nullable(&e7, &nullable_schema); - assert_not_nullable(&e7, ¬_nullable_schema); + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); // CASE WHEN (x = 5 AND x IS NOT NULL) OR (x = bar AND x IS NOT NULL) THEN x ELSE 0 - let e8 = when( + let e = when( or( and(col("x").eq(lit(5)), col("x").is_not_null()), and(col("x").eq(col("bar")), col("x").is_not_null()), @@ -980,14 +1157,49 @@ mod tests { col("x"), ) .otherwise(lit(0))?; - assert_not_nullable(&e8, &nullable_schema); - assert_not_nullable(&e8, ¬_nullable_schema); + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); // CASE WHEN x = 5 OR x IS NULL THEN x ELSE 0 - let e9 = when(or(col("x").eq(lit(5)), col("x").is_null()), col("x")) + let e = when(or(col("x").eq(lit(5)), col("x").is_null()), col("x")) .otherwise(lit(0))?; - assert_nullable(&e9, &nullable_schema); - assert_not_nullable(&e9, ¬_nullable_schema); + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS TRUE THEN x ELSE 0 + let e = when(col("x").is_true(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT TRUE THEN x ELSE 0 + let e = when(col("x").is_not_true(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS FALSE THEN x ELSE 0 + let e = when(col("x").is_false(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT FALSE THEN x ELSE 0 + let e = when(col("x").is_not_false(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS UNKNOWN THEN x ELSE 0 + let e = when(col("x").is_unknown(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT UNKNOWN THEN x ELSE 0 + let e = when(col("x").is_not_unknown(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x LIKE 'x' THEN x ELSE 0 + let e = when(col("x").like(lit("x")), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); Ok(()) } From 2075f4b8127ff233d788d7453b997a111ae0e0d9 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Wed, 1 Oct 2025 11:45:36 +0200 Subject: [PATCH 05/23] #17801 Correctly report nullability of implicit casts in predicates --- datafusion/expr/src/expr_schema.rs | 33 ++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index ec2344f962f2..9eeb2c723be3 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -679,12 +679,31 @@ impl ExprSchemable for Expr { } } +/// Represents the possible values for SQL's three valued logic. +/// `Option` is not used for this since `None` is used to represent +/// inconclusive answers already. enum TriStateBool { True, False, Uncertain, } +impl TryFrom<&ScalarValue> for TriStateBool { + type Error = DataFusionError; + + fn try_from(value: &ScalarValue) -> std::result::Result { + match value { + ScalarValue::Null => Ok(TriStateBool::Uncertain), + ScalarValue::Boolean(b) => Ok(match b { + None => TriStateBool::Uncertain, + Some(true) => TriStateBool::True, + Some(false) => TriStateBool::False, + }), + _ => Self::try_from(&value.cast_to(&DataType::Boolean)?) + } + } +} + struct WhenThenConstEvaluator<'a> { then_expr: &'a Expr, input_schema: &'a dyn ExprSchema, @@ -696,7 +715,7 @@ impl WhenThenConstEvaluator<'_> { fn const_eval_predicate(&self, predicate: &Expr) -> Option { match predicate { // Literal null is equivalent to boolean uncertain - Expr::Literal(ScalarValue::Null, _) => Some(TriStateBool::Uncertain), + Expr::Literal(scalar, _) => TriStateBool::try_from(scalar).ok(), Expr::IsNotNull(e) => { if let Ok(false) = e.nullable(self.input_schema) { // If `e` is not nullable, then `e IS NOT NULL` is always true @@ -845,7 +864,7 @@ impl WhenThenConstEvaluator<'_> { } } - /// Determines if the given expression is null. + /// Determines if the given expression evaluates to null. /// /// This function returns: /// - `Some(true)` is `expr` is certainly null @@ -1201,6 +1220,16 @@ mod tests { assert_not_nullable(&e, &nullable_schema); assert_not_nullable(&e, ¬_nullable_schema); + // CASE WHEN 0 THEN x ELSE 0 + let e = when(lit(0), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN 1 THEN x ELSE 0 + let e = when(lit(1), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + Ok(()) } From 8c87937efbe1f47339900eb5462912185fffcc01 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 6 Oct 2025 18:42:29 +0200 Subject: [PATCH 06/23] #17801 Code formatting --- datafusion/expr/src/expr_schema.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 9eeb2c723be3..31467dff58e2 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -699,7 +699,7 @@ impl TryFrom<&ScalarValue> for TriStateBool { Some(true) => TriStateBool::True, Some(false) => TriStateBool::False, }), - _ => Self::try_from(&value.cast_to(&DataType::Boolean)?) + _ => Self::try_from(&value.cast_to(&DataType::Boolean)?), } } } From ac4267c7fe8831e750a4e0d7df4c82f3e853daeb Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 9 Oct 2025 08:46:30 +0200 Subject: [PATCH 07/23] Add comment explaining why the logical plan optimizer is triggered Co-authored-by: Andrew Lamb --- datafusion/core/tests/tpcds_planning.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index bee3b48a574b..09efc7cff685 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1055,6 +1055,8 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { if create_physical { let _ = state.create_physical_plan(&plan).await?; } else { + // Run the logical optimizer even if we are not creating the physical plan + // to ensure it will properly succeed let _ = state.optimize(&plan)?; } } From 101db28290f4ea39ff388547fe805843a9242cdb Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 9 Oct 2025 19:43:15 +0200 Subject: [PATCH 08/23] Simplify predicate eval code --- datafusion/expr/src/expr_schema.rs | 254 +------------------------- datafusion/expr/src/lib.rs | 1 + datafusion/expr/src/predicate_eval.rs | 234 ++++++++++++++++++++++++ 3 files changed, 241 insertions(+), 248 deletions(-) create mode 100644 datafusion/expr/src/predicate_eval.rs diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 31467dff58e2..b69d9f50d36d 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -use super::{Between, Expr, Like}; +use super::{predicate_eval, Between, Expr, Like}; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, FieldMetadata, InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; +use crate::predicate_eval::TriStateBool; use crate::type_coercion::functions::{ data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, }; @@ -30,9 +31,8 @@ use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, - Result, ScalarValue, Spans, TableReference, + Result, Spans, TableReference, }; -use datafusion_expr_common::operator::Operator; use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use std::sync::Arc; @@ -294,11 +294,9 @@ impl ExprSchemable for Expr { // using limited const evaluation if the branch will be taken when // the then expression evaluates to null. Ok(true) => { - let const_result = WhenThenConstEvaluator { - then_expr: t, - input_schema, - } - .const_eval_predicate(w); + let const_result = predicate_eval::const_eval_predicate(w, input_schema, |expr| { + if expr.eq(t) { TriStateBool::True } else { TriStateBool::Uncertain } + }); match const_result { // Const evaluation was inconclusive or determined the branch @@ -679,246 +677,6 @@ impl ExprSchemable for Expr { } } -/// Represents the possible values for SQL's three valued logic. -/// `Option` is not used for this since `None` is used to represent -/// inconclusive answers already. -enum TriStateBool { - True, - False, - Uncertain, -} - -impl TryFrom<&ScalarValue> for TriStateBool { - type Error = DataFusionError; - - fn try_from(value: &ScalarValue) -> std::result::Result { - match value { - ScalarValue::Null => Ok(TriStateBool::Uncertain), - ScalarValue::Boolean(b) => Ok(match b { - None => TriStateBool::Uncertain, - Some(true) => TriStateBool::True, - Some(false) => TriStateBool::False, - }), - _ => Self::try_from(&value.cast_to(&DataType::Boolean)?), - } - } -} - -struct WhenThenConstEvaluator<'a> { - then_expr: &'a Expr, - input_schema: &'a dyn ExprSchema, -} - -impl WhenThenConstEvaluator<'_> { - /// Attempts to const evaluate the given predicate. - /// Returns a `Some` value containing the const result if so; otherwise returns `None`. - fn const_eval_predicate(&self, predicate: &Expr) -> Option { - match predicate { - // Literal null is equivalent to boolean uncertain - Expr::Literal(scalar, _) => TriStateBool::try_from(scalar).ok(), - Expr::IsNotNull(e) => { - if let Ok(false) = e.nullable(self.input_schema) { - // If `e` is not nullable, then `e IS NOT NULL` is always true - return Some(TriStateBool::True); - } - - match e.get_type(self.input_schema) { - Ok(DataType::Boolean) => match self.const_eval_predicate(e) { - Some(TriStateBool::True) | Some(TriStateBool::False) => { - Some(TriStateBool::True) - } - Some(TriStateBool::Uncertain) => Some(TriStateBool::False), - None => None, - }, - Ok(_) => match self.is_null(e) { - Some(true) => Some(TriStateBool::False), - Some(false) => Some(TriStateBool::True), - None => None, - }, - Err(_) => None, - } - } - Expr::IsNull(e) => { - if let Ok(false) = e.nullable(self.input_schema) { - // If `e` is not nullable, then `e IS NULL` is always false - return Some(TriStateBool::False); - } - - match e.get_type(self.input_schema) { - Ok(DataType::Boolean) => match self.const_eval_predicate(e) { - Some(TriStateBool::True) | Some(TriStateBool::False) => { - Some(TriStateBool::False) - } - Some(TriStateBool::Uncertain) => Some(TriStateBool::True), - None => None, - }, - Ok(_) => match self.is_null(e) { - Some(true) => Some(TriStateBool::True), - Some(false) => Some(TriStateBool::False), - None => None, - }, - Err(_) => None, - } - } - Expr::IsTrue(e) => match self.const_eval_predicate(e) { - Some(TriStateBool::True) => Some(TriStateBool::True), - Some(_) => Some(TriStateBool::False), - _ => None, - }, - Expr::IsNotTrue(e) => match self.const_eval_predicate(e) { - Some(TriStateBool::True) => Some(TriStateBool::False), - Some(_) => Some(TriStateBool::True), - _ => None, - }, - Expr::IsFalse(e) => match self.const_eval_predicate(e) { - Some(TriStateBool::False) => Some(TriStateBool::True), - Some(_) => Some(TriStateBool::False), - _ => None, - }, - Expr::IsNotFalse(e) => match self.const_eval_predicate(e) { - Some(TriStateBool::False) => Some(TriStateBool::False), - Some(_) => Some(TriStateBool::True), - _ => None, - }, - Expr::IsUnknown(e) => match self.const_eval_predicate(e) { - Some(TriStateBool::Uncertain) => Some(TriStateBool::True), - Some(_) => Some(TriStateBool::False), - _ => None, - }, - Expr::IsNotUnknown(e) => match self.const_eval_predicate(e) { - Some(TriStateBool::Uncertain) => Some(TriStateBool::False), - Some(_) => Some(TriStateBool::True), - _ => None, - }, - Expr::Like(Like { expr, pattern, .. }) => { - match (self.is_null(expr), self.is_null(pattern)) { - (Some(true), _) | (_, Some(true)) => Some(TriStateBool::Uncertain), - _ => None, - } - } - Expr::SimilarTo(Like { expr, pattern, .. }) => { - match (self.is_null(expr), self.is_null(pattern)) { - (Some(true), _) | (_, Some(true)) => Some(TriStateBool::Uncertain), - _ => None, - } - } - Expr::Between(Between { - expr, low, high, .. - }) => match (self.is_null(expr), self.is_null(low), self.is_null(high)) { - (Some(true), _, _) | (_, Some(true), _) | (_, _, Some(true)) => { - Some(TriStateBool::Uncertain) - } - _ => None, - }, - Expr::Not(e) => match self.const_eval_predicate(e) { - Some(TriStateBool::True) => Some(TriStateBool::False), - Some(TriStateBool::False) => Some(TriStateBool::True), - Some(TriStateBool::Uncertain) => Some(TriStateBool::Uncertain), - None => None, - }, - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::And => { - match ( - self.const_eval_predicate(left), - self.const_eval_predicate(right), - ) { - (Some(TriStateBool::False), _) - | (_, Some(TriStateBool::False)) => Some(TriStateBool::False), - (Some(TriStateBool::True), Some(TriStateBool::True)) => { - Some(TriStateBool::True) - } - (Some(TriStateBool::Uncertain), Some(_)) - | (Some(_), Some(TriStateBool::Uncertain)) => { - Some(TriStateBool::Uncertain) - } - _ => None, - } - } - Operator::Or => { - match ( - self.const_eval_predicate(left), - self.const_eval_predicate(right), - ) { - (Some(TriStateBool::True), _) | (_, Some(TriStateBool::True)) => { - Some(TriStateBool::True) - } - (Some(TriStateBool::False), Some(TriStateBool::False)) => { - Some(TriStateBool::False) - } - (Some(TriStateBool::Uncertain), Some(_)) - | (Some(_), Some(TriStateBool::Uncertain)) => { - Some(TriStateBool::Uncertain) - } - _ => None, - } - } - _ => match (self.is_null(left), self.is_null(right)) { - (Some(true), _) | (_, Some(true)) => Some(TriStateBool::Uncertain), - _ => None, - }, - }, - e => match self.is_null(e) { - Some(true) => Some(TriStateBool::Uncertain), - _ => None, - }, - } - } - - /// Determines if the given expression evaluates to null. - /// - /// This function returns: - /// - `Some(true)` is `expr` is certainly null - /// - `Some(false)` is `expr` can certainly not be null - /// - `None` if the result is inconclusive - fn is_null(&self, expr: &Expr) -> Option { - match expr { - // Literal null is obviously null - Expr::Literal(ScalarValue::Null, _) => Some(true), - Expr::Negative(e) => self.is_null(e), - Expr::Like(Like { expr, pattern, .. }) => { - match (self.is_null(expr), self.is_null(pattern)) { - (Some(true), _) | (_, Some(true)) => Some(true), - _ => None, - } - } - Expr::SimilarTo(Like { expr, pattern, .. }) => { - match (self.is_null(expr), self.is_null(pattern)) { - (Some(true), _) | (_, Some(true)) => Some(true), - _ => None, - } - } - Expr::Not(e) => self.is_null(e), - Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - match (self.is_null(left), self.is_null(right)) { - (Some(true), _) | (_, Some(true)) => Some(true), - _ => None, - } - } - Expr::Between(Between { - expr, low, high, .. - }) => match (self.is_null(expr), self.is_null(low), self.is_null(high)) { - (Some(true), _, _) | (_, Some(true), _) | (_, _, Some(true)) => { - Some(true) - } - _ => None, - }, - e => { - if e.eq(self.then_expr) { - // Evaluation occurs under the assumption that `then_expr` evaluates to null - Some(true) - } else { - match expr.nullable(self.input_schema) { - // If `expr` is not nullable, we can be certain `expr` is not null - Ok(false) => Some(false), - // Otherwise inconclusive - _ => None, - } - } - } - } - } -} - impl Expr { /// Common method for window functions that applies type coercion /// to all arguments of the window function to check if it matches diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 1c9734a89bd3..71a5d0ad60d1 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -78,6 +78,7 @@ pub mod utils; pub mod var_provider; pub mod window_frame; pub mod window_state; +mod predicate_eval; pub use datafusion_doc::{ aggregate_doc_sections, scalar_doc_sections, window_doc_sections, DocSection, diff --git a/datafusion/expr/src/predicate_eval.rs b/datafusion/expr/src/predicate_eval.rs new file mode 100644 index 000000000000..d995e468c3e6 --- /dev/null +++ b/datafusion/expr/src/predicate_eval.rs @@ -0,0 +1,234 @@ +use crate::predicate_eval::TriStateBool::{False, True, Uncertain}; +use crate::{BinaryExpr, Expr, ExprSchemable}; +use arrow::datatypes::DataType; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{DataFusionError, ExprSchema, ScalarValue}; +use datafusion_expr_common::operator::Operator; + +/// Represents the possible values for SQL's three valued logic. +/// `Option` is not used for this since `None` is used by +/// [const_eval_predicate] to represent inconclusive answers. +pub(super) enum TriStateBool { + True, + False, + Uncertain, +} + +impl TryFrom<&ScalarValue> for TriStateBool { + type Error = DataFusionError; + + fn try_from(value: &ScalarValue) -> Result { + match value { + ScalarValue::Null => { + // Literal null is equivalent to boolean uncertain + Ok(Uncertain) + }, + ScalarValue::Boolean(b) => Ok(match b { + Some(true) => True, + Some(false) => False, + None => Uncertain, + }), + _ => Self::try_from(&value.cast_to(&DataType::Boolean)?), + } + } +} + +/// Attempts to partially constant-evaluate a predicate under SQL three-valued logic. +/// +/// Semantics of the return value: +/// - `Some(True)` => predicate is provably true +/// - `Some(False)` => predicate is provably false +/// - `Some(Uncertain)` => predicate is provably unknown (i.e., can only be NULL) +/// - `None` => inconclusive with available static information +/// +/// The evaluation is conservative and only uses: +/// - Expression nullability from `input_schema` +/// - Simple type checks (e.g. whether an expression is Boolean) +/// - Syntactic patterns (IS NULL/IS NOT NULL/IS TRUE/IS FALSE/etc.) +/// - Three-valued boolean algebra for AND/OR/NOT +/// +/// It does not evaluate user-defined functions. +pub(super) fn const_eval_predicate( + predicate: &Expr, + input_schema: &dyn ExprSchema, + evaluates_to_null: F, +) -> Option +where + F: Fn(&Expr) -> TriStateBool, +{ + PredicateConstEvaluator { + input_schema, + evaluates_to_null, + } + .const_eval_predicate(predicate) +} + +pub(super) struct PredicateConstEvaluator<'a, F> { + input_schema: &'a dyn ExprSchema, + evaluates_to_null: F, +} + +impl PredicateConstEvaluator<'_, F> +where + F: Fn(&Expr) -> TriStateBool, +{ + fn const_eval_predicate(&self, predicate: &Expr) -> Option { + match predicate { + Expr::Literal(scalar, _) => TriStateBool::try_from(scalar).ok(), + Expr::IsNotNull(e) => { + if let Ok(false) = e.nullable(self.input_schema) { + // If `e` is not nullable, then `e IS NOT NULL` is always true + return Some(True); + } + + match e.get_type(self.input_schema) { + Ok(DataType::Boolean) => match self.const_eval_predicate(e) { + Some(True) | Some(False) => Some(True), + Some(Uncertain) => Some(False), + None => None, + }, + Ok(_) => match self.is_null(e) { + True => Some(False), + False => Some(True), + Uncertain => None, + }, + Err(_) => None, + } + } + Expr::IsNull(e) => { + if let Ok(false) = e.nullable(self.input_schema) { + // If `e` is not nullable, then `e IS NULL` is always false + return Some(False); + } + + match e.get_type(self.input_schema) { + Ok(DataType::Boolean) => match self.const_eval_predicate(e) { + Some(True) | Some(False) => Some(False), + Some(Uncertain) => Some(True), + None => None, + }, + Ok(_) => match self.is_null(e) { + True => Some(True), + False => Some(False), + Uncertain => None, + }, + Err(_) => None, + } + } + Expr::IsTrue(e) => match self.const_eval_predicate(e) { + Some(True) => Some(True), + Some(_) => Some(False), + None => None, + }, + Expr::IsNotTrue(e) => match self.const_eval_predicate(e) { + Some(True) => Some(False), + Some(_) => Some(True), + None => None, + }, + Expr::IsFalse(e) => match self.const_eval_predicate(e) { + Some(False) => Some(True), + Some(_) => Some(False), + None => None, + }, + Expr::IsNotFalse(e) => match self.const_eval_predicate(e) { + Some(False) => Some(False), + Some(_) => Some(True), + None => None, + }, + Expr::IsUnknown(e) => match self.const_eval_predicate(e) { + Some(Uncertain) => Some(True), + Some(_) => Some(False), + None => None, + }, + Expr::IsNotUnknown(e) => match self.const_eval_predicate(e) { + Some(Uncertain) => Some(False), + Some(_) => Some(True), + None => None, + }, + Expr::Not(e) => match self.const_eval_predicate(e) { + Some(True) => Some(False), + Some(False) => Some(True), + Some(Uncertain) => Some(Uncertain), + None => None, + }, + Expr::BinaryExpr(BinaryExpr { left, op: Operator::And, right }) => { + match ( + self.const_eval_predicate(left), + self.const_eval_predicate(right), + ) { + (Some(False), _) | (_, Some(False)) => Some(False), + (Some(True), Some(True)) => Some(True), + (Some(Uncertain), Some(_)) | (Some(_), Some(Uncertain)) => { + Some(Uncertain) + } + _ => None, + } + }, + Expr::BinaryExpr(BinaryExpr { left, op: Operator::Or, right }) => { + match ( + self.const_eval_predicate(left), + self.const_eval_predicate(right), + ) { + (Some(True), _) | (_, Some(True)) => Some(True), + (Some(False), Some(False)) => Some(False), + (Some(Uncertain), Some(_)) | (Some(_), Some(Uncertain)) => { + Some(Uncertain) + } + _ => None, + } + }, + e => match self.is_null(e) { + True => Some(Uncertain), + _ => None, + }, + } + } + + /// Determines if the given expression evaluates to `NULL`. + /// + /// This function returns: + /// - `True` if `expr` is provably `NULL` + /// - `False` if `expr` can provably not `NULL` + /// - `Uncertain` if the result is inconclusive + fn is_null(&self, expr: &Expr) -> TriStateBool { + match expr { + Expr::Literal(ScalarValue::Null, _) => { + // Literal null is obviously null + True + } + Expr::Alias(_) + | Expr::Between(_) + | Expr::BinaryExpr(_) + | Expr::Cast(_) + | Expr::Like(_) + | Expr::Negative(_) + | Expr::Not(_) + | Expr::SimilarTo(_) => { + // These expressions are null if any of their direct children is null + // If any child is inconclusive, the result for this expression is also inconclusive + let mut is_null = False; + let _ = expr.apply_children(|child| match self.is_null(child) { + True => { + is_null = True; + Ok(TreeNodeRecursion::Stop) + } + False => Ok(TreeNodeRecursion::Continue), + Uncertain => { + is_null = Uncertain; + Ok(TreeNodeRecursion::Stop) + } + }); + is_null + } + e => { + if let Ok(false) = e.nullable(self.input_schema) { + // If `expr` is not nullable, we can be certain `expr` is not null + False + } else { + // Finally, ask the callback if it knows the nullability of `expr` + (self.evaluates_to_null)(e) + } + } + } + } +} From f4c857913b5b6347a82967cbf796e1152304ef98 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 9 Oct 2025 19:44:56 +0200 Subject: [PATCH 09/23] Code formatting --- datafusion/core/tests/tpcds_planning.rs | 2 +- datafusion/expr/src/expr_schema.rs | 14 +++++++++++--- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/predicate_eval.rs | 18 +++++++++++++----- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 09efc7cff685..3ad74962bc2c 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1055,7 +1055,7 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { if create_physical { let _ = state.create_physical_plan(&plan).await?; } else { - // Run the logical optimizer even if we are not creating the physical plan + // Run the logical optimizer even if we are not creating the physical plan // to ensure it will properly succeed let _ = state.optimize(&plan)?; } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index b69d9f50d36d..aab01ad400af 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -294,9 +294,17 @@ impl ExprSchemable for Expr { // using limited const evaluation if the branch will be taken when // the then expression evaluates to null. Ok(true) => { - let const_result = predicate_eval::const_eval_predicate(w, input_schema, |expr| { - if expr.eq(t) { TriStateBool::True } else { TriStateBool::Uncertain } - }); + let const_result = predicate_eval::const_eval_predicate( + w, + input_schema, + |expr| { + if expr.eq(t) { + TriStateBool::True + } else { + TriStateBool::Uncertain + } + }, + ); match const_result { // Const evaluation was inconclusive or determined the branch diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 71a5d0ad60d1..121929f4a13c 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -69,6 +69,7 @@ pub mod async_udf; pub mod statistics { pub use datafusion_expr_common::statistics::*; } +mod predicate_eval; pub mod ptr_eq; pub mod test; pub mod tree_node; @@ -78,7 +79,6 @@ pub mod utils; pub mod var_provider; pub mod window_frame; pub mod window_state; -mod predicate_eval; pub use datafusion_doc::{ aggregate_doc_sections, scalar_doc_sections, window_doc_sections, DocSection, diff --git a/datafusion/expr/src/predicate_eval.rs b/datafusion/expr/src/predicate_eval.rs index d995e468c3e6..d8ce4b97e1eb 100644 --- a/datafusion/expr/src/predicate_eval.rs +++ b/datafusion/expr/src/predicate_eval.rs @@ -22,7 +22,7 @@ impl TryFrom<&ScalarValue> for TriStateBool { ScalarValue::Null => { // Literal null is equivalent to boolean uncertain Ok(Uncertain) - }, + } ScalarValue::Boolean(b) => Ok(match b { Some(true) => True, Some(false) => False, @@ -151,7 +151,11 @@ where Some(Uncertain) => Some(Uncertain), None => None, }, - Expr::BinaryExpr(BinaryExpr { left, op: Operator::And, right }) => { + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) => { match ( self.const_eval_predicate(left), self.const_eval_predicate(right), @@ -163,8 +167,12 @@ where } _ => None, } - }, - Expr::BinaryExpr(BinaryExpr { left, op: Operator::Or, right }) => { + } + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Or, + right, + }) => { match ( self.const_eval_predicate(left), self.const_eval_predicate(right), @@ -176,7 +184,7 @@ where } _ => None, } - }, + } e => match self.is_null(e) { True => Some(Uncertain), _ => None, From 81b6ec17d650426883871a7aebda18013a64f289 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 9 Oct 2025 19:50:02 +0200 Subject: [PATCH 10/23] Add license header --- datafusion/expr/src/predicate_eval.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/datafusion/expr/src/predicate_eval.rs b/datafusion/expr/src/predicate_eval.rs index d8ce4b97e1eb..238e0d711a55 100644 --- a/datafusion/expr/src/predicate_eval.rs +++ b/datafusion/expr/src/predicate_eval.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use crate::predicate_eval::TriStateBool::{False, True, Uncertain}; use crate::{BinaryExpr, Expr, ExprSchemable}; use arrow::datatypes::DataType; From 313189908879be71c58386135e027ea052cf177a Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 6 Nov 2025 18:37:58 +0100 Subject: [PATCH 11/23] Try to align logical and physical implementations as much as possible --- datafusion/expr/src/expr_schema.rs | 37 ++- datafusion/expr/src/predicate_eval.rs | 219 +++++++++++------- .../physical-expr/src/expressions/case.rs | 150 ++++++------ 3 files changed, 227 insertions(+), 179 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index bada6db4ebb9..5bc28e267b1c 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -21,7 +21,6 @@ use crate::expr::{ InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; -use crate::predicate_eval::TriStateBool; use crate::type_coercion::functions::{ data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, }; @@ -291,33 +290,31 @@ impl ExprSchemable for Expr { match t.nullable(input_schema) { // Branches with a then expression that is not nullable can be skipped Ok(false) => None, - // Pass error determining nullability on verbatim + // Pass on error determining nullability verbatim Err(e) => Some(Err(e)), - // For branches with a nullable then expressions try to determine + // For branches with a nullable 'then' expression, try to determine // using limited const evaluation if the branch will be taken when - // the then expression evaluates to null. + // the 'then' expression evaluates to null. Ok(true) => { - let const_result = predicate_eval::const_eval_predicate( + let is_null = |expr: &Expr /* Type */| { + if expr.eq(t) { + Some(true) + } else { + None + } + }; + + match predicate_eval::const_eval_predicate( w, + is_null, input_schema, - |expr| { - if expr.eq(t) { - TriStateBool::True - } else { - TriStateBool::Uncertain - } - }, - ); - - match const_result { + ) { // Const evaluation was inconclusive or determined the branch // would be taken - None | Some(TriStateBool::True) => Some(Ok(())), + None | Some(true) => Some(Ok(())), // Const evaluation proves the branch will never be taken. - // The most common pattern for this is - // `WHEN x IS NOT NULL THEN x`. - Some(TriStateBool::False) - | Some(TriStateBool::Uncertain) => None, + // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. + Some(false) => None, } } } diff --git a/datafusion/expr/src/predicate_eval.rs b/datafusion/expr/src/predicate_eval.rs index 238e0d711a55..3718706de877 100644 --- a/datafusion/expr/src/predicate_eval.rs +++ b/datafusion/expr/src/predicate_eval.rs @@ -21,16 +21,25 @@ use arrow::datatypes::DataType; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DataFusionError, ExprSchema, ScalarValue}; use datafusion_expr_common::operator::Operator; +use std::ops::{BitAnd, BitOr, Not}; /// Represents the possible values for SQL's three valued logic. -/// `Option` is not used for this since `None` is used by -/// [const_eval_predicate] to represent inconclusive answers. -pub(super) enum TriStateBool { +enum TriStateBool { True, False, Uncertain, } +impl From> for TriStateBool { + fn from(value: Option) -> Self { + match value { + None => Uncertain, + Some(true) => True, + Some(false) => False, + } + } +} + impl TryFrom<&ScalarValue> for TriStateBool { type Error = DataFusionError; @@ -50,12 +59,77 @@ impl TryFrom<&ScalarValue> for TriStateBool { } } +impl TriStateBool { + fn is_null(&self) -> TriStateBool { + match self { + True | False => False, + Uncertain => True, + } + } + + fn is_true(&self) -> TriStateBool { + match self { + True => True, + Uncertain | False => False, + } + } + + fn is_false(&self) -> TriStateBool { + match self { + False => True, + Uncertain | True => False, + } + } + + fn is_unknown(&self) -> TriStateBool { + match self { + Uncertain => True, + True | False => False, + } + } +} + +impl Not for TriStateBool { + type Output = TriStateBool; + + fn not(self) -> Self::Output { + match self { + True => False, + False => True, + Uncertain => Uncertain, + } + } +} + +impl BitAnd for TriStateBool { + type Output = TriStateBool; + + fn bitand(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (False, _) | (_, False) => False, + (Uncertain, _) | (_, Uncertain) => Uncertain, + (True, True) => True, + } + } +} + +impl BitOr for TriStateBool { + type Output = TriStateBool; + + fn bitor(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (True, _) | (_, True) => True, + (Uncertain, _) | (_, Uncertain) => Uncertain, + (False, False) => False, + } + } +} + /// Attempts to partially constant-evaluate a predicate under SQL three-valued logic. /// /// Semantics of the return value: -/// - `Some(True)` => predicate is provably true -/// - `Some(False)` => predicate is provably false -/// - `Some(Uncertain)` => predicate is provably unknown (i.e., can only be NULL) +/// - `Some(true)` => predicate is provably true +/// - `Some(false)` => predicate is provably false /// - `None` => inconclusive with available static information /// /// The evaluation is conservative and only uses: @@ -67,17 +141,18 @@ impl TryFrom<&ScalarValue> for TriStateBool { /// It does not evaluate user-defined functions. pub(super) fn const_eval_predicate( predicate: &Expr, - input_schema: &dyn ExprSchema, evaluates_to_null: F, -) -> Option + input_schema: &dyn ExprSchema, +) -> Option where - F: Fn(&Expr) -> TriStateBool, + F: Fn(&Expr) -> Option, { PredicateConstEvaluator { input_schema, evaluates_to_null, } .const_eval_predicate(predicate) + .map(|b| matches!(b, True)) } pub(super) struct PredicateConstEvaluator<'a, F> { @@ -87,24 +162,22 @@ pub(super) struct PredicateConstEvaluator<'a, F> { impl PredicateConstEvaluator<'_, F> where - F: Fn(&Expr) -> TriStateBool, + F: Fn(&Expr) -> Option, { fn const_eval_predicate(&self, predicate: &Expr) -> Option { match predicate { Expr::Literal(scalar, _) => TriStateBool::try_from(scalar).ok(), Expr::IsNotNull(e) => { if let Ok(false) = e.nullable(self.input_schema) { - // If `e` is not nullable, then `e IS NOT NULL` is always true + // If `e` is not nullable -> `e IS NOT NULL` is true return Some(True); } match e.get_type(self.input_schema) { - Ok(DataType::Boolean) => match self.const_eval_predicate(e) { - Some(True) | Some(False) => Some(True), - Some(Uncertain) => Some(False), - None => None, - }, - Ok(_) => match self.is_null(e) { + Ok(DataType::Boolean) => { + self.const_eval_predicate(e).map(|b| b.is_null()) + } + Ok(_) => match self.evaluates_to_null(e) { True => Some(False), False => Some(True), Uncertain => None, @@ -114,17 +187,15 @@ where } Expr::IsNull(e) => { if let Ok(false) = e.nullable(self.input_schema) { - // If `e` is not nullable, then `e IS NULL` is always false + // If `e` is not nullable -> `e IS NULL` is false return Some(False); } match e.get_type(self.input_schema) { - Ok(DataType::Boolean) => match self.const_eval_predicate(e) { - Some(True) | Some(False) => Some(False), - Some(Uncertain) => Some(True), - None => None, - }, - Ok(_) => match self.is_null(e) { + Ok(DataType::Boolean) => { + self.const_eval_predicate(e).map(|b| !b.is_null()) + } + Ok(_) => match self.evaluates_to_null(e) { True => Some(True), False => Some(False), Uncertain => None, @@ -132,42 +203,15 @@ where Err(_) => None, } } - Expr::IsTrue(e) => match self.const_eval_predicate(e) { - Some(True) => Some(True), - Some(_) => Some(False), - None => None, - }, - Expr::IsNotTrue(e) => match self.const_eval_predicate(e) { - Some(True) => Some(False), - Some(_) => Some(True), - None => None, - }, - Expr::IsFalse(e) => match self.const_eval_predicate(e) { - Some(False) => Some(True), - Some(_) => Some(False), - None => None, - }, - Expr::IsNotFalse(e) => match self.const_eval_predicate(e) { - Some(False) => Some(False), - Some(_) => Some(True), - None => None, - }, - Expr::IsUnknown(e) => match self.const_eval_predicate(e) { - Some(Uncertain) => Some(True), - Some(_) => Some(False), - None => None, - }, - Expr::IsNotUnknown(e) => match self.const_eval_predicate(e) { - Some(Uncertain) => Some(False), - Some(_) => Some(True), - None => None, - }, - Expr::Not(e) => match self.const_eval_predicate(e) { - Some(True) => Some(False), - Some(False) => Some(True), - Some(Uncertain) => Some(Uncertain), - None => None, - }, + Expr::IsTrue(e) => self.const_eval_predicate(e).map(|b| b.is_true()), + Expr::IsNotTrue(e) => self.const_eval_predicate(e).map(|b| !b.is_true()), + Expr::IsFalse(e) => self.const_eval_predicate(e).map(|b| b.is_false()), + Expr::IsNotFalse(e) => self.const_eval_predicate(e).map(|b| !b.is_false()), + Expr::IsUnknown(e) => self.const_eval_predicate(e).map(|b| b.is_unknown()), + Expr::IsNotUnknown(e) => { + self.const_eval_predicate(e).map(|b| !b.is_unknown()) + } + Expr::Not(e) => self.const_eval_predicate(e).map(|b| !b), Expr::BinaryExpr(BinaryExpr { left, op: Operator::And, @@ -178,11 +222,8 @@ where self.const_eval_predicate(right), ) { (Some(False), _) | (_, Some(False)) => Some(False), - (Some(True), Some(True)) => Some(True), - (Some(Uncertain), Some(_)) | (Some(_), Some(Uncertain)) => { - Some(Uncertain) - } - _ => None, + (None, _) | (_, None) => None, + (Some(l), Some(r)) => Some(l & r), } } Expr::BinaryExpr(BinaryExpr { @@ -195,14 +236,11 @@ where self.const_eval_predicate(right), ) { (Some(True), _) | (_, Some(True)) => Some(True), - (Some(False), Some(False)) => Some(False), - (Some(Uncertain), Some(_)) | (Some(_), Some(Uncertain)) => { - Some(Uncertain) - } - _ => None, + (None, _) | (_, None) => None, + (Some(l), Some(r)) => Some(l | r), } } - e => match self.is_null(e) { + e => match self.evaluates_to_null(e) { True => Some(Uncertain), _ => None, }, @@ -213,13 +251,17 @@ where /// /// This function returns: /// - `True` if `expr` is provably `NULL` - /// - `False` if `expr` can provably not `NULL` + /// - `False` if `expr` is provably not `NULL` /// - `Uncertain` if the result is inconclusive - fn is_null(&self, expr: &Expr) -> TriStateBool { + fn evaluates_to_null(&self, expr: &Expr) -> TriStateBool { match expr { - Expr::Literal(ScalarValue::Null, _) => { + Expr::Literal(s, _) => { // Literal null is obviously null - True + if s.is_null() { + True + } else { + False + } } Expr::Alias(_) | Expr::Between(_) @@ -232,17 +274,18 @@ where // These expressions are null if any of their direct children is null // If any child is inconclusive, the result for this expression is also inconclusive let mut is_null = False; - let _ = expr.apply_children(|child| match self.is_null(child) { - True => { - is_null = True; - Ok(TreeNodeRecursion::Stop) - } - False => Ok(TreeNodeRecursion::Continue), - Uncertain => { - is_null = Uncertain; - Ok(TreeNodeRecursion::Stop) - } - }); + let _ = + expr.apply_children(|child| match self.evaluates_to_null(child) { + True => { + is_null = True; + Ok(TreeNodeRecursion::Stop) + } + False => Ok(TreeNodeRecursion::Continue), + Uncertain => { + is_null = Uncertain; + Ok(TreeNodeRecursion::Stop) + } + }); is_null } e => { @@ -250,8 +293,8 @@ where // If `expr` is not nullable, we can be certain `expr` is not null False } else { - // Finally, ask the callback if it knows the nullability of `expr` - (self.evaluates_to_null)(e) + // Finally, ask the callback if it knows the nullness of `expr` + (self.evaluates_to_null)(e).into() } } } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 095d659e8ecf..9ef5311a073e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -17,8 +17,7 @@ use super::{Column, Literal}; use crate::expressions::case::ResultState::{Complete, Empty, Partial}; -use crate::expressions::try_cast; -use crate::expressions::{try_cast, BinaryExpr, IsNotNullExpr, IsNullExpr, NotExpr}; +use crate::expressions::{lit, try_cast}; use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::zip::zip; @@ -39,14 +38,9 @@ use std::borrow::Cow; use std::hash::Hash; use std::{any::Any, sync::Arc}; -use super::{Column, Literal}; -use datafusion_expr_common::operator::Operator; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; -use std::borrow::Cow; use std::fmt::{Debug, Formatter}; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; type WhenThen = (Arc, Arc); @@ -1290,18 +1284,39 @@ impl PhysicalExpr for CaseExpr { fn nullable(&self, input_schema: &Schema) -> Result { // this expression is nullable if any of the input expressions are nullable - let then_nullable = self + let any_nullable_thens = !self .body .when_then_expr .iter() - .filter(|(w, t)| { - // Disregard branches where we can determine statically that the predicate - // is always false when the then expression would evaluate to null - const_result_when_value_is_null(w.as_ref(), t.as_ref()).unwrap_or(true) + .filter_map(|(w, t)| { + match t.nullable(input_schema) { + // Branches with a then expression that is not nullable can be skipped + Ok(false) => None, + // Pass on error determining nullability verbatim + Err(e) => Some(Err(e)), + // For branches with a nullable 'then' expression, try to determine + // using const evaluation if the branch will be taken when + // the 'then' expression evaluates to null. + Ok(true) => { + let is_null = + |expr: &dyn PhysicalExpr /* Type */| expr.dyn_eq(t.as_ref()); + + match const_eval_predicate(w, is_null, input_schema) { + // Const evaluation was inconclusive or determined the branch + // would be taken + Ok(None) | Ok(Some(true)) => Some(Ok(())), + // Const evaluation proves the branch will never be taken. + // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. + Ok(Some(false)) => None, + Err(e) => Some(Err(e)), + } + } + } }) - .map(|(_, t)| t.nullable(input_schema)) - .collect::>>()?; - if then_nullable.contains(&true) { + .collect::>>()? + .is_empty(); + + if any_nullable_thens { Ok(true) } else if let Some(e) = &self.body.else_expr { e.nullable(input_schema) @@ -1406,52 +1421,50 @@ impl PhysicalExpr for CaseExpr { } } -/// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`. -/// Returns a `Some` value containing the const result if so; otherwise returns `None`. -fn const_result_when_value_is_null( - predicate: &dyn PhysicalExpr, - value: &dyn PhysicalExpr, -) -> Option { - let predicate_any = predicate.as_any(); - if let Some(not_null) = predicate_any.downcast_ref::() { - if not_null.arg().as_ref().dyn_eq(value) { - Some(false) - } else { - None - } - } else if let Some(null) = predicate_any.downcast_ref::() { - if null.arg().as_ref().dyn_eq(value) { - Some(true) - } else { - None - } - } else if let Some(not) = predicate_any.downcast_ref::() { - const_result_when_value_is_null(not.arg().as_ref(), value).map(|b| !b) - } else if let Some(binary) = predicate_any.downcast_ref::() { - match binary.op() { - Operator::And => { - let l = const_result_when_value_is_null(binary.left().as_ref(), value); - let r = const_result_when_value_is_null(binary.right().as_ref(), value); - match (l, r) { - (Some(l), Some(r)) => Some(l && r), - (Some(l), None) => Some(l), - (None, Some(r)) => Some(r), - _ => None, - } - } - Operator::Or => { - let l = const_result_when_value_is_null(binary.left().as_ref(), value); - let r = const_result_when_value_is_null(binary.right().as_ref(), value); - match (l, r) { - (Some(l), Some(r)) => Some(l || r), - _ => None, - } +/// Attempts to const evaluate the given `predicate` with the assumption that `value` evaluates to `NULL`. +/// Returns: +/// - `Some(true)` if the predicate evaluates to a truthy value. +/// - `Some(false)` if the predicate evaluates to a falsy value. +/// - `None` if the predicate could not be evaluated. +fn const_eval_predicate( + predicate: &Arc, + evaluates_to_null: F, + input_schema: &Schema, +) -> Result> +where + F: Fn(&dyn PhysicalExpr) -> bool, +{ + // Replace `value` with `NULL` in `predicate` + let with_null = Arc::clone(predicate) + .transform_down(|e| { + if evaluates_to_null(e.as_ref()) { + let data_type = e.data_type(input_schema)?; + let null_literal = lit(ScalarValue::try_new_null(&data_type)?); + Ok(Transformed::yes(null_literal)) + } else { + Ok(Transformed::no(e)) } - _ => None, - } - } else { - None - } + })? + .data; + + // Create a dummy record with no columns and one row + let batch = RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(1)), + )?; + + // Evaluate the predicate and interpret the result as a boolean + let result = match with_null.evaluate(&batch) { + // An error during evaluation means we couldn't const evaluate the predicate, so return `None` + Err(_) => None, + Ok(ColumnarValue::Array(array)) => Some( + ScalarValue::try_from_array(array.as_ref(), 0)? + .cast_to(&DataType::Boolean)?, + ), + Ok(ColumnarValue::Scalar(scalar)) => Some(scalar.cast_to(&DataType::Boolean)?), + }; + Ok(result.map(|v| matches!(v, ScalarValue::Boolean(Some(true))))) } /// Create a CASE expression @@ -1476,6 +1489,7 @@ mod tests { use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::type_coercion::binary::comparison_coercion; + use datafusion_expr_common::operator::Operator; use datafusion_physical_expr_common::physical_expr::fmt_sql; #[test] @@ -2390,11 +2404,7 @@ mod tests { assert_not_nullable(when_then_else(&foo_is_not_null, &foo, &zero)?, &schema); assert_not_nullable(when_then_else(¬_foo_is_null, &foo, &zero)?, &schema); - assert_nullability( - when_then_else(&foo_eq_zero, &foo, &zero)?, - &schema, - col_is_nullable, - ); + assert_not_nullable(when_then_else(&foo_eq_zero, &foo, &zero)?, &schema); assert_not_nullable( when_then_else( @@ -2424,7 +2434,7 @@ mod tests { &schema, ); - assert_nullability( + assert_not_nullable( when_then_else( &binary( Arc::clone(&foo_is_not_null), @@ -2436,10 +2446,9 @@ mod tests { &zero, )?, &schema, - col_is_nullable, ); - assert_nullability( + assert_not_nullable( when_then_else( &binary( Arc::clone(&foo_eq_zero), @@ -2451,13 +2460,12 @@ mod tests { &zero, )?, &schema, - col_is_nullable, ); assert_nullability( when_then_else( &binary( - Arc::clone(&foo_is_not_null), + Arc::clone(&foo_is_null), Operator::Or, Arc::clone(&foo_eq_zero), &schema, @@ -2474,7 +2482,7 @@ mod tests { &binary( binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?, Operator::Or, - Arc::clone(&foo_is_not_null), + Arc::clone(&foo_is_null), &schema, )?, &foo, From 3da92e544ab2ecc230cb4339d6bbda3e44fe38b1 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Thu, 6 Nov 2025 21:51:40 +0100 Subject: [PATCH 12/23] Allow optimizations to change fields from nullable to not-nullable --- datafusion/core/src/physical_planner.rs | 32 ++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c280b50a9f07..b4b9c479a1c9 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -60,6 +60,7 @@ use crate::schema_equivalence::schema_satisfied_by; use arrow::array::{builder::StringBuilder, RecordBatch}; use arrow::compute::SortOptions; use arrow::datatypes::Schema; +use arrow_schema::Field; use datafusion_catalog::ScanArgs; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::format::ExplainAnalyzeLevel; @@ -2516,7 +2517,9 @@ impl<'a> OptimizationInvariantChecker<'a> { previous_schema: Arc, ) -> Result<()> { // if the rule is not permitted to change the schema, confirm that it did not change. - if self.rule.schema_check() && plan.schema() != previous_schema { + if self.rule.schema_check() + && !is_allowed_schema_change(previous_schema.as_ref(), plan.schema().as_ref()) + { internal_err!("PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {:?}, got new schema: {:?}", self.rule.name(), previous_schema, @@ -2532,6 +2535,33 @@ impl<'a> OptimizationInvariantChecker<'a> { } } +/// Checks if the change from `old` schema to `new` is allowed or not. +/// The current implementation only allows nullability of individual fields to change +/// from 'nullable' to 'not nullable'. +fn is_allowed_schema_change(old: &Schema, new: &Schema) -> bool { + if new.metadata != old.metadata { + return false; + } + + if new.fields.len() != old.fields.len() { + return false; + } + + let new_fields = new.fields.iter().map(|f| f.as_ref()); + let old_fields = old.fields.iter().map(|f| f.as_ref()); + old_fields + .zip(new_fields) + .all(|(old, new)| is_allowed_field_change(old, new)) +} + +fn is_allowed_field_change(old_field: &Field, new_field: &Field) -> bool { + new_field.name() == old_field.name() + && new_field.data_type() == old_field.data_type() + && new_field.metadata() == old_field.metadata() + && (new_field.is_nullable() == old_field.is_nullable() + || !new_field.is_nullable()) +} + impl<'n> TreeNodeVisitor<'n> for OptimizationInvariantChecker<'_> { type Node = Arc; From 0a6b2e777b8fc11c107044b246b55ad50d2730f6 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Fri, 7 Nov 2025 09:07:11 +0100 Subject: [PATCH 13/23] Correctly handle case-with-expression nullability analysis --- datafusion/expr/src/expr_schema.rs | 91 +++++++++++-------- .../physical-expr/src/expressions/case.rs | 82 ++++++++++------- datafusion/sqllogictest/test_files/case.slt | 6 ++ 3 files changed, 111 insertions(+), 68 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 5bc28e267b1c..f520dd07adbd 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -282,49 +282,66 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(field, _) => Ok(field.is_nullable()), Expr::Literal(value, _) => Ok(value.is_null()), Expr::Case(case) => { - // This expression is nullable if any of the then expressions are nullable - let any_nullable_thens = !case - .when_then_expr - .iter() - .filter_map(|(w, t)| { - match t.nullable(input_schema) { - // Branches with a then expression that is not nullable can be skipped - Ok(false) => None, - // Pass on error determining nullability verbatim + let nullable_then = if case.expr.is_some() { + // Case-with-expression is nullable if any of the 'then' expressions. + // Assume all 'then' expressions are reachable + case.when_then_expr + .iter() + .filter_map(|(_, t)| match t.nullable(input_schema) { + Ok(n) => { + if n { + Some(Ok(())) + } else { + None + } + } Err(e) => Some(Err(e)), - // For branches with a nullable 'then' expression, try to determine - // using limited const evaluation if the branch will be taken when - // the 'then' expression evaluates to null. - Ok(true) => { - let is_null = |expr: &Expr /* Type */| { - if expr.eq(t) { - Some(true) - } else { - None + }) + .next() + } else { + // case-without-expression is nullable if any of the 'then' expressions is nullable + // and reachable when the 'then' expression evaluates to `null`. + case.when_then_expr + .iter() + .filter_map(|(w, t)| { + match t.nullable(input_schema) { + // Branches with a then expression that is not nullable can be skipped + Ok(false) => None, + // Pass on error determining nullability verbatim + Err(e) => Some(Err(e)), + // For branches with a nullable 'then' expression, try to determine + // using limited const evaluation if the branch will be taken when + // the 'then' expression evaluates to null. + Ok(true) => { + let is_null = |expr: &Expr /* Type */| { + if expr.eq(t) { + Some(true) + } else { + None + } + }; + + match predicate_eval::const_eval_predicate( + w, + is_null, + input_schema, + ) { + // Const evaluation was inconclusive or determined the branch + // would be taken + None | Some(true) => Some(Ok(())), + // Const evaluation proves the branch will never be taken. + // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. + Some(false) => None, } - }; - - match predicate_eval::const_eval_predicate( - w, - is_null, - input_schema, - ) { - // Const evaluation was inconclusive or determined the branch - // would be taken - None | Some(true) => Some(Ok(())), - // Const evaluation proves the branch will never be taken. - // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. - Some(false) => None, } } - } - }) - .collect::>>()? - .is_empty(); + }) + .next() + }; - if any_nullable_thens { + if let Some(nullable_then) = nullable_then { // There is at least one reachable nullable then - Ok(true) + nullable_then.map(|_| true) } else if let Some(e) = &case.else_expr { e.nullable(input_schema) } else { diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 9ef5311a073e..845242e92102 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -1283,41 +1283,61 @@ impl PhysicalExpr for CaseExpr { } fn nullable(&self, input_schema: &Schema) -> Result { - // this expression is nullable if any of the input expressions are nullable - let any_nullable_thens = !self - .body - .when_then_expr - .iter() - .filter_map(|(w, t)| { - match t.nullable(input_schema) { - // Branches with a then expression that is not nullable can be skipped - Ok(false) => None, - // Pass on error determining nullability verbatim + let nullable_then = if self.body.expr.is_some() { + // Case-with-expression is nullable if any of the 'then' expressions. + // Assume all 'then' expressions are reachable + self.body + .when_then_expr + .iter() + .filter_map(|(_, t)| match t.nullable(input_schema) { + Ok(n) => { + if n { + Some(Ok(())) + } else { + None + } + } Err(e) => Some(Err(e)), - // For branches with a nullable 'then' expression, try to determine - // using const evaluation if the branch will be taken when - // the 'then' expression evaluates to null. - Ok(true) => { - let is_null = - |expr: &dyn PhysicalExpr /* Type */| expr.dyn_eq(t.as_ref()); - - match const_eval_predicate(w, is_null, input_schema) { - // Const evaluation was inconclusive or determined the branch - // would be taken - Ok(None) | Ok(Some(true)) => Some(Ok(())), - // Const evaluation proves the branch will never be taken. - // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. - Ok(Some(false)) => None, - Err(e) => Some(Err(e)), + }) + .next() + } else { + // case-without-expression is nullable if any of the 'then' expressions is nullable + // and reachable when the 'then' expression evaluates to `null`. + self.body + .when_then_expr + .iter() + .filter_map(|(w, t)| { + match t.nullable(input_schema) { + // Branches with a then expression that is not nullable can be skipped + Ok(false) => None, + // Pass on error determining nullability verbatim + Err(e) => Some(Err(e)), + Ok(true) => { + // For branches with a nullable 'then' expression, try to determine + // using const evaluation if the branch will be taken when + // the 'then' expression evaluates to null. + let is_null = |expr: &dyn PhysicalExpr /* Type */| { + expr.dyn_eq(t.as_ref()) + }; + + match const_eval_predicate(w, is_null, input_schema) { + // Const evaluation was inconclusive or determined the branch + // would be taken + Ok(None) | Ok(Some(true)) => Some(Ok(())), + // Const evaluation proves the branch will never be taken. + // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. + Ok(Some(false)) => None, + Err(e) => Some(Err(e)), + } } } - } - }) - .collect::>>()? - .is_empty(); + }) + .next() + }; - if any_nullable_thens { - Ok(true) + if let Some(nullable_then) = nullable_then { + // There is at least one reachable nullable then + nullable_then.map(|_| true) } else if let Some(e) = &self.body.else_expr { e.nullable(input_schema) } else { diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 1a4b6a7a2b4a..2bd644d8a8ac 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -683,3 +683,9 @@ FROM ( 10 10 100 -20 20 200 NULL 30 300 + +# Case-with-expression that was incorrectly classified as not-nullable, but evaluates to null +query I +SELECT CASE 0 WHEN 0 THEN NULL WHEN SUM(1) + COUNT(*) THEN 10 ELSE 20 END +---- +NULL From 113e899fa2de08b550122a1bc5e74c2973198094 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Fri, 7 Nov 2025 12:44:39 +0100 Subject: [PATCH 14/23] Add unit tests for predicate_eval --- datafusion/expr/src/predicate_eval.rs | 435 +++++++++++++++++++++++++- 1 file changed, 430 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/predicate_eval.rs b/datafusion/expr/src/predicate_eval.rs index 3718706de877..4e3cbafcc2e4 100644 --- a/datafusion/expr/src/predicate_eval.rs +++ b/datafusion/expr/src/predicate_eval.rs @@ -24,14 +24,15 @@ use datafusion_expr_common::operator::Operator; use std::ops::{BitAnd, BitOr, Not}; /// Represents the possible values for SQL's three valued logic. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum TriStateBool { True, False, Uncertain, } -impl From> for TriStateBool { - fn from(value: Option) -> Self { +impl From<&Option> for TriStateBool { + fn from(value: &Option) -> Self { match value { None => Uncertain, Some(true) => True, @@ -60,6 +61,14 @@ impl TryFrom<&ScalarValue> for TriStateBool { } impl TriStateBool { + fn try_from_no_cooerce(value: &ScalarValue) -> Option { + match value { + ScalarValue::Null => Some(Uncertain), + ScalarValue::Boolean(b) => Some(TriStateBool::from(b)), + _ => None, + } + } + fn is_null(&self) -> TriStateBool { match self { True | False => False, @@ -151,7 +160,7 @@ where input_schema, evaluates_to_null, } - .const_eval_predicate(predicate) + .const_eval_predicate_coerced(predicate) .map(|b| matches!(b, True)) } @@ -164,9 +173,16 @@ impl PredicateConstEvaluator<'_, F> where F: Fn(&Expr) -> Option, { - fn const_eval_predicate(&self, predicate: &Expr) -> Option { + fn const_eval_predicate_coerced(&self, predicate: &Expr) -> Option { match predicate { Expr::Literal(scalar, _) => TriStateBool::try_from(scalar).ok(), + e => self.const_eval_predicate(e), + } + } + + fn const_eval_predicate(&self, predicate: &Expr) -> Option { + match predicate { + Expr::Literal(scalar, _) => TriStateBool::try_from_no_cooerce(scalar), Expr::IsNotNull(e) => { if let Ok(false) = e.nullable(self.input_schema) { // If `e` is not nullable -> `e IS NOT NULL` is true @@ -294,9 +310,418 @@ where False } else { // Finally, ask the callback if it knows the nullness of `expr` - (self.evaluates_to_null)(e).into() + let evaluates_to_null = (self.evaluates_to_null)(e); + TriStateBool::from(&evaluates_to_null) } } } } } + +#[cfg(test)] +mod tests { + use crate::expr::ScalarFunction; + use crate::predicate_eval::TriStateBool::*; + use crate::predicate_eval::{const_eval_predicate, TriStateBool}; + use crate::{ + binary_expr, create_udf, is_false, is_not_false, is_not_null, is_not_true, + is_not_unknown, is_null, is_true, is_unknown, lit, not, Expr, + }; + use arrow::datatypes::{DataType, Schema}; + use datafusion_common::{DFSchema, ScalarValue}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::operator::Operator; + use datafusion_expr_common::signature::Volatility; + use std::sync::Arc; + use Operator::{And, Or}; + + #[test] + fn tristate_bool_from_option() { + assert_eq!(TriStateBool::from(&None), Uncertain); + assert_eq!(TriStateBool::from(&Some(true)), True); + assert_eq!(TriStateBool::from(&Some(false)), False); + } + + #[test] + fn tristate_bool_from_scalar() { + assert_eq!( + TriStateBool::try_from(&ScalarValue::Null).unwrap(), + Uncertain + ); + + assert_eq!( + TriStateBool::try_from(&ScalarValue::Boolean(None)).unwrap(), + Uncertain + ); + assert_eq!( + TriStateBool::try_from(&ScalarValue::Boolean(Some(true))).unwrap(), + True + ); + assert_eq!( + TriStateBool::try_from(&ScalarValue::Boolean(Some(false))).unwrap(), + False + ); + + assert_eq!( + TriStateBool::try_from(&ScalarValue::UInt8(None)).unwrap(), + Uncertain + ); + assert_eq!( + TriStateBool::try_from(&ScalarValue::UInt8(Some(0))).unwrap(), + False + ); + assert_eq!( + TriStateBool::try_from(&ScalarValue::UInt8(Some(1))).unwrap(), + True + ); + } + + #[test] + fn tristate_bool_from_scalar_no_cooerce() { + assert_eq!( + TriStateBool::try_from_no_cooerce(&ScalarValue::Null).unwrap(), + Uncertain + ); + + assert_eq!( + TriStateBool::try_from_no_cooerce(&ScalarValue::Boolean(None)).unwrap(), + Uncertain + ); + assert_eq!( + TriStateBool::try_from_no_cooerce(&ScalarValue::Boolean(Some(true))).unwrap(), + True + ); + assert_eq!( + TriStateBool::try_from_no_cooerce(&ScalarValue::Boolean(Some(false))) + .unwrap(), + False + ); + + assert_eq!( + TriStateBool::try_from_no_cooerce(&ScalarValue::UInt8(None)), + None + ); + assert_eq!( + TriStateBool::try_from_no_cooerce(&ScalarValue::UInt8(Some(0))), + None + ); + assert_eq!( + TriStateBool::try_from_no_cooerce(&ScalarValue::UInt8(Some(1))), + None + ); + } + + #[test] + fn tristate_bool_not() { + assert_eq!(!Uncertain, Uncertain); + assert_eq!(!False, True); + assert_eq!(!True, False); + } + + #[test] + fn tristate_bool_and() { + assert_eq!(Uncertain & Uncertain, Uncertain); + assert_eq!(Uncertain & True, Uncertain); + assert_eq!(Uncertain & False, False); + assert_eq!(True & Uncertain, Uncertain); + assert_eq!(True & True, True); + assert_eq!(True & False, False); + assert_eq!(False & Uncertain, False); + assert_eq!(False & True, False); + assert_eq!(False & False, False); + } + + #[test] + fn tristate_bool_or() { + assert_eq!(Uncertain | Uncertain, Uncertain); + assert_eq!(Uncertain | True, True); + assert_eq!(Uncertain | False, Uncertain); + assert_eq!(True | Uncertain, True); + assert_eq!(True | True, True); + assert_eq!(True | False, True); + assert_eq!(False | Uncertain, Uncertain); + assert_eq!(False | True, True); + assert_eq!(False | False, False); + } + + fn const_eval(predicate: &Expr) -> Option { + let schema = DFSchema::try_from(Schema::empty()).unwrap(); + const_eval_predicate(predicate, |_| None, &schema) + } + + #[test] + fn predicate_eval_literal() { + assert_eq!(const_eval(&lit(ScalarValue::Null)), Some(false)); + + assert_eq!(const_eval(&lit(false)), Some(false)); + assert_eq!(const_eval(&lit(true)), Some(true)); + + assert_eq!(const_eval(&lit(0)), Some(false)); + assert_eq!(const_eval(&lit(1)), Some(true)); + + assert_eq!(const_eval(&lit("foo")), None); + assert_eq!(const_eval(&lit(ScalarValue::Utf8(None))), Some(false)); + } + + #[test] + fn predicate_eval_and() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + let func = make_scalar_func_expr(); + + assert_eq!( + const_eval(&binary_expr(null.clone(), And, null.clone())), + Some(false) + ); + assert_eq!( + const_eval(&binary_expr(null.clone(), And, one.clone())), + None + ); + assert_eq!( + const_eval(&binary_expr(null.clone(), And, zero.clone())), + None + ); + + assert_eq!( + const_eval(&binary_expr(one.clone(), And, one.clone())), + None + ); + assert_eq!( + const_eval(&binary_expr(one.clone(), And, zero.clone())), + None + ); + + assert_eq!( + const_eval(&binary_expr(null.clone(), And, t.clone())), + Some(false) + ); + assert_eq!( + const_eval(&binary_expr(t.clone(), And, null.clone())), + Some(false) + ); + assert_eq!( + const_eval(&binary_expr(null.clone(), And, f.clone())), + Some(false) + ); + assert_eq!( + const_eval(&binary_expr(f.clone(), And, null.clone())), + Some(false) + ); + + assert_eq!( + const_eval(&binary_expr(t.clone(), And, t.clone())), + Some(true) + ); + assert_eq!( + const_eval(&binary_expr(t.clone(), And, f.clone())), + Some(false) + ); + assert_eq!( + const_eval(&binary_expr(f.clone(), And, t.clone())), + Some(false) + ); + assert_eq!( + const_eval(&binary_expr(f.clone(), And, f.clone())), + Some(false) + ); + + assert_eq!(const_eval(&binary_expr(t.clone(), And, func.clone())), None); + assert_eq!(const_eval(&binary_expr(func.clone(), And, t.clone())), None); + assert_eq!( + const_eval(&binary_expr(f.clone(), And, func.clone())), + Some(false) + ); + assert_eq!( + const_eval(&binary_expr(func.clone(), And, f.clone())), + Some(false) + ); + assert_eq!( + const_eval(&binary_expr(null.clone(), And, func.clone())), + None + ); + assert_eq!( + const_eval(&binary_expr(func.clone(), And, null.clone())), + None + ); + } + + #[test] + fn predicate_eval_or() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + let func = make_scalar_func_expr(); + + assert_eq!( + const_eval(&binary_expr(null.clone(), Or, null.clone())), + Some(false) + ); + assert_eq!( + const_eval(&binary_expr(null.clone(), Or, one.clone())), + None + ); + assert_eq!( + const_eval(&binary_expr(null.clone(), Or, zero.clone())), + None + ); + + assert_eq!(const_eval(&binary_expr(one.clone(), Or, one.clone())), None); + assert_eq!( + const_eval(&binary_expr(one.clone(), Or, zero.clone())), + None + ); + + assert_eq!( + const_eval(&binary_expr(null.clone(), Or, t.clone())), + Some(true) + ); + assert_eq!( + const_eval(&binary_expr(t.clone(), Or, null.clone())), + Some(true) + ); + assert_eq!( + const_eval(&binary_expr(null.clone(), Or, f.clone())), + Some(false) + ); + assert_eq!( + const_eval(&binary_expr(f.clone(), Or, null.clone())), + Some(false) + ); + + assert_eq!( + const_eval(&binary_expr(t.clone(), Or, t.clone())), + Some(true) + ); + assert_eq!( + const_eval(&binary_expr(t.clone(), Or, f.clone())), + Some(true) + ); + assert_eq!( + const_eval(&binary_expr(f.clone(), Or, t.clone())), + Some(true) + ); + assert_eq!( + const_eval(&binary_expr(f.clone(), Or, f.clone())), + Some(false) + ); + + assert_eq!( + const_eval(&binary_expr(t.clone(), Or, func.clone())), + Some(true) + ); + assert_eq!( + const_eval(&binary_expr(func.clone(), Or, t.clone())), + Some(true) + ); + assert_eq!(const_eval(&binary_expr(f.clone(), Or, func.clone())), None); + assert_eq!(const_eval(&binary_expr(func.clone(), Or, f.clone())), None); + assert_eq!( + const_eval(&binary_expr(null.clone(), Or, func.clone())), + None + ); + assert_eq!( + const_eval(&binary_expr(func.clone(), Or, null.clone())), + None + ); + } + + #[test] + fn predicate_eval_not() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + let func = make_scalar_func_expr(); + + assert_eq!(const_eval(¬(null.clone())), Some(false)); + assert_eq!(const_eval(¬(one.clone())), None); + assert_eq!(const_eval(¬(zero.clone())), None); + + assert_eq!(const_eval(¬(t.clone())), Some(false)); + assert_eq!(const_eval(¬(f.clone())), Some(true)); + + assert_eq!(const_eval(¬(func.clone())), None); + } + + #[test] + fn predicate_eval_is() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + + assert_eq!(const_eval(&is_null(null.clone())), Some(true)); + assert_eq!(const_eval(&is_null(one.clone())), Some(false)); + + assert_eq!(const_eval(&is_not_null(null.clone())), Some(false)); + assert_eq!(const_eval(&is_not_null(one.clone())), Some(true)); + + assert_eq!(const_eval(&is_true(null.clone())), Some(false)); + assert_eq!(const_eval(&is_true(t.clone())), Some(true)); + assert_eq!(const_eval(&is_true(f.clone())), Some(false)); + assert_eq!(const_eval(&is_true(zero.clone())), None); + assert_eq!(const_eval(&is_true(one.clone())), None); + + assert_eq!(const_eval(&is_not_true(null.clone())), Some(true)); + assert_eq!(const_eval(&is_not_true(t.clone())), Some(false)); + assert_eq!(const_eval(&is_not_true(f.clone())), Some(true)); + assert_eq!(const_eval(&is_not_true(zero.clone())), None); + assert_eq!(const_eval(&is_not_true(one.clone())), None); + + assert_eq!(const_eval(&is_false(null.clone())), Some(false)); + assert_eq!(const_eval(&is_false(t.clone())), Some(false)); + assert_eq!(const_eval(&is_false(f.clone())), Some(true)); + assert_eq!(const_eval(&is_false(zero.clone())), None); + assert_eq!(const_eval(&is_false(one.clone())), None); + + assert_eq!(const_eval(&is_not_false(null.clone())), Some(true)); + assert_eq!(const_eval(&is_not_false(t.clone())), Some(true)); + assert_eq!(const_eval(&is_not_false(f.clone())), Some(false)); + assert_eq!(const_eval(&is_not_false(zero.clone())), None); + assert_eq!(const_eval(&is_not_false(one.clone())), None); + + assert_eq!(const_eval(&is_unknown(null.clone())), Some(true)); + assert_eq!(const_eval(&is_unknown(t.clone())), Some(false)); + assert_eq!(const_eval(&is_unknown(f.clone())), Some(false)); + assert_eq!(const_eval(&is_unknown(zero.clone())), None); + assert_eq!(const_eval(&is_unknown(one.clone())), None); + + assert_eq!(const_eval(&is_not_unknown(null.clone())), Some(false)); + assert_eq!(const_eval(&is_not_unknown(t.clone())), Some(true)); + assert_eq!(const_eval(&is_not_unknown(f.clone())), Some(true)); + assert_eq!(const_eval(&is_not_unknown(zero.clone())), None); + assert_eq!(const_eval(&is_not_unknown(one.clone())), None); + } + + #[test] + fn predicate_eval_udf() { + let func = make_scalar_func_expr(); + + assert_eq!(const_eval(&func.clone()), None); + assert_eq!(const_eval(¬(func.clone())), None); + assert_eq!( + const_eval(&binary_expr(func.clone(), And, func.clone())), + None + ); + } + + fn make_scalar_func_expr() -> Expr { + let scalar_func_impl = + |_: &[ColumnarValue]| Ok(ColumnarValue::Scalar(ScalarValue::Null)); + let udf = create_udf( + "foo", + vec![], + DataType::Boolean, + Volatility::Stable, + Arc::new(scalar_func_impl), + ); + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), vec![])) + } +} From 9dee1e898e416315e29b2828a7ea100974ed9bbf Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sat, 8 Nov 2025 00:14:24 +0100 Subject: [PATCH 15/23] Another attempt to make the code easier to read --- datafusion/expr/src/predicate_eval.rs | 431 +++++++++++++++----------- 1 file changed, 243 insertions(+), 188 deletions(-) diff --git a/datafusion/expr/src/predicate_eval.rs b/datafusion/expr/src/predicate_eval.rs index 4e3cbafcc2e4..02d66b515d94 100644 --- a/datafusion/expr/src/predicate_eval.rs +++ b/datafusion/expr/src/predicate_eval.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::predicate_eval::TriStateBool::{False, True, Uncertain}; +use crate::predicate_eval::TriStateBool::{False, True, Unknown}; use crate::{BinaryExpr, Expr, ExprSchemable}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{DataFusionError, ExprSchema, ScalarValue}; +use datafusion_common::{ExprSchema, ScalarValue}; use datafusion_expr_common::operator::Operator; use std::ops::{BitAnd, BitOr, Not}; @@ -28,71 +28,48 @@ use std::ops::{BitAnd, BitOr, Not}; enum TriStateBool { True, False, - Uncertain, + Unknown, } impl From<&Option> for TriStateBool { fn from(value: &Option) -> Self { match value { - None => Uncertain, + None => Unknown, Some(true) => True, Some(false) => False, } } } -impl TryFrom<&ScalarValue> for TriStateBool { - type Error = DataFusionError; - - fn try_from(value: &ScalarValue) -> Result { - match value { - ScalarValue::Null => { - // Literal null is equivalent to boolean uncertain - Ok(Uncertain) - } - ScalarValue::Boolean(b) => Ok(match b { - Some(true) => True, - Some(false) => False, - None => Uncertain, - }), - _ => Self::try_from(&value.cast_to(&DataType::Boolean)?), - } - } -} - impl TriStateBool { - fn try_from_no_cooerce(value: &ScalarValue) -> Option { + fn try_from(value: &ScalarValue) -> Option { match value { - ScalarValue::Null => Some(Uncertain), + ScalarValue::Null => Some(Unknown), ScalarValue::Boolean(b) => Some(TriStateBool::from(b)), - _ => None, - } - } - - fn is_null(&self) -> TriStateBool { - match self { - True | False => False, - Uncertain => True, + _ => Self::try_from(&value.cast_to(&DataType::Boolean).ok()?), } } + /// [True] if self is [True], [False] otherwise. fn is_true(&self) -> TriStateBool { match self { True => True, - Uncertain | False => False, + Unknown | False => False, } } + /// [True] if self is [False], [False] otherwise. fn is_false(&self) -> TriStateBool { match self { False => True, - Uncertain | True => False, + Unknown | True => False, } } + /// [True] if self is [Unknown], [False] otherwise. fn is_unknown(&self) -> TriStateBool { match self { - Uncertain => True, + Unknown => True, True | False => False, } } @@ -102,10 +79,17 @@ impl Not for TriStateBool { type Output = TriStateBool; fn not(self) -> Self::Output { + // SQL three-valued logic NOT truth table: + // + // P | !P + // --------|------- + // TRUE | FALSE + // FALSE | TRUE + // UNKNOWN | UNKNOWN match self { True => False, False => True, - Uncertain => Uncertain, + Unknown => Unknown, } } } @@ -114,10 +98,23 @@ impl BitAnd for TriStateBool { type Output = TriStateBool; fn bitand(self, rhs: Self) -> Self::Output { + // SQL three-valued logic AND truth table: + // + // P | Q | P AND Q + // --------|---------|---------- + // TRUE | TRUE | TRUE + // TRUE | FALSE | FALSE + // FALSE | TRUE | FALSE + // FALSE | FALSE | FALSE + // FALSE | UNKNOWN | FALSE + // UNKNOWN | FALSE | FALSE + // TRUE | UNKNOWN | UNKNOWN + // UNKNOWN | TRUE | UNKNOWN + // UNKNOWN | UNKNOWN | UNKNOWN match (self, rhs) { - (False, _) | (_, False) => False, - (Uncertain, _) | (_, Uncertain) => Uncertain, (True, True) => True, + (False, _) | (_, False) => False, + (Unknown, _) | (_, Unknown) => Unknown, } } } @@ -126,19 +123,32 @@ impl BitOr for TriStateBool { type Output = TriStateBool; fn bitor(self, rhs: Self) -> Self::Output { + // SQL three-valued logic OR truth table: + // + // P | Q | P OR Q + // --------|---------|---------- + // FALSE | FALSE | FALSE + // TRUE | TRUE | TRUE + // TRUE | FALSE | TRUE + // FALSE | TRUE | TRUE + // TRUE | UNKNOWN | TRUE + // UNKNOWN | TRUE | TRUE + // FALSE | UNKNOWN | UNKNOWN + // UNKNOWN | FALSE | UNKNOWN + // UNKNOWN | UNKNOWN | UNKNOWN match (self, rhs) { - (True, _) | (_, True) => True, - (Uncertain, _) | (_, Uncertain) => Uncertain, (False, False) => False, + (True, _) | (_, True) => True, + (Unknown, _) | (_, Unknown) => Unknown, } } } -/// Attempts to partially constant-evaluate a predicate under SQL three-valued logic. +/// Attempts to const evaluate a predicate using SQL three-valued logic. /// /// Semantics of the return value: -/// - `Some(true)` => predicate is provably true -/// - `Some(false)` => predicate is provably false +/// - `Some(true)` => predicate is provably truthy +/// - `Some(false)` => predicate is provably falsy /// - `None` => inconclusive with available static information /// /// The evaluation is conservative and only uses: @@ -160,7 +170,7 @@ where input_schema, evaluates_to_null, } - .const_eval_predicate_coerced(predicate) + .eval_predicate(predicate) .map(|b| matches!(b, True)) } @@ -173,72 +183,77 @@ impl PredicateConstEvaluator<'_, F> where F: Fn(&Expr) -> Option, { - fn const_eval_predicate_coerced(&self, predicate: &Expr) -> Option { + /// Attempts to const evaluate a boolean predicate. + fn eval_predicate(&self, predicate: &Expr) -> Option { match predicate { - Expr::Literal(scalar, _) => TriStateBool::try_from(scalar).ok(), - e => self.const_eval_predicate(e), - } - } - - fn const_eval_predicate(&self, predicate: &Expr) -> Option { - match predicate { - Expr::Literal(scalar, _) => TriStateBool::try_from_no_cooerce(scalar), - Expr::IsNotNull(e) => { + Expr::Literal(scalar, _) => { + // Interpret literals as boolean, coercing if necessary and allowed + TriStateBool::try_from(scalar) + } + Expr::Negative(e) => self.eval_predicate(e), + Expr::IsNull(e) => { + // If `e` is not nullable, then `e IS NULL` is provably false if let Ok(false) = e.nullable(self.input_schema) { - // If `e` is not nullable -> `e IS NOT NULL` is true - return Some(True); + return Some(False); } match e.get_type(self.input_schema) { + // If `e` is a boolean expression, try to evaluate it and test for unknown Ok(DataType::Boolean) => { - self.const_eval_predicate(e).map(|b| b.is_null()) + self.eval_predicate(e).map(|b| b.is_unknown()) } - Ok(_) => match self.evaluates_to_null(e) { - True => Some(False), - False => Some(True), - Uncertain => None, + // If `e` is not a boolean expression, check if `e` is provably null + Ok(_) => match self.is_null(e) { + // If `e` is provably null, then `e IS NULL` is provably true + True => Some(True), + // If `e` is provably not null, then `e IS NULL` is provably false + False => Some(False), + Unknown => None, }, Err(_) => None, } } - Expr::IsNull(e) => { + Expr::IsNotNull(e) => { + // If `e` is not nullable, then `e IS NOT NULL` is provably true if let Ok(false) = e.nullable(self.input_schema) { - // If `e` is not nullable -> `e IS NULL` is false - return Some(False); + // If `e` is not nullable -> `e IS NOT NULL` is true + return Some(True); } match e.get_type(self.input_schema) { + // If `e` is a boolean expression, try to evaluate it and test for not unknown Ok(DataType::Boolean) => { - self.const_eval_predicate(e).map(|b| !b.is_null()) + self.eval_predicate(e).map(|b| !b.is_unknown()) } - Ok(_) => match self.evaluates_to_null(e) { - True => Some(True), - False => Some(False), - Uncertain => None, + // If `e` is not a boolean expression, check if `e` is provably null + Ok(_) => match self.is_null(e) { + // If `e` is provably null, then `e IS NOT NULL` is provably false + True => Some(False), + // If `e` is provably not null, then `e IS NOT NULL` is provably true + False => Some(True), + Unknown => None, }, Err(_) => None, } } - Expr::IsTrue(e) => self.const_eval_predicate(e).map(|b| b.is_true()), - Expr::IsNotTrue(e) => self.const_eval_predicate(e).map(|b| !b.is_true()), - Expr::IsFalse(e) => self.const_eval_predicate(e).map(|b| b.is_false()), - Expr::IsNotFalse(e) => self.const_eval_predicate(e).map(|b| !b.is_false()), - Expr::IsUnknown(e) => self.const_eval_predicate(e).map(|b| b.is_unknown()), - Expr::IsNotUnknown(e) => { - self.const_eval_predicate(e).map(|b| !b.is_unknown()) - } - Expr::Not(e) => self.const_eval_predicate(e).map(|b| !b), + Expr::IsTrue(e) => self.eval_predicate(e).map(|b| b.is_true()), + Expr::IsNotTrue(e) => self.eval_predicate(e).map(|b| !b.is_true()), + Expr::IsFalse(e) => self.eval_predicate(e).map(|b| b.is_false()), + Expr::IsNotFalse(e) => self.eval_predicate(e).map(|b| !b.is_false()), + Expr::IsUnknown(e) => self.eval_predicate(e).map(|b| b.is_unknown()), + Expr::IsNotUnknown(e) => self.eval_predicate(e).map(|b| !b.is_unknown()), + Expr::Not(e) => self.eval_predicate(e).map(|b| !b), Expr::BinaryExpr(BinaryExpr { left, op: Operator::And, right, }) => { - match ( - self.const_eval_predicate(left), - self.const_eval_predicate(right), - ) { + match (self.eval_predicate(left), self.eval_predicate(right)) { + // If either side is false, then the result is false regardless of the other side (Some(False), _) | (_, Some(False)) => Some(False), + // If either side is inconclusive, then the result is inconclusive as well (None, _) | (_, None) => None, + // Otherwise, defer to the tristate boolean algebra (Some(l), Some(r)) => Some(l & r), } } @@ -247,32 +262,28 @@ where op: Operator::Or, right, }) => { - match ( - self.const_eval_predicate(left), - self.const_eval_predicate(right), - ) { + match (self.eval_predicate(left), self.eval_predicate(right)) { + // If either side is true, then the result is true regardless of the other side (Some(True), _) | (_, Some(True)) => Some(True), + // If either side is inconclusive, then the result is inconclusive as well (None, _) | (_, None) => None, + // Otherwise, defer to the tristate boolean algebra (Some(l), Some(r)) => Some(l | r), } } - e => match self.evaluates_to_null(e) { - True => Some(Uncertain), - _ => None, + e => match self.is_null(e) { + // Null values coerce to unknown + True => Some(Unknown), + // Not null, but some unknown value -> inconclusive + False | Unknown => None, }, } } /// Determines if the given expression evaluates to `NULL`. - /// - /// This function returns: - /// - `True` if `expr` is provably `NULL` - /// - `False` if `expr` is provably not `NULL` - /// - `Uncertain` if the result is inconclusive - fn evaluates_to_null(&self, expr: &Expr) -> TriStateBool { + fn is_null(&self, expr: &Expr) -> TriStateBool { match expr { Expr::Literal(s, _) => { - // Literal null is obviously null if s.is_null() { True } else { @@ -290,18 +301,15 @@ where // These expressions are null if any of their direct children is null // If any child is inconclusive, the result for this expression is also inconclusive let mut is_null = False; - let _ = - expr.apply_children(|child| match self.evaluates_to_null(child) { - True => { - is_null = True; - Ok(TreeNodeRecursion::Stop) - } - False => Ok(TreeNodeRecursion::Continue), - Uncertain => { - is_null = Uncertain; - Ok(TreeNodeRecursion::Stop) - } - }); + let _ = expr.apply_children(|child| match self.is_null(child) { + False => Ok(TreeNodeRecursion::Continue), + n @ True | n @ Unknown => { + // If any child is null or inconclusive, this result applies to the + // entire expression and we can stop traversing + is_null = n; + Ok(TreeNodeRecursion::Stop) + } + }); is_null } e => { @@ -310,8 +318,11 @@ where False } else { // Finally, ask the callback if it knows the nullness of `expr` - let evaluates_to_null = (self.evaluates_to_null)(e); - TriStateBool::from(&evaluates_to_null) + match (self.evaluates_to_null)(e) { + Some(true) => True, + Some(false) => False, + None => Unknown, + } } } } @@ -324,34 +335,32 @@ mod tests { use crate::predicate_eval::TriStateBool::*; use crate::predicate_eval::{const_eval_predicate, TriStateBool}; use crate::{ - binary_expr, create_udf, is_false, is_not_false, is_not_null, is_not_true, + binary_expr, col, create_udf, is_false, is_not_false, is_not_null, is_not_true, is_not_unknown, is_null, is_true, is_unknown, lit, not, Expr, }; - use arrow::datatypes::{DataType, Schema}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::operator::Operator; + use datafusion_expr_common::operator::Operator::Eq; use datafusion_expr_common::signature::Volatility; use std::sync::Arc; use Operator::{And, Or}; #[test] fn tristate_bool_from_option() { - assert_eq!(TriStateBool::from(&None), Uncertain); + assert_eq!(TriStateBool::from(&None), Unknown); assert_eq!(TriStateBool::from(&Some(true)), True); assert_eq!(TriStateBool::from(&Some(false)), False); } #[test] fn tristate_bool_from_scalar() { - assert_eq!( - TriStateBool::try_from(&ScalarValue::Null).unwrap(), - Uncertain - ); + assert_eq!(TriStateBool::try_from(&ScalarValue::Null).unwrap(), Unknown); assert_eq!( TriStateBool::try_from(&ScalarValue::Boolean(None)).unwrap(), - Uncertain + Unknown ); assert_eq!( TriStateBool::try_from(&ScalarValue::Boolean(Some(true))).unwrap(), @@ -364,7 +373,7 @@ mod tests { assert_eq!( TriStateBool::try_from(&ScalarValue::UInt8(None)).unwrap(), - Uncertain + Unknown ); assert_eq!( TriStateBool::try_from(&ScalarValue::UInt8(Some(0))).unwrap(), @@ -376,70 +385,35 @@ mod tests { ); } - #[test] - fn tristate_bool_from_scalar_no_cooerce() { - assert_eq!( - TriStateBool::try_from_no_cooerce(&ScalarValue::Null).unwrap(), - Uncertain - ); - - assert_eq!( - TriStateBool::try_from_no_cooerce(&ScalarValue::Boolean(None)).unwrap(), - Uncertain - ); - assert_eq!( - TriStateBool::try_from_no_cooerce(&ScalarValue::Boolean(Some(true))).unwrap(), - True - ); - assert_eq!( - TriStateBool::try_from_no_cooerce(&ScalarValue::Boolean(Some(false))) - .unwrap(), - False - ); - - assert_eq!( - TriStateBool::try_from_no_cooerce(&ScalarValue::UInt8(None)), - None - ); - assert_eq!( - TriStateBool::try_from_no_cooerce(&ScalarValue::UInt8(Some(0))), - None - ); - assert_eq!( - TriStateBool::try_from_no_cooerce(&ScalarValue::UInt8(Some(1))), - None - ); - } - #[test] fn tristate_bool_not() { - assert_eq!(!Uncertain, Uncertain); + assert_eq!(!Unknown, Unknown); assert_eq!(!False, True); assert_eq!(!True, False); } #[test] fn tristate_bool_and() { - assert_eq!(Uncertain & Uncertain, Uncertain); - assert_eq!(Uncertain & True, Uncertain); - assert_eq!(Uncertain & False, False); - assert_eq!(True & Uncertain, Uncertain); + assert_eq!(Unknown & Unknown, Unknown); + assert_eq!(Unknown & True, Unknown); + assert_eq!(Unknown & False, False); + assert_eq!(True & Unknown, Unknown); assert_eq!(True & True, True); assert_eq!(True & False, False); - assert_eq!(False & Uncertain, False); + assert_eq!(False & Unknown, False); assert_eq!(False & True, False); assert_eq!(False & False, False); } #[test] fn tristate_bool_or() { - assert_eq!(Uncertain | Uncertain, Uncertain); - assert_eq!(Uncertain | True, True); - assert_eq!(Uncertain | False, Uncertain); - assert_eq!(True | Uncertain, True); + assert_eq!(Unknown | Unknown, Unknown); + assert_eq!(Unknown | True, True); + assert_eq!(Unknown | False, Unknown); + assert_eq!(True | Unknown, True); assert_eq!(True | True, True); assert_eq!(True | False, True); - assert_eq!(False | Uncertain, Uncertain); + assert_eq!(False | Unknown, Unknown); assert_eq!(False | True, True); assert_eq!(False | False, False); } @@ -449,6 +423,24 @@ mod tests { const_eval_predicate(predicate, |_| None, &schema) } + fn const_eval_with_null( + predicate: &Expr, + schema: &DFSchema, + null_expr: &Expr, + ) -> Option { + const_eval_predicate( + predicate, + |e| { + if e.eq(null_expr) { + Some(true) + } else { + None + } + }, + schema, + ) + } + #[test] fn predicate_eval_literal() { assert_eq!(const_eval(&lit(ScalarValue::Null)), Some(false)); @@ -478,20 +470,20 @@ mod tests { ); assert_eq!( const_eval(&binary_expr(null.clone(), And, one.clone())), - None + Some(false) ); assert_eq!( const_eval(&binary_expr(null.clone(), And, zero.clone())), - None + Some(false) ); assert_eq!( const_eval(&binary_expr(one.clone(), And, one.clone())), - None + Some(true) ); assert_eq!( const_eval(&binary_expr(one.clone(), And, zero.clone())), - None + Some(false) ); assert_eq!( @@ -563,17 +555,20 @@ mod tests { ); assert_eq!( const_eval(&binary_expr(null.clone(), Or, one.clone())), - None + Some(true) ); assert_eq!( const_eval(&binary_expr(null.clone(), Or, zero.clone())), - None + Some(false) ); - assert_eq!(const_eval(&binary_expr(one.clone(), Or, one.clone())), None); + assert_eq!( + const_eval(&binary_expr(one.clone(), Or, one.clone())), + Some(true) + ); assert_eq!( const_eval(&binary_expr(one.clone(), Or, zero.clone())), - None + Some(true) ); assert_eq!( @@ -640,8 +635,8 @@ mod tests { let func = make_scalar_func_expr(); assert_eq!(const_eval(¬(null.clone())), Some(false)); - assert_eq!(const_eval(¬(one.clone())), None); - assert_eq!(const_eval(¬(zero.clone())), None); + assert_eq!(const_eval(¬(one.clone())), Some(false)); + assert_eq!(const_eval(¬(zero.clone())), Some(true)); assert_eq!(const_eval(¬(t.clone())), Some(false)); assert_eq!(const_eval(¬(f.clone())), Some(true)); @@ -656,48 +651,77 @@ mod tests { let one = lit(1); let t = lit(true); let f = lit(false); + let col = col("col"); + let nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "col", + DataType::UInt8, + true, + )])) + .unwrap(); + let not_nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "col", + DataType::UInt8, + false, + )])) + .unwrap(); assert_eq!(const_eval(&is_null(null.clone())), Some(true)); assert_eq!(const_eval(&is_null(one.clone())), Some(false)); + assert_eq!( + const_eval_with_null(&is_null(col.clone()), &nullable_schema, &col), + Some(true) + ); + assert_eq!( + const_eval_with_null(&is_null(col.clone()), ¬_nullable_schema, &col), + Some(false) + ); assert_eq!(const_eval(&is_not_null(null.clone())), Some(false)); assert_eq!(const_eval(&is_not_null(one.clone())), Some(true)); + assert_eq!( + const_eval_with_null(&is_not_null(col.clone()), &nullable_schema, &col), + Some(false) + ); + assert_eq!( + const_eval_with_null(&is_not_null(col.clone()), ¬_nullable_schema, &col), + Some(true) + ); assert_eq!(const_eval(&is_true(null.clone())), Some(false)); assert_eq!(const_eval(&is_true(t.clone())), Some(true)); assert_eq!(const_eval(&is_true(f.clone())), Some(false)); - assert_eq!(const_eval(&is_true(zero.clone())), None); - assert_eq!(const_eval(&is_true(one.clone())), None); + assert_eq!(const_eval(&is_true(zero.clone())), Some(false)); + assert_eq!(const_eval(&is_true(one.clone())), Some(true)); assert_eq!(const_eval(&is_not_true(null.clone())), Some(true)); assert_eq!(const_eval(&is_not_true(t.clone())), Some(false)); assert_eq!(const_eval(&is_not_true(f.clone())), Some(true)); - assert_eq!(const_eval(&is_not_true(zero.clone())), None); - assert_eq!(const_eval(&is_not_true(one.clone())), None); + assert_eq!(const_eval(&is_not_true(zero.clone())), Some(true)); + assert_eq!(const_eval(&is_not_true(one.clone())), Some(false)); assert_eq!(const_eval(&is_false(null.clone())), Some(false)); assert_eq!(const_eval(&is_false(t.clone())), Some(false)); assert_eq!(const_eval(&is_false(f.clone())), Some(true)); - assert_eq!(const_eval(&is_false(zero.clone())), None); - assert_eq!(const_eval(&is_false(one.clone())), None); + assert_eq!(const_eval(&is_false(zero.clone())), Some(true)); + assert_eq!(const_eval(&is_false(one.clone())), Some(false)); assert_eq!(const_eval(&is_not_false(null.clone())), Some(true)); assert_eq!(const_eval(&is_not_false(t.clone())), Some(true)); assert_eq!(const_eval(&is_not_false(f.clone())), Some(false)); - assert_eq!(const_eval(&is_not_false(zero.clone())), None); - assert_eq!(const_eval(&is_not_false(one.clone())), None); + assert_eq!(const_eval(&is_not_false(zero.clone())), Some(false)); + assert_eq!(const_eval(&is_not_false(one.clone())), Some(true)); assert_eq!(const_eval(&is_unknown(null.clone())), Some(true)); assert_eq!(const_eval(&is_unknown(t.clone())), Some(false)); assert_eq!(const_eval(&is_unknown(f.clone())), Some(false)); - assert_eq!(const_eval(&is_unknown(zero.clone())), None); - assert_eq!(const_eval(&is_unknown(one.clone())), None); + assert_eq!(const_eval(&is_unknown(zero.clone())), Some(false)); + assert_eq!(const_eval(&is_unknown(one.clone())), Some(false)); assert_eq!(const_eval(&is_not_unknown(null.clone())), Some(false)); assert_eq!(const_eval(&is_not_unknown(t.clone())), Some(true)); assert_eq!(const_eval(&is_not_unknown(f.clone())), Some(true)); - assert_eq!(const_eval(&is_not_unknown(zero.clone())), None); - assert_eq!(const_eval(&is_not_unknown(one.clone())), None); + assert_eq!(const_eval(&is_not_unknown(zero.clone())), Some(true)); + assert_eq!(const_eval(&is_not_unknown(one.clone())), Some(true)); } #[test] @@ -724,4 +748,35 @@ mod tests { ); Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), vec![])) } + + #[test] + fn predicate_eval_when_then() { + let nullable_schema = + DFSchema::try_from(Schema::new(vec![Field::new("x", DataType::UInt8, true)])) + .unwrap(); + let not_nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "x", + DataType::UInt8, + false, + )])) + .unwrap(); + + let x = col("x"); + + // CASE WHEN x IS NOT NULL OR x = 5 THEN x ELSE 0 END + let when = binary_expr( + is_not_null(x.clone()), + Or, + binary_expr(x.clone(), Eq, lit(5)), + ); + + assert_eq!( + const_eval_with_null(&when, &nullable_schema, &x), + Some(false) + ); + assert_eq!( + const_eval_with_null(&when, ¬_nullable_schema, &x), + Some(true) + ); + } } From 4a22dfc0043b961ca9824ef3574fdb13b7a90079 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sat, 8 Nov 2025 16:59:25 +0100 Subject: [PATCH 16/23] Rework predicate_eval to use set arithmetic --- Cargo.lock | 1 + datafusion/expr/Cargo.toml | 1 + datafusion/expr/src/predicate_eval.rs | 667 ++++++++++++++++---------- 3 files changed, 426 insertions(+), 243 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f500265108ff..821a4a2bcc53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2253,6 +2253,7 @@ version = "50.3.0" dependencies = [ "arrow", "async-trait", + "bitflags 2.9.4", "chrono", "ctor", "datafusion-common", diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index e6b2734cfff3..0b313429a248 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -45,6 +45,7 @@ sql = ["sqlparser"] [dependencies] arrow = { workspace = true } async-trait = { workspace = true } +bitflags = "2.9.4" chrono = { workspace = true } datafusion-common = { workspace = true, default-features = false } datafusion-doc = { workspace = true } diff --git a/datafusion/expr/src/predicate_eval.rs b/datafusion/expr/src/predicate_eval.rs index 02d66b515d94..c29ca6839070 100644 --- a/datafusion/expr/src/predicate_eval.rs +++ b/datafusion/expr/src/predicate_eval.rs @@ -15,132 +15,185 @@ // specific language governing permissions and limitations // under the License. -use crate::predicate_eval::TriStateBool::{False, True, Unknown}; use crate::{BinaryExpr, Expr, ExprSchemable}; use arrow::datatypes::DataType; +use bitflags::bitflags; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{ExprSchema, ScalarValue}; use datafusion_expr_common::operator::Operator; -use std::ops::{BitAnd, BitOr, Not}; - -/// Represents the possible values for SQL's three valued logic. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -enum TriStateBool { - True, - False, - Unknown, -} -impl From<&Option> for TriStateBool { - fn from(value: &Option) -> Self { - match value { - None => Unknown, - Some(true) => True, - Some(false) => False, - } +bitflags! { + #[derive(PartialEq, Eq, Clone, Debug)] + struct TriBoolSet: u8 { + const TRUE = 0b1; + const FALSE = 0b10; + const UNKNOWN = 0b100; } } -impl TriStateBool { - fn try_from(value: &ScalarValue) -> Option { +impl TriBoolSet { + fn try_from(value: &ScalarValue) -> TriBoolSet { match value { - ScalarValue::Null => Some(Unknown), - ScalarValue::Boolean(b) => Some(TriStateBool::from(b)), - _ => Self::try_from(&value.cast_to(&DataType::Boolean).ok()?), + ScalarValue::Null => TriBoolSet::UNKNOWN, + ScalarValue::Boolean(b) => match b { + Some(true) => TriBoolSet::TRUE, + Some(false) => TriBoolSet::FALSE, + None => TriBoolSet::UNKNOWN, + }, + _ => { + if let Ok(b) = value.cast_to(&DataType::Boolean) { + Self::try_from(&b) + } else { + TriBoolSet::empty() + } + } } } - /// [True] if self is [True], [False] otherwise. - fn is_true(&self) -> TriStateBool { - match self { - True => True, - Unknown | False => False, + /// Returns the set of possible values after applying `IS TRUE` on all + /// values in this set + fn is_true(&self) -> Self { + let mut is_true = Self::empty(); + if self.contains(Self::TRUE) { + is_true.toggle(Self::TRUE); + } + if self.intersects(Self::UNKNOWN | Self::FALSE) { + is_true.toggle(Self::FALSE); } + is_true } - /// [True] if self is [False], [False] otherwise. - fn is_false(&self) -> TriStateBool { - match self { - False => True, - Unknown | True => False, + /// Returns the set of possible values after applying `IS FALSE` on all + /// values in this set + fn is_false(&self) -> Self { + let mut is_false = Self::empty(); + if self.contains(Self::FALSE) { + is_false.toggle(Self::TRUE); } + if self.intersects(Self::UNKNOWN | Self::TRUE) { + is_false.toggle(Self::FALSE); + } + is_false } - /// [True] if self is [Unknown], [False] otherwise. - fn is_unknown(&self) -> TriStateBool { - match self { - Unknown => True, - True | False => False, + /// Returns the set of possible values after applying `IS UNKNOWN` on all + /// values in this set + fn is_unknown(&self) -> Self { + let mut is_unknown = Self::empty(); + if self.contains(Self::UNKNOWN) { + is_unknown.toggle(Self::TRUE); + } + if self.intersects(Self::TRUE | Self::FALSE) { + is_unknown.toggle(Self::FALSE); } + is_unknown } -} -impl Not for TriStateBool { - type Output = TriStateBool; - - fn not(self) -> Self::Output { - // SQL three-valued logic NOT truth table: - // - // P | !P - // --------|------- - // TRUE | FALSE - // FALSE | TRUE - // UNKNOWN | UNKNOWN - match self { - True => False, - False => True, - Unknown => Unknown, + /// Returns the set of possible values after applying SQL three-valued logical NOT + /// on each value in `value`. + /// + /// This method uses the following truth table. + /// + /// ``` + /// P | !P + /// --------|------- + /// TRUE | FALSE + /// FALSE | TRUE + /// UNKNOWN | UNKNOWN + /// ``` + fn not(set: Self) -> Self { + let mut not = Self::empty(); + if set.contains(Self::TRUE) { + not.toggle(Self::FALSE); + } + if set.contains(Self::FALSE) { + not.toggle(Self::TRUE); + } + if set.contains(Self::UNKNOWN) { + not.toggle(Self::UNKNOWN); } + not } -} -impl BitAnd for TriStateBool { - type Output = TriStateBool; - - fn bitand(self, rhs: Self) -> Self::Output { - // SQL three-valued logic AND truth table: - // - // P | Q | P AND Q - // --------|---------|---------- - // TRUE | TRUE | TRUE - // TRUE | FALSE | FALSE - // FALSE | TRUE | FALSE - // FALSE | FALSE | FALSE - // FALSE | UNKNOWN | FALSE - // UNKNOWN | FALSE | FALSE - // TRUE | UNKNOWN | UNKNOWN - // UNKNOWN | TRUE | UNKNOWN - // UNKNOWN | UNKNOWN | UNKNOWN - match (self, rhs) { - (True, True) => True, - (False, _) | (_, False) => False, - (Unknown, _) | (_, Unknown) => Unknown, + /// Returns the set of possible values after applying SQL three-valued logical AND + /// on each combination of values from `lhs` and `rhs`. + /// + /// This method uses the following truth table. + /// + /// ``` + /// P | Q | P AND Q + /// --------|---------|---------- + /// TRUE | TRUE | TRUE + /// TRUE | FALSE | FALSE + /// FALSE | TRUE | FALSE + /// FALSE | FALSE | FALSE + /// FALSE | UNKNOWN | FALSE + /// UNKNOWN | FALSE | FALSE + /// TRUE | UNKNOWN | UNKNOWN + /// UNKNOWN | TRUE | UNKNOWN + /// UNKNOWN | UNKNOWN | UNKNOWN + /// ``` + fn and(lhs: Self, rhs: Self) -> Self { + if lhs.is_empty() || rhs.is_empty() { + return Self::empty(); + } + + let mut and = Self::empty(); + if lhs.contains(Self::FALSE) || rhs.contains(Self::FALSE) { + and.toggle(Self::FALSE); + } + + if (lhs.contains(Self::UNKNOWN) && rhs.intersects(Self::TRUE | Self::UNKNOWN)) + || (rhs.contains(Self::UNKNOWN) && lhs.intersects(Self::TRUE | Self::UNKNOWN)) + { + and.toggle(Self::UNKNOWN); + } + + if lhs.contains(Self::TRUE) && rhs.contains(Self::TRUE) { + and.toggle(Self::TRUE); } + + and } -} -impl BitOr for TriStateBool { - type Output = TriStateBool; - - fn bitor(self, rhs: Self) -> Self::Output { - // SQL three-valued logic OR truth table: - // - // P | Q | P OR Q - // --------|---------|---------- - // FALSE | FALSE | FALSE - // TRUE | TRUE | TRUE - // TRUE | FALSE | TRUE - // FALSE | TRUE | TRUE - // TRUE | UNKNOWN | TRUE - // UNKNOWN | TRUE | TRUE - // FALSE | UNKNOWN | UNKNOWN - // UNKNOWN | FALSE | UNKNOWN - // UNKNOWN | UNKNOWN | UNKNOWN - match (self, rhs) { - (False, False) => False, - (True, _) | (_, True) => True, - (Unknown, _) | (_, Unknown) => Unknown, + /// Returns the set of possible values after applying SQL three-valued logical OR + /// on each combination of values from `lhs` and `rhs`. + /// + /// This method uses the following truth table. + /// + /// ``` + /// SQL three-valued logic OR truth table: + /// + /// P | Q | P OR Q + /// --------|---------|---------- + /// FALSE | FALSE | FALSE + /// TRUE | TRUE | TRUE + /// TRUE | FALSE | TRUE + /// FALSE | TRUE | TRUE + /// TRUE | UNKNOWN | TRUE + /// UNKNOWN | TRUE | TRUE + /// FALSE | UNKNOWN | UNKNOWN + /// UNKNOWN | FALSE | UNKNOWN + /// UNKNOWN | UNKNOWN | UNKNOWN + /// ``` + fn or(lhs: Self, rhs: Self) -> Self { + let mut or = Self::empty(); + if lhs.contains(Self::TRUE) || rhs.contains(Self::TRUE) { + or.toggle(Self::TRUE); + } + + if (lhs.contains(Self::UNKNOWN) && rhs.intersects(Self::FALSE | Self::UNKNOWN)) + || (rhs.contains(Self::UNKNOWN) + && lhs.intersects(Self::FALSE | Self::UNKNOWN)) + { + or.toggle(Self::UNKNOWN); + } + + if lhs.contains(Self::FALSE) && rhs.contains(Self::FALSE) { + or.toggle(Self::FALSE); } + + or } } @@ -166,12 +219,23 @@ pub(super) fn const_eval_predicate( where F: Fn(&Expr) -> Option, { - PredicateConstEvaluator { + let evaluator = PredicateConstEvaluator { input_schema, evaluates_to_null, + }; + let possible_results = evaluator.eval_predicate(predicate); + + if !possible_results.is_empty() { + if possible_results == TriBoolSet::TRUE { + // Provably true + return Some(true); + } else if !possible_results.contains(TriBoolSet::TRUE) { + // Provably not true + return Some(false); + } } - .eval_predicate(predicate) - .map(|b| matches!(b, True)) + + None } pub(super) struct PredicateConstEvaluator<'a, F> { @@ -184,110 +248,105 @@ where F: Fn(&Expr) -> Option, { /// Attempts to const evaluate a boolean predicate. - fn eval_predicate(&self, predicate: &Expr) -> Option { + fn eval_predicate(&self, predicate: &Expr) -> TriBoolSet { match predicate { Expr::Literal(scalar, _) => { - // Interpret literals as boolean, coercing if necessary and allowed - TriStateBool::try_from(scalar) + // Interpret literals as boolean, coercing if necessary + TriBoolSet::try_from(scalar) } Expr::Negative(e) => self.eval_predicate(e), Expr::IsNull(e) => { // If `e` is not nullable, then `e IS NULL` is provably false if let Ok(false) = e.nullable(self.input_schema) { - return Some(False); + return TriBoolSet::FALSE; } match e.get_type(self.input_schema) { // If `e` is a boolean expression, try to evaluate it and test for unknown - Ok(DataType::Boolean) => { - self.eval_predicate(e).map(|b| b.is_unknown()) - } + Ok(DataType::Boolean) => self.eval_predicate(e).is_unknown(), // If `e` is not a boolean expression, check if `e` is provably null - Ok(_) => match self.is_null(e) { - // If `e` is provably null, then `e IS NULL` is provably true - True => Some(True), - // If `e` is provably not null, then `e IS NULL` is provably false - False => Some(False), - Unknown => None, - }, - Err(_) => None, + Ok(_) => self.is_null(e), + Err(_) => TriBoolSet::empty(), } } Expr::IsNotNull(e) => { // If `e` is not nullable, then `e IS NOT NULL` is provably true if let Ok(false) = e.nullable(self.input_schema) { - // If `e` is not nullable -> `e IS NOT NULL` is true - return Some(True); + return TriBoolSet::TRUE; } match e.get_type(self.input_schema) { // If `e` is a boolean expression, try to evaluate it and test for not unknown Ok(DataType::Boolean) => { - self.eval_predicate(e).map(|b| !b.is_unknown()) + TriBoolSet::not(self.eval_predicate(e).is_unknown()) } // If `e` is not a boolean expression, check if `e` is provably null - Ok(_) => match self.is_null(e) { - // If `e` is provably null, then `e IS NOT NULL` is provably false - True => Some(False), - // If `e` is provably not null, then `e IS NOT NULL` is provably true - False => Some(True), - Unknown => None, - }, - Err(_) => None, + Ok(_) => TriBoolSet::not(self.is_null(e)), + Err(_) => TriBoolSet::empty(), } } - Expr::IsTrue(e) => self.eval_predicate(e).map(|b| b.is_true()), - Expr::IsNotTrue(e) => self.eval_predicate(e).map(|b| !b.is_true()), - Expr::IsFalse(e) => self.eval_predicate(e).map(|b| b.is_false()), - Expr::IsNotFalse(e) => self.eval_predicate(e).map(|b| !b.is_false()), - Expr::IsUnknown(e) => self.eval_predicate(e).map(|b| b.is_unknown()), - Expr::IsNotUnknown(e) => self.eval_predicate(e).map(|b| !b.is_unknown()), - Expr::Not(e) => self.eval_predicate(e).map(|b| !b), + Expr::IsTrue(e) => self.eval_predicate(e).is_true(), + Expr::IsNotTrue(e) => TriBoolSet::not(self.eval_predicate(e).is_true()), + Expr::IsFalse(e) => self.eval_predicate(e).is_false(), + Expr::IsNotFalse(e) => TriBoolSet::not(self.eval_predicate(e).is_false()), + Expr::IsUnknown(e) => self.eval_predicate(e).is_unknown(), + Expr::IsNotUnknown(e) => TriBoolSet::not(self.eval_predicate(e).is_unknown()), + Expr::Not(e) => TriBoolSet::not(self.eval_predicate(e)), Expr::BinaryExpr(BinaryExpr { left, op: Operator::And, right, - }) => { - match (self.eval_predicate(left), self.eval_predicate(right)) { - // If either side is false, then the result is false regardless of the other side - (Some(False), _) | (_, Some(False)) => Some(False), - // If either side is inconclusive, then the result is inconclusive as well - (None, _) | (_, None) => None, - // Otherwise, defer to the tristate boolean algebra - (Some(l), Some(r)) => Some(l & r), - } - } + }) => TriBoolSet::and(self.eval_predicate(left), self.eval_predicate(right)), Expr::BinaryExpr(BinaryExpr { left, op: Operator::Or, right, - }) => { - match (self.eval_predicate(left), self.eval_predicate(right)) { - // If either side is true, then the result is true regardless of the other side - (Some(True), _) | (_, Some(True)) => Some(True), - // If either side is inconclusive, then the result is inconclusive as well - (None, _) | (_, None) => None, - // Otherwise, defer to the tristate boolean algebra - (Some(l), Some(r)) => Some(l | r), + }) => TriBoolSet::or(self.eval_predicate(left), self.eval_predicate(right)), + e => { + let mut result = TriBoolSet::empty(); + let is_null = self.is_null(e); + + // If an expression is null, then it's value is UNKNOWN + if is_null.contains(TriBoolSet::TRUE) { + result |= TriBoolSet::UNKNOWN + } + + // If an expression is not null, then it's either TRUE or FALSE + if is_null.contains(TriBoolSet::FALSE) { + result |= TriBoolSet::TRUE | TriBoolSet::FALSE } + + result } - e => match self.is_null(e) { - // Null values coerce to unknown - True => Some(Unknown), - // Not null, but some unknown value -> inconclusive - False | Unknown => None, - }, } } - /// Determines if the given expression evaluates to `NULL`. - fn is_null(&self, expr: &Expr) -> TriStateBool { + /// Determines if the given expression can evaluate to `NULL`. + /// + /// This method only returns sets containing `TRUE`, `FALSE`, or both. + fn is_null(&self, expr: &Expr) -> TriBoolSet { + // If `expr` is not nullable, we can be certain `expr` is not null + if let Ok(false) = expr.nullable(self.input_schema) { + return TriBoolSet::FALSE; + } + + // Check if the callback can decide for us + if let Some(is_null) = (self.evaluates_to_null)(expr) { + return if is_null { + TriBoolSet::TRUE + } else { + TriBoolSet::FALSE + }; + } + + // `expr` is nullable, so our default answer is { TRUE, FALSE }. + // Try to see if we can narrow it down to one of the two. match expr { Expr::Literal(s, _) => { if s.is_null() { - True + TriBoolSet::TRUE } else { - False + TriBoolSet::FALSE } } Expr::Alias(_) @@ -300,31 +359,27 @@ where | Expr::SimilarTo(_) => { // These expressions are null if any of their direct children is null // If any child is inconclusive, the result for this expression is also inconclusive - let mut is_null = False; - let _ = expr.apply_children(|child| match self.is_null(child) { - False => Ok(TreeNodeRecursion::Continue), - n @ True | n @ Unknown => { - // If any child is null or inconclusive, this result applies to the - // entire expression and we can stop traversing - is_null = n; + let mut is_null = TriBoolSet::FALSE.clone(); + let _ = expr.apply_children(|child| { + let child_is_null = self.is_null(child); + + if child_is_null.contains(TriBoolSet::TRUE) { + // If a child might be null, then the result may also be null + is_null.insert(TriBoolSet::TRUE); + } + + if !child_is_null.contains(TriBoolSet::FALSE) { + // If the child is never not null, then the result can also never be not null + // and we can stop traversing the children + is_null.remove(TriBoolSet::FALSE); Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) } }); is_null } - e => { - if let Ok(false) = e.nullable(self.input_schema) { - // If `expr` is not nullable, we can be certain `expr` is not null - False - } else { - // Finally, ask the callback if it knows the nullness of `expr` - match (self.evaluates_to_null)(e) { - Some(true) => True, - Some(false) => False, - None => Unknown, - } - } - } + _ => TriBoolSet::TRUE | TriBoolSet::FALSE, } } } @@ -332,8 +387,7 @@ where #[cfg(test)] mod tests { use crate::expr::ScalarFunction; - use crate::predicate_eval::TriStateBool::*; - use crate::predicate_eval::{const_eval_predicate, TriStateBool}; + use crate::predicate_eval::{const_eval_predicate, TriBoolSet}; use crate::{ binary_expr, col, create_udf, is_false, is_not_false, is_not_null, is_not_true, is_not_unknown, is_null, is_true, is_unknown, lit, not, Expr, @@ -341,81 +395,208 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; - use datafusion_expr_common::operator::Operator; - use datafusion_expr_common::operator::Operator::Eq; + use datafusion_expr_common::operator::Operator::{And, Eq, Or}; use datafusion_expr_common::signature::Volatility; use std::sync::Arc; - use Operator::{And, Or}; - - #[test] - fn tristate_bool_from_option() { - assert_eq!(TriStateBool::from(&None), Unknown); - assert_eq!(TriStateBool::from(&Some(true)), True); - assert_eq!(TriStateBool::from(&Some(false)), False); - } #[test] fn tristate_bool_from_scalar() { - assert_eq!(TriStateBool::try_from(&ScalarValue::Null).unwrap(), Unknown); - - assert_eq!( - TriStateBool::try_from(&ScalarValue::Boolean(None)).unwrap(), - Unknown - ); - assert_eq!( - TriStateBool::try_from(&ScalarValue::Boolean(Some(true))).unwrap(), - True - ); - assert_eq!( - TriStateBool::try_from(&ScalarValue::Boolean(Some(false))).unwrap(), - False - ); - - assert_eq!( - TriStateBool::try_from(&ScalarValue::UInt8(None)).unwrap(), - Unknown - ); - assert_eq!( - TriStateBool::try_from(&ScalarValue::UInt8(Some(0))).unwrap(), - False - ); - assert_eq!( - TriStateBool::try_from(&ScalarValue::UInt8(Some(1))).unwrap(), - True - ); + let cases = vec![ + (ScalarValue::Null, TriBoolSet::UNKNOWN), + (ScalarValue::Boolean(None), TriBoolSet::UNKNOWN), + (ScalarValue::Boolean(Some(true)), TriBoolSet::TRUE), + (ScalarValue::Boolean(Some(false)), TriBoolSet::FALSE), + (ScalarValue::UInt8(None), TriBoolSet::UNKNOWN), + (ScalarValue::UInt8(Some(0)), TriBoolSet::FALSE), + (ScalarValue::UInt8(Some(1)), TriBoolSet::TRUE), + ( + ScalarValue::Utf8(Some("abc".to_string())), + TriBoolSet::empty(), + ), + ]; + + for case in cases { + assert_eq!(TriBoolSet::try_from(&case.0), case.1); + } } #[test] fn tristate_bool_not() { - assert_eq!(!Unknown, Unknown); - assert_eq!(!False, True); - assert_eq!(!True, False); + let cases = vec![ + (TriBoolSet::UNKNOWN, TriBoolSet::UNKNOWN), + (TriBoolSet::TRUE, TriBoolSet::FALSE), + (TriBoolSet::FALSE, TriBoolSet::TRUE), + ( + TriBoolSet::TRUE | TriBoolSet::FALSE, + TriBoolSet::TRUE | TriBoolSet::FALSE, + ), + ( + TriBoolSet::TRUE | TriBoolSet::UNKNOWN, + TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + ), + ( + TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE | TriBoolSet::UNKNOWN, + ), + ( + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + ), + ]; + + for case in cases { + assert_eq!(TriBoolSet::not(case.0), case.1); + } } #[test] fn tristate_bool_and() { - assert_eq!(Unknown & Unknown, Unknown); - assert_eq!(Unknown & True, Unknown); - assert_eq!(Unknown & False, False); - assert_eq!(True & Unknown, Unknown); - assert_eq!(True & True, True); - assert_eq!(True & False, False); - assert_eq!(False & Unknown, False); - assert_eq!(False & True, False); - assert_eq!(False & False, False); + let cases = vec![ + ( + TriBoolSet::UNKNOWN, + TriBoolSet::UNKNOWN, + TriBoolSet::UNKNOWN, + ), + (TriBoolSet::UNKNOWN, TriBoolSet::TRUE, TriBoolSet::UNKNOWN), + (TriBoolSet::UNKNOWN, TriBoolSet::FALSE, TriBoolSet::FALSE), + (TriBoolSet::TRUE, TriBoolSet::TRUE, TriBoolSet::TRUE), + (TriBoolSet::TRUE, TriBoolSet::FALSE, TriBoolSet::FALSE), + (TriBoolSet::FALSE, TriBoolSet::FALSE, TriBoolSet::FALSE), + ( + TriBoolSet::TRUE | TriBoolSet::FALSE, + TriBoolSet::FALSE, + TriBoolSet::FALSE, + ), + ( + TriBoolSet::TRUE | TriBoolSet::FALSE, + TriBoolSet::TRUE, + TriBoolSet::TRUE | TriBoolSet::FALSE, + ), + ( + TriBoolSet::TRUE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE, + TriBoolSet::TRUE | TriBoolSet::UNKNOWN, + ), + ( + TriBoolSet::TRUE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE | TriBoolSet::FALSE, + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + ), + ( + TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE, + TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + ), + ( + TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE | TriBoolSet::FALSE, + TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + ), + ( + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE, + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + ), + ( + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE | TriBoolSet::FALSE, + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + ), + ( + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + ), + ]; + + for case in cases { + assert_eq!( + TriBoolSet::and(case.0.clone(), case.1.clone()), + case.2.clone(), + "{:?} & {:?} = {:?}", + case.0.clone(), + case.1.clone(), + case.2.clone() + ); + assert_eq!( + TriBoolSet::and(case.1.clone(), case.0.clone()), + case.2.clone(), + "{:?} & {:?} = {:?}", + case.1.clone(), + case.0.clone(), + case.2.clone() + ); + } } #[test] fn tristate_bool_or() { - assert_eq!(Unknown | Unknown, Unknown); - assert_eq!(Unknown | True, True); - assert_eq!(Unknown | False, Unknown); - assert_eq!(True | Unknown, True); - assert_eq!(True | True, True); - assert_eq!(True | False, True); - assert_eq!(False | Unknown, Unknown); - assert_eq!(False | True, True); - assert_eq!(False | False, False); + let cases = vec![ + ( + TriBoolSet::UNKNOWN, + TriBoolSet::UNKNOWN, + TriBoolSet::UNKNOWN, + ), + (TriBoolSet::UNKNOWN, TriBoolSet::TRUE, TriBoolSet::TRUE), + (TriBoolSet::UNKNOWN, TriBoolSet::FALSE, TriBoolSet::UNKNOWN), + (TriBoolSet::TRUE, TriBoolSet::TRUE, TriBoolSet::TRUE), + (TriBoolSet::TRUE, TriBoolSet::FALSE, TriBoolSet::TRUE), + (TriBoolSet::FALSE, TriBoolSet::FALSE, TriBoolSet::FALSE), + ( + TriBoolSet::TRUE | TriBoolSet::FALSE, + TriBoolSet::FALSE, + TriBoolSet::TRUE | TriBoolSet::FALSE, + ), + ( + TriBoolSet::TRUE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE, + TriBoolSet::TRUE, + ), + ( + TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE, + TriBoolSet::TRUE, + ), + ( + TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE | TriBoolSet::FALSE, + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + ), + ( + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE, + TriBoolSet::TRUE, + ), + ( + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE | TriBoolSet::FALSE, + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + ), + ( + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + ), + ]; + + for case in cases { + assert_eq!( + TriBoolSet::or(case.0.clone(), case.1.clone()), + case.2.clone(), + "{:?} | {:?} = {:?}", + case.0.clone(), + case.1.clone(), + case.2.clone() + ); + assert_eq!( + TriBoolSet::or(case.1.clone(), case.0.clone()), + case.2.clone(), + "{:?} | {:?} = {:?}", + case.1.clone(), + case.0.clone(), + case.2.clone() + ); + } } fn const_eval(predicate: &Expr) -> Option { @@ -532,11 +713,11 @@ mod tests { ); assert_eq!( const_eval(&binary_expr(null.clone(), And, func.clone())), - None + Some(false) ); assert_eq!( const_eval(&binary_expr(func.clone(), And, null.clone())), - None + Some(false) ); } From a1bc263e1fa92cfdc2daf2c440c8acfa9ca3ef02 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sat, 8 Nov 2025 23:13:08 +0100 Subject: [PATCH 17/23] Rename predicate_eval to predicate_bounds --- .../expr-common/src/interval_arithmetic.rs | 10 + datafusion/expr/src/expr_schema.rs | 17 +- datafusion/expr/src/lib.rs | 2 +- ...{predicate_eval.rs => predicate_bounds.rs} | 493 ++++++++++-------- 4 files changed, 301 insertions(+), 221 deletions(-) rename datafusion/expr/src/{predicate_eval.rs => predicate_bounds.rs} (63%) diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 7515b59b9221..61af080fe514 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -1762,6 +1762,16 @@ impl NullableInterval { } } + /// Return true if the value is definitely not true (either null or false). + pub fn is_certainly_not_true(&self) -> bool { + match self { + Self::Null { .. } => true, + Self::MaybeNull { values } | Self::NotNull { values } => { + values == &Interval::CERTAINLY_FALSE + } + } + } + /// Return true if the value is definitely false (and not null). pub fn is_certainly_false(&self) -> bool { match self { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index f520dd07adbd..0fc475660b8d 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use super::{predicate_eval, Between, Expr, Like}; +use super::{predicate_bounds, Between, Expr, Like}; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, @@ -321,17 +321,18 @@ impl ExprSchemable for Expr { } }; - match predicate_eval::const_eval_predicate( + let bounds = predicate_bounds::evaluate_bounds( w, is_null, input_schema, - ) { - // Const evaluation was inconclusive or determined the branch - // would be taken - None | Some(true) => Some(Ok(())), - // Const evaluation proves the branch will never be taken. + ); + if bounds.is_certainly_not_true() { + // The branch will certainly never be taken. // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. - Some(false) => None, + None + } else { + // The branch might be taken + Some(Ok(())) } } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 701f40ee678b..d9a55d489b6c 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -70,7 +70,7 @@ pub mod async_udf; pub mod statistics { pub use datafusion_expr_common::statistics::*; } -mod predicate_eval; +mod predicate_bounds; pub mod ptr_eq; pub mod test; pub mod tree_node; diff --git a/datafusion/expr/src/predicate_eval.rs b/datafusion/expr/src/predicate_bounds.rs similarity index 63% rename from datafusion/expr/src/predicate_eval.rs rename to datafusion/expr/src/predicate_bounds.rs index c29ca6839070..da2927a8d720 100644 --- a/datafusion/expr/src/predicate_eval.rs +++ b/datafusion/expr/src/predicate_bounds.rs @@ -20,38 +20,41 @@ use arrow::datatypes::DataType; use bitflags::bitflags; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{ExprSchema, ScalarValue}; +use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr_common::operator::Operator; bitflags! { + /// A set representing the possible outcomes of a SQL boolean expression #[derive(PartialEq, Eq, Clone, Debug)] - struct TriBoolSet: u8 { + struct TernarySet: u8 { const TRUE = 0b1; const FALSE = 0b10; const UNKNOWN = 0b100; } } -impl TriBoolSet { - fn try_from(value: &ScalarValue) -> TriBoolSet { +impl TernarySet { + fn try_from(value: &ScalarValue) -> TernarySet { match value { - ScalarValue::Null => TriBoolSet::UNKNOWN, + ScalarValue::Null => TernarySet::UNKNOWN, ScalarValue::Boolean(b) => match b { - Some(true) => TriBoolSet::TRUE, - Some(false) => TriBoolSet::FALSE, - None => TriBoolSet::UNKNOWN, + Some(true) => TernarySet::TRUE, + Some(false) => TernarySet::FALSE, + None => TernarySet::UNKNOWN, }, _ => { if let Ok(b) = value.cast_to(&DataType::Boolean) { Self::try_from(&b) } else { - TriBoolSet::empty() + TernarySet::empty() } } } } - /// Returns the set of possible values after applying `IS TRUE` on all - /// values in this set + /// Returns the set of possible values after applying the `is true` test on all + /// values in this set. + /// The resulting set can only contain 'TRUE' and/or 'FALSE', never 'UNKNOWN'. fn is_true(&self) -> Self { let mut is_true = Self::empty(); if self.contains(Self::TRUE) { @@ -63,8 +66,9 @@ impl TriBoolSet { is_true } - /// Returns the set of possible values after applying `IS FALSE` on all - /// values in this set + /// Returns the set of possible values after applying the `is false` test on all + /// values in this set. + /// The resulting set can only contain 'TRUE' and/or 'FALSE', never 'UNKNOWN'. fn is_false(&self) -> Self { let mut is_false = Self::empty(); if self.contains(Self::FALSE) { @@ -76,8 +80,9 @@ impl TriBoolSet { is_false } - /// Returns the set of possible values after applying `IS UNKNOWN` on all - /// values in this set + /// Returns the set of possible values after applying the `is unknown` test on all + /// values in this set. + /// The resulting set can only contain 'TRUE' and/or 'FALSE', never 'UNKNOWN'. fn is_unknown(&self) -> Self { let mut is_unknown = Self::empty(); if self.contains(Self::UNKNOWN) { @@ -94,12 +99,12 @@ impl TriBoolSet { /// /// This method uses the following truth table. /// - /// ``` - /// P | !P - /// --------|------- - /// TRUE | FALSE - /// FALSE | TRUE - /// UNKNOWN | UNKNOWN + /// ```text + /// A | ¬A + /// ----|---- + /// F | T + /// U | U + /// T | F /// ``` fn not(set: Self) -> Self { let mut not = Self::empty(); @@ -120,18 +125,12 @@ impl TriBoolSet { /// /// This method uses the following truth table. /// - /// ``` - /// P | Q | P AND Q - /// --------|---------|---------- - /// TRUE | TRUE | TRUE - /// TRUE | FALSE | FALSE - /// FALSE | TRUE | FALSE - /// FALSE | FALSE | FALSE - /// FALSE | UNKNOWN | FALSE - /// UNKNOWN | FALSE | FALSE - /// TRUE | UNKNOWN | UNKNOWN - /// UNKNOWN | TRUE | UNKNOWN - /// UNKNOWN | UNKNOWN | UNKNOWN + /// ```text + /// A ∧ B │ F U T + /// ──────┼────── + /// F │ F F F + /// U │ F U U + /// T │ F U T /// ``` fn and(lhs: Self, rhs: Self) -> Self { if lhs.is_empty() || rhs.is_empty() { @@ -161,20 +160,12 @@ impl TriBoolSet { /// /// This method uses the following truth table. /// - /// ``` - /// SQL three-valued logic OR truth table: - /// - /// P | Q | P OR Q - /// --------|---------|---------- - /// FALSE | FALSE | FALSE - /// TRUE | TRUE | TRUE - /// TRUE | FALSE | TRUE - /// FALSE | TRUE | TRUE - /// TRUE | UNKNOWN | TRUE - /// UNKNOWN | TRUE | TRUE - /// FALSE | UNKNOWN | UNKNOWN - /// UNKNOWN | FALSE | UNKNOWN - /// UNKNOWN | UNKNOWN | UNKNOWN + /// ```text + /// A ∨ B │ F U T + /// ──────┼────── + /// F │ F U T + /// U │ U U T + /// T │ T T T /// ``` fn or(lhs: Self, rhs: Self) -> Self { let mut or = Self::empty(); @@ -197,123 +188,180 @@ impl TriBoolSet { } } -/// Attempts to const evaluate a predicate using SQL three-valued logic. +/// Computes the output interval for the given boolean expression based on statically +/// available information. +/// +/// # Arguments +/// +/// * `predicate` - The boolean expression to analyze +/// * `is_null` - A callback function that provides additional nullability information for +/// expressions. When called with an expression, it should return: +/// - `Some(true)` if the expression is known to evaluate to NULL +/// - `Some(false)` if the expression is known to NOT evaluate to NULL +/// - `None` if the nullability cannot be determined +/// +/// This callback allows the caller to provide context-specific knowledge about expression +/// nullability that cannot be determined from the schema alone. For example, it can be used +/// to indicate that a particular column reference is known to be NULL in a specific context, +/// or that certain expressions will never be NULL based on runtime constraints. +/// +/// * `input_schema` - Schema information for resolving expression types and nullability +/// +/// # Return Value /// -/// Semantics of the return value: -/// - `Some(true)` => predicate is provably truthy -/// - `Some(false)` => predicate is provably falsy -/// - `None` => inconclusive with available static information +/// The function returns a [NullableInterval] that describes the possible boolean values the +/// predicate can evaluate to. The return value will be one of the following: /// -/// The evaluation is conservative and only uses: -/// - Expression nullability from `input_schema` -/// - Simple type checks (e.g. whether an expression is Boolean) -/// - Syntactic patterns (IS NULL/IS NOT NULL/IS TRUE/IS FALSE/etc.) -/// - Three-valued boolean algebra for AND/OR/NOT +/// * `NullableInterval::NotNull { values: Interval::CERTAINLY_TRUE }` - The predicate will +/// always evaluate to TRUE (never FALSE or NULL) /// -/// It does not evaluate user-defined functions. -pub(super) fn const_eval_predicate( +/// * `NullableInterval::NotNull { values: Interval::CERTAINLY_FALSE }` - The predicate will +/// always evaluate to FALSE (never TRUE or NULL) +/// +/// * `NullableInterval::NotNull { values: Interval::UNCERTAIN }` - The predicate will never +/// evaluate to NULL, but may be either TRUE or FALSE +/// +/// * `NullableInterval::Null { datatype: DataType::Boolean }` - The predicate will always +/// evaluate to NULL (SQL UNKNOWN in three-valued logic) +/// +/// * `NullableInterval::MaybeNull { values: Interval::CERTAINLY_TRUE }` - The predicate may +/// evaluate to TRUE or NULL, but never FALSE +/// +/// * `NullableInterval::MaybeNull { values: Interval::CERTAINLY_FALSE }` - The predicate may +/// evaluate to FALSE or NULL, but never TRUE +/// +/// * `NullableInterval::MaybeNull { values: Interval::UNCERTAIN }` - The predicate may +/// evaluate to any of TRUE, FALSE, or NULL +/// +pub(super) fn evaluate_bounds( predicate: &Expr, - evaluates_to_null: F, + is_null: F, input_schema: &dyn ExprSchema, -) -> Option +) -> NullableInterval where F: Fn(&Expr) -> Option, { - let evaluator = PredicateConstEvaluator { + let evaluator = PredicateBoundsEvaluator { input_schema, - evaluates_to_null, + is_null, }; - let possible_results = evaluator.eval_predicate(predicate); + let possible_results = evaluator.evaluate_bounds(predicate); - if !possible_results.is_empty() { - if possible_results == TriBoolSet::TRUE { - // Provably true - return Some(true); - } else if !possible_results.contains(TriBoolSet::TRUE) { - // Provably not true - return Some(false); + if possible_results.is_empty() || possible_results == TernarySet::all() { + NullableInterval::MaybeNull { + values: Interval::UNCERTAIN, + } + } else if possible_results == TernarySet::TRUE { + NullableInterval::NotNull { + values: Interval::CERTAINLY_TRUE, + } + } else if possible_results == TernarySet::FALSE { + NullableInterval::NotNull { + values: Interval::CERTAINLY_FALSE, + } + } else if possible_results == TernarySet::UNKNOWN { + NullableInterval::Null { + datatype: DataType::Boolean, + } + } else { + let t = possible_results.contains(TernarySet::TRUE); + let f = possible_results.contains(TernarySet::FALSE); + let values = if t && f { + Interval::UNCERTAIN + } else if t { + Interval::CERTAINLY_TRUE + } else { + Interval::CERTAINLY_FALSE + }; + + if possible_results.contains(TernarySet::UNKNOWN) { + NullableInterval::MaybeNull { values } + } else { + NullableInterval::NotNull { values } } } - - None } -pub(super) struct PredicateConstEvaluator<'a, F> { +pub(super) struct PredicateBoundsEvaluator<'a, F> { input_schema: &'a dyn ExprSchema, - evaluates_to_null: F, + is_null: F, } -impl PredicateConstEvaluator<'_, F> +impl PredicateBoundsEvaluator<'_, F> where F: Fn(&Expr) -> Option, { - /// Attempts to const evaluate a boolean predicate. - fn eval_predicate(&self, predicate: &Expr) -> TriBoolSet { + /// Derives the bounds of the given boolean expression + fn evaluate_bounds(&self, predicate: &Expr) -> TernarySet { match predicate { Expr::Literal(scalar, _) => { // Interpret literals as boolean, coercing if necessary - TriBoolSet::try_from(scalar) + TernarySet::try_from(scalar) } - Expr::Negative(e) => self.eval_predicate(e), + Expr::Negative(e) => self.evaluate_bounds(e), Expr::IsNull(e) => { // If `e` is not nullable, then `e IS NULL` is provably false if let Ok(false) = e.nullable(self.input_schema) { - return TriBoolSet::FALSE; + return TernarySet::FALSE; } match e.get_type(self.input_schema) { // If `e` is a boolean expression, try to evaluate it and test for unknown - Ok(DataType::Boolean) => self.eval_predicate(e).is_unknown(), + Ok(DataType::Boolean) => self.evaluate_bounds(e).is_unknown(), // If `e` is not a boolean expression, check if `e` is provably null Ok(_) => self.is_null(e), - Err(_) => TriBoolSet::empty(), + Err(_) => TernarySet::empty(), } } Expr::IsNotNull(e) => { // If `e` is not nullable, then `e IS NOT NULL` is provably true if let Ok(false) = e.nullable(self.input_schema) { - return TriBoolSet::TRUE; + return TernarySet::TRUE; } match e.get_type(self.input_schema) { // If `e` is a boolean expression, try to evaluate it and test for not unknown Ok(DataType::Boolean) => { - TriBoolSet::not(self.eval_predicate(e).is_unknown()) + TernarySet::not(self.evaluate_bounds(e).is_unknown()) } // If `e` is not a boolean expression, check if `e` is provably null - Ok(_) => TriBoolSet::not(self.is_null(e)), - Err(_) => TriBoolSet::empty(), + Ok(_) => TernarySet::not(self.is_null(e)), + Err(_) => TernarySet::empty(), } } - Expr::IsTrue(e) => self.eval_predicate(e).is_true(), - Expr::IsNotTrue(e) => TriBoolSet::not(self.eval_predicate(e).is_true()), - Expr::IsFalse(e) => self.eval_predicate(e).is_false(), - Expr::IsNotFalse(e) => TriBoolSet::not(self.eval_predicate(e).is_false()), - Expr::IsUnknown(e) => self.eval_predicate(e).is_unknown(), - Expr::IsNotUnknown(e) => TriBoolSet::not(self.eval_predicate(e).is_unknown()), - Expr::Not(e) => TriBoolSet::not(self.eval_predicate(e)), + Expr::IsTrue(e) => self.evaluate_bounds(e).is_true(), + Expr::IsNotTrue(e) => TernarySet::not(self.evaluate_bounds(e).is_true()), + Expr::IsFalse(e) => self.evaluate_bounds(e).is_false(), + Expr::IsNotFalse(e) => TernarySet::not(self.evaluate_bounds(e).is_false()), + Expr::IsUnknown(e) => self.evaluate_bounds(e).is_unknown(), + Expr::IsNotUnknown(e) => { + TernarySet::not(self.evaluate_bounds(e).is_unknown()) + } + Expr::Not(e) => TernarySet::not(self.evaluate_bounds(e)), Expr::BinaryExpr(BinaryExpr { left, op: Operator::And, right, - }) => TriBoolSet::and(self.eval_predicate(left), self.eval_predicate(right)), + }) => { + TernarySet::and(self.evaluate_bounds(left), self.evaluate_bounds(right)) + } Expr::BinaryExpr(BinaryExpr { left, op: Operator::Or, right, - }) => TriBoolSet::or(self.eval_predicate(left), self.eval_predicate(right)), + }) => TernarySet::or(self.evaluate_bounds(left), self.evaluate_bounds(right)), e => { - let mut result = TriBoolSet::empty(); + let mut result = TernarySet::empty(); let is_null = self.is_null(e); // If an expression is null, then it's value is UNKNOWN - if is_null.contains(TriBoolSet::TRUE) { - result |= TriBoolSet::UNKNOWN + if is_null.contains(TernarySet::TRUE) { + result |= TernarySet::UNKNOWN } // If an expression is not null, then it's either TRUE or FALSE - if is_null.contains(TriBoolSet::FALSE) { - result |= TriBoolSet::TRUE | TriBoolSet::FALSE + if is_null.contains(TernarySet::FALSE) { + result |= TernarySet::TRUE | TernarySet::FALSE } result @@ -324,31 +372,33 @@ where /// Determines if the given expression can evaluate to `NULL`. /// /// This method only returns sets containing `TRUE`, `FALSE`, or both. - fn is_null(&self, expr: &Expr) -> TriBoolSet { + fn is_null(&self, expr: &Expr) -> TernarySet { + // Fast path for literals + if let Expr::Literal(scalar, _) = expr { + if scalar.is_null() { + return TernarySet::TRUE; + } else { + return TernarySet::FALSE; + } + } + // If `expr` is not nullable, we can be certain `expr` is not null if let Ok(false) = expr.nullable(self.input_schema) { - return TriBoolSet::FALSE; + return TernarySet::FALSE; } // Check if the callback can decide for us - if let Some(is_null) = (self.evaluates_to_null)(expr) { - return if is_null { - TriBoolSet::TRUE + if let Some(expr_is_null) = (self.is_null)(expr) { + return if expr_is_null { + TernarySet::TRUE } else { - TriBoolSet::FALSE + TernarySet::FALSE }; } - // `expr` is nullable, so our default answer is { TRUE, FALSE }. - // Try to see if we can narrow it down to one of the two. + // `expr` is nullable, so our default answer for `is null` is going to be `{ TRUE, FALSE }`. + // Try to see if we can narrow it down to just one option. match expr { - Expr::Literal(s, _) => { - if s.is_null() { - TriBoolSet::TRUE - } else { - TriBoolSet::FALSE - } - } Expr::Alias(_) | Expr::Between(_) | Expr::BinaryExpr(_) @@ -359,19 +409,19 @@ where | Expr::SimilarTo(_) => { // These expressions are null if any of their direct children is null // If any child is inconclusive, the result for this expression is also inconclusive - let mut is_null = TriBoolSet::FALSE.clone(); + let mut is_null = TernarySet::FALSE.clone(); let _ = expr.apply_children(|child| { let child_is_null = self.is_null(child); - if child_is_null.contains(TriBoolSet::TRUE) { + if child_is_null.contains(TernarySet::TRUE) { // If a child might be null, then the result may also be null - is_null.insert(TriBoolSet::TRUE); + is_null.insert(TernarySet::TRUE); } - if !child_is_null.contains(TriBoolSet::FALSE) { + if !child_is_null.contains(TernarySet::FALSE) { // If the child is never not null, then the result can also never be not null // and we can stop traversing the children - is_null.remove(TriBoolSet::FALSE); + is_null.remove(TernarySet::FALSE); Ok(TreeNodeRecursion::Stop) } else { Ok(TreeNodeRecursion::Continue) @@ -379,7 +429,7 @@ where }); is_null } - _ => TriBoolSet::TRUE | TriBoolSet::FALSE, + _ => TernarySet::TRUE | TernarySet::FALSE, } } } @@ -387,13 +437,13 @@ where #[cfg(test)] mod tests { use crate::expr::ScalarFunction; - use crate::predicate_eval::{const_eval_predicate, TriBoolSet}; + use crate::predicate_bounds::{evaluate_bounds, TernarySet}; use crate::{ binary_expr, col, create_udf, is_false, is_not_false, is_not_null, is_not_true, is_not_unknown, is_null, is_true, is_unknown, lit, not, Expr, }; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{DFSchema, ScalarValue}; + use datafusion_common::{DFSchema, ExprSchema, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::operator::Operator::{And, Eq, Or}; use datafusion_expr_common::signature::Volatility; @@ -402,50 +452,50 @@ mod tests { #[test] fn tristate_bool_from_scalar() { let cases = vec![ - (ScalarValue::Null, TriBoolSet::UNKNOWN), - (ScalarValue::Boolean(None), TriBoolSet::UNKNOWN), - (ScalarValue::Boolean(Some(true)), TriBoolSet::TRUE), - (ScalarValue::Boolean(Some(false)), TriBoolSet::FALSE), - (ScalarValue::UInt8(None), TriBoolSet::UNKNOWN), - (ScalarValue::UInt8(Some(0)), TriBoolSet::FALSE), - (ScalarValue::UInt8(Some(1)), TriBoolSet::TRUE), + (ScalarValue::Null, TernarySet::UNKNOWN), + (ScalarValue::Boolean(None), TernarySet::UNKNOWN), + (ScalarValue::Boolean(Some(true)), TernarySet::TRUE), + (ScalarValue::Boolean(Some(false)), TernarySet::FALSE), + (ScalarValue::UInt8(None), TernarySet::UNKNOWN), + (ScalarValue::UInt8(Some(0)), TernarySet::FALSE), + (ScalarValue::UInt8(Some(1)), TernarySet::TRUE), ( ScalarValue::Utf8(Some("abc".to_string())), - TriBoolSet::empty(), + TernarySet::empty(), ), ]; for case in cases { - assert_eq!(TriBoolSet::try_from(&case.0), case.1); + assert_eq!(TernarySet::try_from(&case.0), case.1); } } #[test] fn tristate_bool_not() { let cases = vec![ - (TriBoolSet::UNKNOWN, TriBoolSet::UNKNOWN), - (TriBoolSet::TRUE, TriBoolSet::FALSE), - (TriBoolSet::FALSE, TriBoolSet::TRUE), + (TernarySet::UNKNOWN, TernarySet::UNKNOWN), + (TernarySet::TRUE, TernarySet::FALSE), + (TernarySet::FALSE, TernarySet::TRUE), ( - TriBoolSet::TRUE | TriBoolSet::FALSE, - TriBoolSet::TRUE | TriBoolSet::FALSE, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE, ), ( - TriBoolSet::TRUE | TriBoolSet::UNKNOWN, - TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TernarySet::TRUE | TernarySet::UNKNOWN, + TernarySet::FALSE | TernarySet::UNKNOWN, ), ( - TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE | TriBoolSet::UNKNOWN, + TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::UNKNOWN, ), ( - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, ), ]; for case in cases { - assert_eq!(TriBoolSet::not(case.0), case.1); + assert_eq!(TernarySet::not(case.0), case.1); } } @@ -453,65 +503,65 @@ mod tests { fn tristate_bool_and() { let cases = vec![ ( - TriBoolSet::UNKNOWN, - TriBoolSet::UNKNOWN, - TriBoolSet::UNKNOWN, + TernarySet::UNKNOWN, + TernarySet::UNKNOWN, + TernarySet::UNKNOWN, ), - (TriBoolSet::UNKNOWN, TriBoolSet::TRUE, TriBoolSet::UNKNOWN), - (TriBoolSet::UNKNOWN, TriBoolSet::FALSE, TriBoolSet::FALSE), - (TriBoolSet::TRUE, TriBoolSet::TRUE, TriBoolSet::TRUE), - (TriBoolSet::TRUE, TriBoolSet::FALSE, TriBoolSet::FALSE), - (TriBoolSet::FALSE, TriBoolSet::FALSE, TriBoolSet::FALSE), + (TernarySet::UNKNOWN, TernarySet::TRUE, TernarySet::UNKNOWN), + (TernarySet::UNKNOWN, TernarySet::FALSE, TernarySet::FALSE), + (TernarySet::TRUE, TernarySet::TRUE, TernarySet::TRUE), + (TernarySet::TRUE, TernarySet::FALSE, TernarySet::FALSE), + (TernarySet::FALSE, TernarySet::FALSE, TernarySet::FALSE), ( - TriBoolSet::TRUE | TriBoolSet::FALSE, - TriBoolSet::FALSE, - TriBoolSet::FALSE, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::FALSE, + TernarySet::FALSE, ), ( - TriBoolSet::TRUE | TriBoolSet::FALSE, - TriBoolSet::TRUE, - TriBoolSet::TRUE | TriBoolSet::FALSE, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE, + TernarySet::TRUE | TernarySet::FALSE, ), ( - TriBoolSet::TRUE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE, - TriBoolSet::TRUE | TriBoolSet::UNKNOWN, + TernarySet::TRUE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::TRUE | TernarySet::UNKNOWN, ), ( - TriBoolSet::TRUE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE | TriBoolSet::FALSE, - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TernarySet::TRUE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, ), ( - TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE, - TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::FALSE | TernarySet::UNKNOWN, ), ( - TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE | TriBoolSet::FALSE, - TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::FALSE | TernarySet::UNKNOWN, ), ( - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE, - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, ), ( - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE | TriBoolSet::FALSE, - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, ), ( - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, ), ]; for case in cases { assert_eq!( - TriBoolSet::and(case.0.clone(), case.1.clone()), + TernarySet::and(case.0.clone(), case.1.clone()), case.2.clone(), "{:?} & {:?} = {:?}", case.0.clone(), @@ -519,7 +569,7 @@ mod tests { case.2.clone() ); assert_eq!( - TriBoolSet::and(case.1.clone(), case.0.clone()), + TernarySet::and(case.1.clone(), case.0.clone()), case.2.clone(), "{:?} & {:?} = {:?}", case.1.clone(), @@ -533,55 +583,55 @@ mod tests { fn tristate_bool_or() { let cases = vec![ ( - TriBoolSet::UNKNOWN, - TriBoolSet::UNKNOWN, - TriBoolSet::UNKNOWN, + TernarySet::UNKNOWN, + TernarySet::UNKNOWN, + TernarySet::UNKNOWN, ), - (TriBoolSet::UNKNOWN, TriBoolSet::TRUE, TriBoolSet::TRUE), - (TriBoolSet::UNKNOWN, TriBoolSet::FALSE, TriBoolSet::UNKNOWN), - (TriBoolSet::TRUE, TriBoolSet::TRUE, TriBoolSet::TRUE), - (TriBoolSet::TRUE, TriBoolSet::FALSE, TriBoolSet::TRUE), - (TriBoolSet::FALSE, TriBoolSet::FALSE, TriBoolSet::FALSE), + (TernarySet::UNKNOWN, TernarySet::TRUE, TernarySet::TRUE), + (TernarySet::UNKNOWN, TernarySet::FALSE, TernarySet::UNKNOWN), + (TernarySet::TRUE, TernarySet::TRUE, TernarySet::TRUE), + (TernarySet::TRUE, TernarySet::FALSE, TernarySet::TRUE), + (TernarySet::FALSE, TernarySet::FALSE, TernarySet::FALSE), ( - TriBoolSet::TRUE | TriBoolSet::FALSE, - TriBoolSet::FALSE, - TriBoolSet::TRUE | TriBoolSet::FALSE, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE, ), ( - TriBoolSet::TRUE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE, - TriBoolSet::TRUE, + TernarySet::TRUE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::TRUE, ), ( - TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE, - TriBoolSet::TRUE, + TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::TRUE, ), ( - TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE | TriBoolSet::FALSE, - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, ), ( - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE, - TriBoolSet::TRUE, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::TRUE, ), ( - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE | TriBoolSet::FALSE, - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, ), ( - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, - TriBoolSet::TRUE | TriBoolSet::FALSE | TriBoolSet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, ), ]; for case in cases { assert_eq!( - TriBoolSet::or(case.0.clone(), case.1.clone()), + TernarySet::or(case.0.clone(), case.1.clone()), case.2.clone(), "{:?} | {:?} = {:?}", case.0.clone(), @@ -589,7 +639,7 @@ mod tests { case.2.clone() ); assert_eq!( - TriBoolSet::or(case.1.clone(), case.0.clone()), + TernarySet::or(case.1.clone(), case.0.clone()), case.2.clone(), "{:?} | {:?} = {:?}", case.1.clone(), @@ -599,6 +649,25 @@ mod tests { } } + fn const_eval_predicate( + predicate: &Expr, + evaluates_to_null: F, + input_schema: &dyn ExprSchema, + ) -> Option + where + F: Fn(&Expr) -> Option, + { + let bounds = evaluate_bounds(predicate, evaluates_to_null, input_schema); + + if bounds.is_certainly_true() { + Some(true) + } else if bounds.is_certainly_not_true() { + Some(false) + } else { + None + } + } + fn const_eval(predicate: &Expr) -> Option { let schema = DFSchema::try_from(Schema::empty()).unwrap(); const_eval_predicate(predicate, |_| None, &schema) From ac765e9bc0e290ed2ef12237bddf537c608c3fc3 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 9 Nov 2025 00:19:51 +0100 Subject: [PATCH 18/23] Add unit tests for NullableInterval::is_certainly_... --- .../expr-common/src/interval_arithmetic.rs | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 61af080fe514..1e670d2de60a 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -1980,6 +1980,7 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::rounding::{next_down, next_up}; use datafusion_common::{Result, ScalarValue}; + use crate::interval_arithmetic::NullableInterval; #[test] fn test_next_prev_value() -> Result<()> { @@ -4113,4 +4114,67 @@ mod tests { Ok(()) } + + #[test] + fn test_is_certainly_true() { + let test_cases = vec![ + (NullableInterval::Null { datatype: DataType::Boolean }, false), + (NullableInterval::MaybeNull { values: Interval::CERTAINLY_TRUE }, false), + (NullableInterval::MaybeNull { values: Interval::CERTAINLY_FALSE }, false), + (NullableInterval::MaybeNull { values: Interval::UNCERTAIN }, false), + (NullableInterval::NotNull { values: Interval::CERTAINLY_TRUE }, true), + (NullableInterval::NotNull { values: Interval::CERTAINLY_FALSE }, false), + (NullableInterval::NotNull { values: Interval::UNCERTAIN }, false), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_certainly_true(); + assert_eq!( + result, expected, + "Failed for interval: {interval}", + ); + } + } + + #[test] + fn test_is_certainly_not_true() { + let test_cases = vec![ + (NullableInterval::Null { datatype: DataType::Boolean }, true), + (NullableInterval::MaybeNull { values: Interval::CERTAINLY_TRUE }, false), + (NullableInterval::MaybeNull { values: Interval::CERTAINLY_FALSE }, true), + (NullableInterval::MaybeNull { values: Interval::UNCERTAIN }, false), + (NullableInterval::NotNull { values: Interval::CERTAINLY_TRUE }, false), + (NullableInterval::NotNull { values: Interval::CERTAINLY_FALSE }, true), + (NullableInterval::NotNull { values: Interval::UNCERTAIN }, false), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_certainly_not_true(); + assert_eq!( + result, expected, + "Failed for interval: {interval}", + ); + } + } + + #[test] + fn test_is_certainly_false() { + let test_cases = vec![ + (NullableInterval::Null { datatype: DataType::Boolean }, false), + (NullableInterval::MaybeNull { values: Interval::CERTAINLY_TRUE }, false), + (NullableInterval::MaybeNull { values: Interval::CERTAINLY_FALSE }, false), + (NullableInterval::MaybeNull { values: Interval::UNCERTAIN }, false), + (NullableInterval::NotNull { values: Interval::CERTAINLY_TRUE }, false), + (NullableInterval::NotNull { values: Interval::CERTAINLY_FALSE }, true), + (NullableInterval::NotNull { values: Interval::UNCERTAIN }, false), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_certainly_false(); + assert_eq!( + result, expected, + "Failed for interval: {interval}", + ); + } + } } From 51af7499f5b59e7bff1c90eb7cad8757c66e72f7 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 9 Nov 2025 00:49:56 +0100 Subject: [PATCH 19/23] Formatting --- .../expr-common/src/interval_arithmetic.rs | 164 ++++++++++++++---- 1 file changed, 130 insertions(+), 34 deletions(-) diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 1e670d2de60a..478e95520f95 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -1977,10 +1977,10 @@ mod tests { operator::Operator, }; + use crate::interval_arithmetic::NullableInterval; use arrow::datatypes::DataType; use datafusion_common::rounding::{next_down, next_up}; use datafusion_common::{Result, ScalarValue}; - use crate::interval_arithmetic::NullableInterval; #[test] fn test_next_prev_value() -> Result<()> { @@ -4118,63 +4118,159 @@ mod tests { #[test] fn test_is_certainly_true() { let test_cases = vec![ - (NullableInterval::Null { datatype: DataType::Boolean }, false), - (NullableInterval::MaybeNull { values: Interval::CERTAINLY_TRUE }, false), - (NullableInterval::MaybeNull { values: Interval::CERTAINLY_FALSE }, false), - (NullableInterval::MaybeNull { values: Interval::UNCERTAIN }, false), - (NullableInterval::NotNull { values: Interval::CERTAINLY_TRUE }, true), - (NullableInterval::NotNull { values: Interval::CERTAINLY_FALSE }, false), - (NullableInterval::NotNull { values: Interval::UNCERTAIN }, false), + ( + NullableInterval::Null { + datatype: DataType::Boolean, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_TRUE, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_FALSE, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::UNCERTAIN, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_TRUE, + }, + true, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_FALSE, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::UNCERTAIN, + }, + false, + ), ]; for (interval, expected) in test_cases { let result = interval.is_certainly_true(); - assert_eq!( - result, expected, - "Failed for interval: {interval}", - ); + assert_eq!(result, expected, "Failed for interval: {interval}",); } } #[test] fn test_is_certainly_not_true() { let test_cases = vec![ - (NullableInterval::Null { datatype: DataType::Boolean }, true), - (NullableInterval::MaybeNull { values: Interval::CERTAINLY_TRUE }, false), - (NullableInterval::MaybeNull { values: Interval::CERTAINLY_FALSE }, true), - (NullableInterval::MaybeNull { values: Interval::UNCERTAIN }, false), - (NullableInterval::NotNull { values: Interval::CERTAINLY_TRUE }, false), - (NullableInterval::NotNull { values: Interval::CERTAINLY_FALSE }, true), - (NullableInterval::NotNull { values: Interval::UNCERTAIN }, false), + ( + NullableInterval::Null { + datatype: DataType::Boolean, + }, + true, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_TRUE, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_FALSE, + }, + true, + ), + ( + NullableInterval::MaybeNull { + values: Interval::UNCERTAIN, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_TRUE, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_FALSE, + }, + true, + ), + ( + NullableInterval::NotNull { + values: Interval::UNCERTAIN, + }, + false, + ), ]; for (interval, expected) in test_cases { let result = interval.is_certainly_not_true(); - assert_eq!( - result, expected, - "Failed for interval: {interval}", - ); + assert_eq!(result, expected, "Failed for interval: {interval}",); } } #[test] fn test_is_certainly_false() { let test_cases = vec![ - (NullableInterval::Null { datatype: DataType::Boolean }, false), - (NullableInterval::MaybeNull { values: Interval::CERTAINLY_TRUE }, false), - (NullableInterval::MaybeNull { values: Interval::CERTAINLY_FALSE }, false), - (NullableInterval::MaybeNull { values: Interval::UNCERTAIN }, false), - (NullableInterval::NotNull { values: Interval::CERTAINLY_TRUE }, false), - (NullableInterval::NotNull { values: Interval::CERTAINLY_FALSE }, true), - (NullableInterval::NotNull { values: Interval::UNCERTAIN }, false), + ( + NullableInterval::Null { + datatype: DataType::Boolean, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_TRUE, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_FALSE, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::UNCERTAIN, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_TRUE, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_FALSE, + }, + true, + ), + ( + NullableInterval::NotNull { + values: Interval::UNCERTAIN, + }, + false, + ), ]; for (interval, expected) in test_cases { let result = interval.is_certainly_false(); - assert_eq!( - result, expected, - "Failed for interval: {interval}", - ); + assert_eq!(result, expected, "Failed for interval: {interval}",); } } } From 4af84a7ff1738537954d67ebfa3b0775d0d9e99e Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Sun, 9 Nov 2025 11:07:08 +0100 Subject: [PATCH 20/23] Simplify logical and physical case branch filtering logic --- datafusion/expr/src/expr_schema.rs | 100 ++++++------ .../physical-expr/src/expressions/case.rs | 143 +++++++++--------- 2 files changed, 113 insertions(+), 130 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 0fc475660b8d..5e716c1a64d3 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -282,63 +282,51 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(field, _) => Ok(field.is_nullable()), Expr::Literal(value, _) => Ok(value.is_null()), Expr::Case(case) => { - let nullable_then = if case.expr.is_some() { - // Case-with-expression is nullable if any of the 'then' expressions. - // Assume all 'then' expressions are reachable - case.when_then_expr - .iter() - .filter_map(|(_, t)| match t.nullable(input_schema) { - Ok(n) => { - if n { - Some(Ok(())) - } else { - None - } - } - Err(e) => Some(Err(e)), - }) - .next() - } else { - // case-without-expression is nullable if any of the 'then' expressions is nullable - // and reachable when the 'then' expression evaluates to `null`. - case.when_then_expr - .iter() - .filter_map(|(w, t)| { - match t.nullable(input_schema) { - // Branches with a then expression that is not nullable can be skipped - Ok(false) => None, - // Pass on error determining nullability verbatim - Err(e) => Some(Err(e)), - // For branches with a nullable 'then' expression, try to determine - // using limited const evaluation if the branch will be taken when - // the 'then' expression evaluates to null. - Ok(true) => { - let is_null = |expr: &Expr /* Type */| { - if expr.eq(t) { - Some(true) - } else { - None - } - }; - - let bounds = predicate_bounds::evaluate_bounds( - w, - is_null, - input_schema, - ); - if bounds.is_certainly_not_true() { - // The branch will certainly never be taken. - // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. - None - } else { - // The branch might be taken - Some(Ok(())) - } - } + let nullable_then = case + .when_then_expr + .iter() + .filter_map(|(w, t)| { + let is_nullable = match t.nullable(input_schema) { + Err(e) => return Some(Err(e)), + Ok(n) => n, + }; + + // Branches with a then expression that is not nullable do not impact the + // nullability of the case expression. + if !is_nullable { + return None; + } + + // For case-with-expression assume all 'then' expressions are reachable + if case.expr.is_some() { + return Some(Ok(())); + } + + // For branches with a nullable 'then' expression, try to determine + // if the 'then' expression is ever reachable in the situation where + // it would evaluate to null. + let is_null = |expr: &Expr /* Type */| { + if expr.eq(t) { + Some(true) + } else { + None } - }) - .next() - }; + }; + + let bounds = + predicate_bounds::evaluate_bounds(w, is_null, input_schema); + + if bounds.is_certainly_not_true() { + // The predicate will never evaluate to true, so the 'then' expression + // is never reachable. + // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. + None + } else { + // The branch might be taken + Some(Ok(())) + } + }) + .next(); if let Some(nullable_then) = nullable_then { // There is at least one reachable nullable then diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 845242e92102..0d090d6b8001 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -1283,57 +1283,53 @@ impl PhysicalExpr for CaseExpr { } fn nullable(&self, input_schema: &Schema) -> Result { - let nullable_then = if self.body.expr.is_some() { - // Case-with-expression is nullable if any of the 'then' expressions. - // Assume all 'then' expressions are reachable - self.body - .when_then_expr - .iter() - .filter_map(|(_, t)| match t.nullable(input_schema) { - Ok(n) => { - if n { - Some(Ok(())) - } else { - None - } - } - Err(e) => Some(Err(e)), - }) - .next() - } else { - // case-without-expression is nullable if any of the 'then' expressions is nullable - // and reachable when the 'then' expression evaluates to `null`. - self.body - .when_then_expr - .iter() - .filter_map(|(w, t)| { - match t.nullable(input_schema) { - // Branches with a then expression that is not nullable can be skipped - Ok(false) => None, - // Pass on error determining nullability verbatim - Err(e) => Some(Err(e)), - Ok(true) => { - // For branches with a nullable 'then' expression, try to determine - // using const evaluation if the branch will be taken when - // the 'then' expression evaluates to null. - let is_null = |expr: &dyn PhysicalExpr /* Type */| { - expr.dyn_eq(t.as_ref()) - }; - - match const_eval_predicate(w, is_null, input_schema) { - // Const evaluation was inconclusive or determined the branch - // would be taken - Ok(None) | Ok(Some(true)) => Some(Ok(())), - // Const evaluation proves the branch will never be taken. - // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. - Ok(Some(false)) => None, - Err(e) => Some(Err(e)), - } - } - } - }) - .next() - }; + let nullable_then = self + .body + .when_then_expr + .iter() + .filter_map(|(w, t)| { + let is_nullable = match t.nullable(input_schema) { + // Pass on error determining nullability verbatim + Err(e) => return Some(Err(e)), + Ok(n) => n, + }; + + // Branches with a then expression that is not nullable do not impact the + // nullability of the case expression. + if !is_nullable { + return None; + } + + // For case-with-expression assume all 'then' expressions are reachable + if self.body.expr.is_some() { + return Some(Ok(())); + } + + // For branches with a nullable 'then' expression, try to determine + // if the 'then' expression is ever reachable in the situation where + // it would evaluate to null. + + // Replace the `then` expression with `NULL` in the `when` expression + let with_null = match replace_with_null(w, t.as_ref(), input_schema) { + Err(e) => return Some(Err(e)), + Ok(e) => e, + }; + + // Try to const evaluate the modified `when` expression. + let predicate_result = match evaluate_predicate(&with_null) { + Err(e) => return Some(Err(e)), + Ok(b) => b, + }; + + match predicate_result { + // Evaluation was inconclusive or true, so the 'then' expression is reachable + None | Some(true) => Some(Ok(())), + // Evaluation proves the branch will never be taken. + // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. + Some(false) => None, + } + }) + .next(); if let Some(nullable_then) = nullable_then { // There is at least one reachable nullable then @@ -1441,32 +1437,12 @@ impl PhysicalExpr for CaseExpr { } } -/// Attempts to const evaluate the given `predicate` with the assumption that `value` evaluates to `NULL`. +/// Attempts to const evaluate the given `predicate`. /// Returns: /// - `Some(true)` if the predicate evaluates to a truthy value. /// - `Some(false)` if the predicate evaluates to a falsy value. /// - `None` if the predicate could not be evaluated. -fn const_eval_predicate( - predicate: &Arc, - evaluates_to_null: F, - input_schema: &Schema, -) -> Result> -where - F: Fn(&dyn PhysicalExpr) -> bool, -{ - // Replace `value` with `NULL` in `predicate` - let with_null = Arc::clone(predicate) - .transform_down(|e| { - if evaluates_to_null(e.as_ref()) { - let data_type = e.data_type(input_schema)?; - let null_literal = lit(ScalarValue::try_new_null(&data_type)?); - Ok(Transformed::yes(null_literal)) - } else { - Ok(Transformed::no(e)) - } - })? - .data; - +fn evaluate_predicate(predicate: &Arc) -> Result> { // Create a dummy record with no columns and one row let batch = RecordBatch::try_new_with_options( Arc::new(Schema::empty()), @@ -1475,7 +1451,7 @@ where )?; // Evaluate the predicate and interpret the result as a boolean - let result = match with_null.evaluate(&batch) { + let result = match predicate.evaluate(&batch) { // An error during evaluation means we couldn't const evaluate the predicate, so return `None` Err(_) => None, Ok(ColumnarValue::Array(array)) => Some( @@ -1487,6 +1463,25 @@ where Ok(result.map(|v| matches!(v, ScalarValue::Boolean(Some(true))))) } +fn replace_with_null( + expr: &Arc, + expr_to_replace: &dyn PhysicalExpr, + input_schema: &Schema, +) -> Result, DataFusionError> { + let with_null = Arc::clone(expr) + .transform_down(|e| { + if e.as_ref().dyn_eq(expr_to_replace) { + let data_type = e.data_type(input_schema)?; + let null_literal = lit(ScalarValue::try_new_null(&data_type)?); + Ok(Transformed::yes(null_literal)) + } else { + Ok(Transformed::no(e)) + } + })? + .data; + Ok(with_null) +} + /// Create a CASE expression pub fn case( expr: Option>, From 427fc30fed1bfdfd49f5098a3e1ba955e5452704 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 10 Nov 2025 19:20:09 +0100 Subject: [PATCH 21/23] Further simplification of `is_null` --- datafusion/expr/src/expr_schema.rs | 16 +- datafusion/expr/src/predicate_bounds.rs | 493 ++++++++++---------- datafusion/sqllogictest/test_files/case.slt | 14 + 3 files changed, 273 insertions(+), 250 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 5e716c1a64d3..9b471d17ae25 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -305,17 +305,15 @@ impl ExprSchemable for Expr { // For branches with a nullable 'then' expression, try to determine // if the 'then' expression is ever reachable in the situation where // it would evaluate to null. - let is_null = |expr: &Expr /* Type */| { - if expr.eq(t) { - Some(true) - } else { - None - } + let bounds = match predicate_bounds::evaluate_bounds( + w, + Some(t), + input_schema, + ) { + Err(e) => return Some(Err(e)), + Ok(b) => b, }; - let bounds = - predicate_bounds::evaluate_bounds(w, is_null, input_schema); - if bounds.is_certainly_not_true() { // The predicate will never evaluate to true, so the 'then' expression // is never reachable. diff --git a/datafusion/expr/src/predicate_bounds.rs b/datafusion/expr/src/predicate_bounds.rs index da2927a8d720..f4cbc53f8ec1 100644 --- a/datafusion/expr/src/predicate_bounds.rs +++ b/datafusion/expr/src/predicate_bounds.rs @@ -19,7 +19,7 @@ use crate::{BinaryExpr, Expr, ExprSchemable}; use arrow::datatypes::DataType; use bitflags::bitflags; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{ExprSchema, ScalarValue}; +use datafusion_common::{DataFusionError, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr_common::operator::Operator; @@ -34,24 +34,6 @@ bitflags! { } impl TernarySet { - fn try_from(value: &ScalarValue) -> TernarySet { - match value { - ScalarValue::Null => TernarySet::UNKNOWN, - ScalarValue::Boolean(b) => match b { - Some(true) => TernarySet::TRUE, - Some(false) => TernarySet::FALSE, - None => TernarySet::UNKNOWN, - }, - _ => { - if let Ok(b) = value.cast_to(&DataType::Boolean) { - Self::try_from(&b) - } else { - TernarySet::empty() - } - } - } - } - /// Returns the set of possible values after applying the `is true` test on all /// values in this set. /// The resulting set can only contain 'TRUE' and/or 'FALSE', never 'UNKNOWN'. @@ -188,6 +170,25 @@ impl TernarySet { } } +impl TryFrom<&ScalarValue> for TernarySet { + type Error = DataFusionError; + + fn try_from(value: &ScalarValue) -> Result { + Ok(match value { + ScalarValue::Null => TernarySet::UNKNOWN, + ScalarValue::Boolean(b) => match b { + Some(true) => TernarySet::TRUE, + Some(false) => TernarySet::FALSE, + None => TernarySet::UNKNOWN, + }, + _ => { + let b = value.cast_to(&DataType::Boolean)?; + Self::try_from(&b)? + } + }) + } +} + /// Computes the output interval for the given boolean expression based on statically /// available information. /// @@ -233,21 +234,19 @@ impl TernarySet { /// * `NullableInterval::MaybeNull { values: Interval::UNCERTAIN }` - The predicate may /// evaluate to any of TRUE, FALSE, or NULL /// -pub(super) fn evaluate_bounds( +pub(super) fn evaluate_bounds( predicate: &Expr, - is_null: F, + certainly_null_expr: Option<&Expr>, input_schema: &dyn ExprSchema, -) -> NullableInterval -where - F: Fn(&Expr) -> Option, -{ +) -> Result { let evaluator = PredicateBoundsEvaluator { input_schema, - is_null, + certainly_null_expr: certainly_null_expr.map(unwrap_certainly_null_expr), }; - let possible_results = evaluator.evaluate_bounds(predicate); + let possible_results = evaluator.evaluate_bounds(predicate)?; - if possible_results.is_empty() || possible_results == TernarySet::all() { + let interval = if possible_results.is_empty() || possible_results == TernarySet::all() + { NullableInterval::MaybeNull { values: Interval::UNCERTAIN, } @@ -279,77 +278,85 @@ where } else { NullableInterval::NotNull { values } } + }; + + Ok(interval) +} + +/// Returns the innermost [Expr] that is provably null if `expr` is null. +fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr { + match expr { + Expr::Not(e) => unwrap_certainly_null_expr(e), + Expr::Negative(e) => unwrap_certainly_null_expr(e), + Expr::Cast(e) => unwrap_certainly_null_expr(e.expr.as_ref()), + _ => expr, } } -pub(super) struct PredicateBoundsEvaluator<'a, F> { +struct PredicateBoundsEvaluator<'a> { input_schema: &'a dyn ExprSchema, - is_null: F, + certainly_null_expr: Option<&'a Expr>, } -impl PredicateBoundsEvaluator<'_, F> -where - F: Fn(&Expr) -> Option, -{ +impl PredicateBoundsEvaluator<'_> { /// Derives the bounds of the given boolean expression - fn evaluate_bounds(&self, predicate: &Expr) -> TernarySet { - match predicate { + fn evaluate_bounds(&self, predicate: &Expr) -> Result { + Ok(match predicate { Expr::Literal(scalar, _) => { // Interpret literals as boolean, coercing if necessary - TernarySet::try_from(scalar) + TernarySet::try_from(scalar)? } - Expr::Negative(e) => self.evaluate_bounds(e), Expr::IsNull(e) => { // If `e` is not nullable, then `e IS NULL` is provably false - if let Ok(false) = e.nullable(self.input_schema) { - return TernarySet::FALSE; - } - - match e.get_type(self.input_schema) { - // If `e` is a boolean expression, try to evaluate it and test for unknown - Ok(DataType::Boolean) => self.evaluate_bounds(e).is_unknown(), - // If `e` is not a boolean expression, check if `e` is provably null - Ok(_) => self.is_null(e), - Err(_) => TernarySet::empty(), + if !e.nullable(self.input_schema)? { + TernarySet::FALSE + } else { + match e.get_type(self.input_schema)? { + // If `e` is a boolean expression, check if `e` is provably 'unknown'. + DataType::Boolean => self.evaluate_bounds(e)?.is_unknown(), + // If `e` is not a boolean expression, check if `e` is provably null + _ => self.is_null(e), + } } } Expr::IsNotNull(e) => { // If `e` is not nullable, then `e IS NOT NULL` is provably true - if let Ok(false) = e.nullable(self.input_schema) { - return TernarySet::TRUE; - } - - match e.get_type(self.input_schema) { - // If `e` is a boolean expression, try to evaluate it and test for not unknown - Ok(DataType::Boolean) => { - TernarySet::not(self.evaluate_bounds(e).is_unknown()) + if !e.nullable(self.input_schema)? { + TernarySet::TRUE + } else { + match e.get_type(self.input_schema)? { + // If `e` is a boolean expression, try to evaluate it and test for not unknown + DataType::Boolean => { + TernarySet::not(self.evaluate_bounds(e)?.is_unknown()) + } + // If `e` is not a boolean expression, check if `e` is provably null + _ => TernarySet::not(self.is_null(e)), } - // If `e` is not a boolean expression, check if `e` is provably null - Ok(_) => TernarySet::not(self.is_null(e)), - Err(_) => TernarySet::empty(), } } - Expr::IsTrue(e) => self.evaluate_bounds(e).is_true(), - Expr::IsNotTrue(e) => TernarySet::not(self.evaluate_bounds(e).is_true()), - Expr::IsFalse(e) => self.evaluate_bounds(e).is_false(), - Expr::IsNotFalse(e) => TernarySet::not(self.evaluate_bounds(e).is_false()), - Expr::IsUnknown(e) => self.evaluate_bounds(e).is_unknown(), + Expr::IsTrue(e) => self.evaluate_bounds(e)?.is_true(), + Expr::IsNotTrue(e) => TernarySet::not(self.evaluate_bounds(e)?.is_true()), + Expr::IsFalse(e) => self.evaluate_bounds(e)?.is_false(), + Expr::IsNotFalse(e) => TernarySet::not(self.evaluate_bounds(e)?.is_false()), + Expr::IsUnknown(e) => self.evaluate_bounds(e)?.is_unknown(), Expr::IsNotUnknown(e) => { - TernarySet::not(self.evaluate_bounds(e).is_unknown()) + TernarySet::not(self.evaluate_bounds(e)?.is_unknown()) } - Expr::Not(e) => TernarySet::not(self.evaluate_bounds(e)), + Expr::Not(e) => TernarySet::not(self.evaluate_bounds(e)?), Expr::BinaryExpr(BinaryExpr { left, op: Operator::And, right, }) => { - TernarySet::and(self.evaluate_bounds(left), self.evaluate_bounds(right)) + TernarySet::and(self.evaluate_bounds(left)?, self.evaluate_bounds(right)?) } Expr::BinaryExpr(BinaryExpr { left, op: Operator::Or, right, - }) => TernarySet::or(self.evaluate_bounds(left), self.evaluate_bounds(right)), + }) => { + TernarySet::or(self.evaluate_bounds(left)?, self.evaluate_bounds(right)?) + } e => { let mut result = TernarySet::empty(); let is_null = self.is_null(e); @@ -366,7 +373,7 @@ where result } - } + }) } /// Determines if the given expression can evaluate to `NULL`. @@ -387,51 +394,52 @@ where return TernarySet::FALSE; } - // Check if the callback can decide for us - if let Some(expr_is_null) = (self.is_null)(expr) { - return if expr_is_null { - TernarySet::TRUE - } else { - TernarySet::FALSE - }; + // Check if the expression is the `certainly_null_expr` that was passed in. + if let Some(certainly_null_expr) = &self.certainly_null_expr { + if expr.eq(certainly_null_expr) { + return TernarySet::TRUE; + } } // `expr` is nullable, so our default answer for `is null` is going to be `{ TRUE, FALSE }`. // Try to see if we can narrow it down to just one option. match expr { + Expr::BinaryExpr(BinaryExpr { op, .. }) if op.returns_null_on_null() => { + self.is_null_if_any_child_null(expr) + } Expr::Alias(_) - | Expr::Between(_) - | Expr::BinaryExpr(_) | Expr::Cast(_) | Expr::Like(_) | Expr::Negative(_) | Expr::Not(_) - | Expr::SimilarTo(_) => { - // These expressions are null if any of their direct children is null - // If any child is inconclusive, the result for this expression is also inconclusive - let mut is_null = TernarySet::FALSE.clone(); - let _ = expr.apply_children(|child| { - let child_is_null = self.is_null(child); - - if child_is_null.contains(TernarySet::TRUE) { - // If a child might be null, then the result may also be null - is_null.insert(TernarySet::TRUE); - } - - if !child_is_null.contains(TernarySet::FALSE) { - // If the child is never not null, then the result can also never be not null - // and we can stop traversing the children - is_null.remove(TernarySet::FALSE); - Ok(TreeNodeRecursion::Stop) - } else { - Ok(TreeNodeRecursion::Continue) - } - }); - is_null - } + | Expr::SimilarTo(_) => self.is_null_if_any_child_null(expr), _ => TernarySet::TRUE | TernarySet::FALSE, } } + + fn is_null_if_any_child_null(&self, expr: &Expr) -> TernarySet { + // These expressions are null if any of their direct children is null + // If any child is inconclusive, the result for this expression is also inconclusive + let mut is_null = TernarySet::FALSE.clone(); + let _ = expr.apply_children(|child| { + let child_is_null = self.is_null(child); + + if child_is_null.contains(TernarySet::TRUE) { + // If a child might be null, then the result may also be null + is_null.insert(TernarySet::TRUE); + } + + if !child_is_null.contains(TernarySet::FALSE) { + // If the child is never not null, then the result can also never be not null + // and we can stop traversing the children + is_null.remove(TernarySet::FALSE); + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }); + is_null + } } #[cfg(test)] @@ -443,7 +451,7 @@ mod tests { is_not_unknown, is_null, is_true, is_unknown, lit, not, Expr, }; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{DFSchema, ExprSchema, ScalarValue}; + use datafusion_common::{DFSchema, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::operator::Operator::{And, Eq, Or}; use datafusion_expr_common::signature::Volatility; @@ -459,14 +467,16 @@ mod tests { (ScalarValue::UInt8(None), TernarySet::UNKNOWN), (ScalarValue::UInt8(Some(0)), TernarySet::FALSE), (ScalarValue::UInt8(Some(1)), TernarySet::TRUE), - ( - ScalarValue::Utf8(Some("abc".to_string())), - TernarySet::empty(), - ), ]; for case in cases { - assert_eq!(TernarySet::try_from(&case.0), case.1); + assert_eq!(TernarySet::try_from(&case.0).unwrap(), case.1); + } + + let error_cases = vec![ScalarValue::Utf8(Some("abc".to_string()))]; + + for case in error_cases { + assert!(TernarySet::try_from(&case).is_err()); } } @@ -649,64 +659,55 @@ mod tests { } } - fn const_eval_predicate( + fn try_eval_predicate_bounds( predicate: &Expr, - evaluates_to_null: F, + evaluates_to_null: Option<&Expr>, input_schema: &dyn ExprSchema, - ) -> Option - where - F: Fn(&Expr) -> Option, - { - let bounds = evaluate_bounds(predicate, evaluates_to_null, input_schema); + ) -> Result> { + let bounds = evaluate_bounds(predicate, evaluates_to_null, input_schema)?; - if bounds.is_certainly_true() { + Ok(if bounds.is_certainly_true() { Some(true) } else if bounds.is_certainly_not_true() { Some(false) } else { None - } - } - - fn const_eval(predicate: &Expr) -> Option { - let schema = DFSchema::try_from(Schema::empty()).unwrap(); - const_eval_predicate(predicate, |_| None, &schema) + }) } - fn const_eval_with_null( + fn eval_predicate_bounds( predicate: &Expr, - schema: &DFSchema, - null_expr: &Expr, + evaluates_to_null: Option<&Expr>, + input_schema: &dyn ExprSchema, ) -> Option { - const_eval_predicate( - predicate, - |e| { - if e.eq(null_expr) { - Some(true) - } else { - None - } - }, - schema, - ) + try_eval_predicate_bounds(predicate, evaluates_to_null, input_schema).unwrap() + } + + fn try_eval_bounds(predicate: &Expr) -> Result> { + let schema = DFSchema::try_from(Schema::empty())?; + try_eval_predicate_bounds(predicate, None, &schema) + } + + fn eval_bounds(predicate: &Expr) -> Option { + try_eval_bounds(predicate).unwrap() } #[test] - fn predicate_eval_literal() { - assert_eq!(const_eval(&lit(ScalarValue::Null)), Some(false)); + fn evaluate_bounds_literal() { + assert_eq!(eval_bounds(&lit(ScalarValue::Null)), Some(false)); - assert_eq!(const_eval(&lit(false)), Some(false)); - assert_eq!(const_eval(&lit(true)), Some(true)); + assert_eq!(eval_bounds(&lit(false)), Some(false)); + assert_eq!(eval_bounds(&lit(true)), Some(true)); - assert_eq!(const_eval(&lit(0)), Some(false)); - assert_eq!(const_eval(&lit(1)), Some(true)); + assert_eq!(eval_bounds(&lit(0)), Some(false)); + assert_eq!(eval_bounds(&lit(1)), Some(true)); - assert_eq!(const_eval(&lit("foo")), None); - assert_eq!(const_eval(&lit(ScalarValue::Utf8(None))), Some(false)); + assert_eq!(eval_bounds(&lit(ScalarValue::Utf8(None))), Some(false)); + assert!(try_eval_bounds(&lit("foo")).is_err()); } #[test] - fn predicate_eval_and() { + fn evaluate_bounds_and() { let null = lit(ScalarValue::Null); let zero = lit(0); let one = lit(1); @@ -715,83 +716,89 @@ mod tests { let func = make_scalar_func_expr(); assert_eq!( - const_eval(&binary_expr(null.clone(), And, null.clone())), + eval_bounds(&binary_expr(null.clone(), And, null.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(null.clone(), And, one.clone())), + eval_bounds(&binary_expr(null.clone(), And, one.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(null.clone(), And, zero.clone())), + eval_bounds(&binary_expr(null.clone(), And, zero.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(one.clone(), And, one.clone())), + eval_bounds(&binary_expr(one.clone(), And, one.clone())), Some(true) ); assert_eq!( - const_eval(&binary_expr(one.clone(), And, zero.clone())), + eval_bounds(&binary_expr(one.clone(), And, zero.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(null.clone(), And, t.clone())), + eval_bounds(&binary_expr(null.clone(), And, t.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(t.clone(), And, null.clone())), + eval_bounds(&binary_expr(t.clone(), And, null.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(null.clone(), And, f.clone())), + eval_bounds(&binary_expr(null.clone(), And, f.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(f.clone(), And, null.clone())), + eval_bounds(&binary_expr(f.clone(), And, null.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(t.clone(), And, t.clone())), + eval_bounds(&binary_expr(t.clone(), And, t.clone())), Some(true) ); assert_eq!( - const_eval(&binary_expr(t.clone(), And, f.clone())), + eval_bounds(&binary_expr(t.clone(), And, f.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(f.clone(), And, t.clone())), + eval_bounds(&binary_expr(f.clone(), And, t.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(f.clone(), And, f.clone())), + eval_bounds(&binary_expr(f.clone(), And, f.clone())), Some(false) ); - assert_eq!(const_eval(&binary_expr(t.clone(), And, func.clone())), None); - assert_eq!(const_eval(&binary_expr(func.clone(), And, t.clone())), None); assert_eq!( - const_eval(&binary_expr(f.clone(), And, func.clone())), + eval_bounds(&binary_expr(t.clone(), And, func.clone())), + None + ); + assert_eq!( + eval_bounds(&binary_expr(func.clone(), And, t.clone())), + None + ); + assert_eq!( + eval_bounds(&binary_expr(f.clone(), And, func.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(func.clone(), And, f.clone())), + eval_bounds(&binary_expr(func.clone(), And, f.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(null.clone(), And, func.clone())), + eval_bounds(&binary_expr(null.clone(), And, func.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(func.clone(), And, null.clone())), + eval_bounds(&binary_expr(func.clone(), And, null.clone())), Some(false) ); } #[test] - fn predicate_eval_or() { + fn evaluate_bounds_or() { let null = lit(ScalarValue::Null); let zero = lit(0); let one = lit(1); @@ -800,83 +807,83 @@ mod tests { let func = make_scalar_func_expr(); assert_eq!( - const_eval(&binary_expr(null.clone(), Or, null.clone())), + eval_bounds(&binary_expr(null.clone(), Or, null.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(null.clone(), Or, one.clone())), + eval_bounds(&binary_expr(null.clone(), Or, one.clone())), Some(true) ); assert_eq!( - const_eval(&binary_expr(null.clone(), Or, zero.clone())), + eval_bounds(&binary_expr(null.clone(), Or, zero.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(one.clone(), Or, one.clone())), + eval_bounds(&binary_expr(one.clone(), Or, one.clone())), Some(true) ); assert_eq!( - const_eval(&binary_expr(one.clone(), Or, zero.clone())), + eval_bounds(&binary_expr(one.clone(), Or, zero.clone())), Some(true) ); assert_eq!( - const_eval(&binary_expr(null.clone(), Or, t.clone())), + eval_bounds(&binary_expr(null.clone(), Or, t.clone())), Some(true) ); assert_eq!( - const_eval(&binary_expr(t.clone(), Or, null.clone())), + eval_bounds(&binary_expr(t.clone(), Or, null.clone())), Some(true) ); assert_eq!( - const_eval(&binary_expr(null.clone(), Or, f.clone())), + eval_bounds(&binary_expr(null.clone(), Or, f.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(f.clone(), Or, null.clone())), + eval_bounds(&binary_expr(f.clone(), Or, null.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(t.clone(), Or, t.clone())), + eval_bounds(&binary_expr(t.clone(), Or, t.clone())), Some(true) ); assert_eq!( - const_eval(&binary_expr(t.clone(), Or, f.clone())), + eval_bounds(&binary_expr(t.clone(), Or, f.clone())), Some(true) ); assert_eq!( - const_eval(&binary_expr(f.clone(), Or, t.clone())), + eval_bounds(&binary_expr(f.clone(), Or, t.clone())), Some(true) ); assert_eq!( - const_eval(&binary_expr(f.clone(), Or, f.clone())), + eval_bounds(&binary_expr(f.clone(), Or, f.clone())), Some(false) ); assert_eq!( - const_eval(&binary_expr(t.clone(), Or, func.clone())), + eval_bounds(&binary_expr(t.clone(), Or, func.clone())), Some(true) ); assert_eq!( - const_eval(&binary_expr(func.clone(), Or, t.clone())), + eval_bounds(&binary_expr(func.clone(), Or, t.clone())), Some(true) ); - assert_eq!(const_eval(&binary_expr(f.clone(), Or, func.clone())), None); - assert_eq!(const_eval(&binary_expr(func.clone(), Or, f.clone())), None); + assert_eq!(eval_bounds(&binary_expr(f.clone(), Or, func.clone())), None); + assert_eq!(eval_bounds(&binary_expr(func.clone(), Or, f.clone())), None); assert_eq!( - const_eval(&binary_expr(null.clone(), Or, func.clone())), + eval_bounds(&binary_expr(null.clone(), Or, func.clone())), None ); assert_eq!( - const_eval(&binary_expr(func.clone(), Or, null.clone())), + eval_bounds(&binary_expr(func.clone(), Or, null.clone())), None ); } #[test] - fn predicate_eval_not() { + fn evaluate_bounds_not() { let null = lit(ScalarValue::Null); let zero = lit(0); let one = lit(1); @@ -884,18 +891,18 @@ mod tests { let f = lit(false); let func = make_scalar_func_expr(); - assert_eq!(const_eval(¬(null.clone())), Some(false)); - assert_eq!(const_eval(¬(one.clone())), Some(false)); - assert_eq!(const_eval(¬(zero.clone())), Some(true)); + assert_eq!(eval_bounds(¬(null.clone())), Some(false)); + assert_eq!(eval_bounds(¬(one.clone())), Some(false)); + assert_eq!(eval_bounds(¬(zero.clone())), Some(true)); - assert_eq!(const_eval(¬(t.clone())), Some(false)); - assert_eq!(const_eval(¬(f.clone())), Some(true)); + assert_eq!(eval_bounds(¬(t.clone())), Some(false)); + assert_eq!(eval_bounds(¬(f.clone())), Some(true)); - assert_eq!(const_eval(¬(func.clone())), None); + assert_eq!(eval_bounds(¬(func.clone())), None); } #[test] - fn predicate_eval_is() { + fn evaluate_bounds_is() { let null = lit(ScalarValue::Null); let zero = lit(0); let one = lit(1); @@ -915,73 +922,77 @@ mod tests { )])) .unwrap(); - assert_eq!(const_eval(&is_null(null.clone())), Some(true)); - assert_eq!(const_eval(&is_null(one.clone())), Some(false)); + assert_eq!(eval_bounds(&is_null(null.clone())), Some(true)); + assert_eq!(eval_bounds(&is_null(one.clone())), Some(false)); + let predicate = &is_null(col.clone()); assert_eq!( - const_eval_with_null(&is_null(col.clone()), &nullable_schema, &col), + eval_predicate_bounds(predicate, Some(&col), &nullable_schema), Some(true) ); + let predicate = &is_null(col.clone()); assert_eq!( - const_eval_with_null(&is_null(col.clone()), ¬_nullable_schema, &col), + eval_predicate_bounds(predicate, Some(&col), ¬_nullable_schema), Some(false) ); - assert_eq!(const_eval(&is_not_null(null.clone())), Some(false)); - assert_eq!(const_eval(&is_not_null(one.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_null(null.clone())), Some(false)); + assert_eq!(eval_bounds(&is_not_null(one.clone())), Some(true)); + let predicate = &is_not_null(col.clone()); assert_eq!( - const_eval_with_null(&is_not_null(col.clone()), &nullable_schema, &col), + eval_predicate_bounds(predicate, Some(&col), &nullable_schema), Some(false) ); + let predicate = &is_not_null(col.clone()); assert_eq!( - const_eval_with_null(&is_not_null(col.clone()), ¬_nullable_schema, &col), + eval_predicate_bounds(predicate, Some(&col), ¬_nullable_schema), Some(true) ); - assert_eq!(const_eval(&is_true(null.clone())), Some(false)); - assert_eq!(const_eval(&is_true(t.clone())), Some(true)); - assert_eq!(const_eval(&is_true(f.clone())), Some(false)); - assert_eq!(const_eval(&is_true(zero.clone())), Some(false)); - assert_eq!(const_eval(&is_true(one.clone())), Some(true)); - - assert_eq!(const_eval(&is_not_true(null.clone())), Some(true)); - assert_eq!(const_eval(&is_not_true(t.clone())), Some(false)); - assert_eq!(const_eval(&is_not_true(f.clone())), Some(true)); - assert_eq!(const_eval(&is_not_true(zero.clone())), Some(true)); - assert_eq!(const_eval(&is_not_true(one.clone())), Some(false)); - - assert_eq!(const_eval(&is_false(null.clone())), Some(false)); - assert_eq!(const_eval(&is_false(t.clone())), Some(false)); - assert_eq!(const_eval(&is_false(f.clone())), Some(true)); - assert_eq!(const_eval(&is_false(zero.clone())), Some(true)); - assert_eq!(const_eval(&is_false(one.clone())), Some(false)); - - assert_eq!(const_eval(&is_not_false(null.clone())), Some(true)); - assert_eq!(const_eval(&is_not_false(t.clone())), Some(true)); - assert_eq!(const_eval(&is_not_false(f.clone())), Some(false)); - assert_eq!(const_eval(&is_not_false(zero.clone())), Some(false)); - assert_eq!(const_eval(&is_not_false(one.clone())), Some(true)); - - assert_eq!(const_eval(&is_unknown(null.clone())), Some(true)); - assert_eq!(const_eval(&is_unknown(t.clone())), Some(false)); - assert_eq!(const_eval(&is_unknown(f.clone())), Some(false)); - assert_eq!(const_eval(&is_unknown(zero.clone())), Some(false)); - assert_eq!(const_eval(&is_unknown(one.clone())), Some(false)); - - assert_eq!(const_eval(&is_not_unknown(null.clone())), Some(false)); - assert_eq!(const_eval(&is_not_unknown(t.clone())), Some(true)); - assert_eq!(const_eval(&is_not_unknown(f.clone())), Some(true)); - assert_eq!(const_eval(&is_not_unknown(zero.clone())), Some(true)); - assert_eq!(const_eval(&is_not_unknown(one.clone())), Some(true)); + assert_eq!(eval_bounds(&is_true(null.clone())), Some(false)); + assert_eq!(eval_bounds(&is_true(t.clone())), Some(true)); + assert_eq!(eval_bounds(&is_true(f.clone())), Some(false)); + assert_eq!(eval_bounds(&is_true(zero.clone())), Some(false)); + assert_eq!(eval_bounds(&is_true(one.clone())), Some(true)); + + assert_eq!(eval_bounds(&is_not_true(null.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_true(t.clone())), Some(false)); + assert_eq!(eval_bounds(&is_not_true(f.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_true(zero.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_true(one.clone())), Some(false)); + + assert_eq!(eval_bounds(&is_false(null.clone())), Some(false)); + assert_eq!(eval_bounds(&is_false(t.clone())), Some(false)); + assert_eq!(eval_bounds(&is_false(f.clone())), Some(true)); + assert_eq!(eval_bounds(&is_false(zero.clone())), Some(true)); + assert_eq!(eval_bounds(&is_false(one.clone())), Some(false)); + + assert_eq!(eval_bounds(&is_not_false(null.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_false(t.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_false(f.clone())), Some(false)); + assert_eq!(eval_bounds(&is_not_false(zero.clone())), Some(false)); + assert_eq!(eval_bounds(&is_not_false(one.clone())), Some(true)); + + assert_eq!(eval_bounds(&is_unknown(null.clone())), Some(true)); + assert_eq!(eval_bounds(&is_unknown(t.clone())), Some(false)); + assert_eq!(eval_bounds(&is_unknown(f.clone())), Some(false)); + assert_eq!(eval_bounds(&is_unknown(zero.clone())), Some(false)); + assert_eq!(eval_bounds(&is_unknown(one.clone())), Some(false)); + + assert_eq!(eval_bounds(&is_not_unknown(null.clone())), Some(false)); + assert_eq!(eval_bounds(&is_not_unknown(t.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_unknown(f.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_unknown(zero.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_unknown(one.clone())), Some(true)); } #[test] - fn predicate_eval_udf() { + fn evaluate_bounds_udf() { let func = make_scalar_func_expr(); - assert_eq!(const_eval(&func.clone()), None); - assert_eq!(const_eval(¬(func.clone())), None); + assert_eq!(eval_bounds(&func.clone()), None); + assert_eq!(eval_bounds(¬(func.clone())), None); assert_eq!( - const_eval(&binary_expr(func.clone(), And, func.clone())), + eval_bounds(&binary_expr(func.clone(), And, func.clone())), None ); } @@ -1000,7 +1011,7 @@ mod tests { } #[test] - fn predicate_eval_when_then() { + fn evaluate_bounds_when_then() { let nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new("x", DataType::UInt8, true)])) .unwrap(); @@ -1021,11 +1032,11 @@ mod tests { ); assert_eq!( - const_eval_with_null(&when, &nullable_schema, &x), + eval_predicate_bounds(&when, Some(&x), &nullable_schema), Some(false) ); assert_eq!( - const_eval_with_null(&when, ¬_nullable_schema, &x), + eval_predicate_bounds(&when, Some(&x), ¬_nullable_schema), Some(true) ); } diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 2bd644d8a8ac..3905575d22dc 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -689,3 +689,17 @@ query I SELECT CASE 0 WHEN 0 THEN NULL WHEN SUM(1) + COUNT(*) THEN 10 ELSE 20 END ---- NULL + +query TT +EXPLAIN SELECT CASE WHEN CASE WHEN a IS NOT NULL THEN a ELSE 1 END IS NOT NULL THEN a ELSE 1 END FROM ( + VALUES (10), (20), (30) + ) t(a); +---- +logical_plan +01)Projection: t.a AS CASE WHEN CASE WHEN t.a IS NOT NULL THEN t.a ELSE Int64(1) END IS NOT NULL THEN t.a ELSE Int64(1) END +02)--SubqueryAlias: t +03)----Projection: column1 AS a +04)------Values: (Int64(10)), (Int64(20)), (Int64(30)) +physical_plan +01)ProjectionExec: expr=[column1@0 as CASE WHEN CASE WHEN t.a IS NOT NULL THEN t.a ELSE Int64(1) END IS NOT NULL THEN t.a ELSE Int64(1) END] +02)--DataSourceExec: partitions=1, partition_sizes=[1] From c5914d63afa13b172f3b3bf9c829fd89e11a9e0c Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 10 Nov 2025 23:09:17 +0100 Subject: [PATCH 22/23] Update bitflags version declaration to match arrow-schema --- Cargo.lock | 2 +- datafusion/expr/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4881b029db1d..d0af7424ee07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2259,7 +2259,7 @@ version = "51.0.0" dependencies = [ "arrow", "async-trait", - "bitflags 2.9.4", + "bitflags 2.10.0", "chrono", "ctor", "datafusion-common", diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 924bd570710f..84be57023d9a 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -48,7 +48,7 @@ sql = ["sqlparser"] [dependencies] arrow = { workspace = true } async-trait = { workspace = true } -bitflags = "2.9.4" +bitflags = "2.0.0" chrono = { workspace = true } datafusion-common = { workspace = true, default-features = false } datafusion-doc = { workspace = true } From 4b879e4a8d3e2e88683e416677ec921359bb29fb Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Mon, 10 Nov 2025 23:27:10 +0100 Subject: [PATCH 23/23] Silence "needless pass by value" lint --- datafusion/expr/src/predicate_bounds.rs | 60 +++++++++++++------------ 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/datafusion/expr/src/predicate_bounds.rs b/datafusion/expr/src/predicate_bounds.rs index f4cbc53f8ec1..547db3c9b0f0 100644 --- a/datafusion/expr/src/predicate_bounds.rs +++ b/datafusion/expr/src/predicate_bounds.rs @@ -88,7 +88,7 @@ impl TernarySet { /// U | U /// T | F /// ``` - fn not(set: Self) -> Self { + fn not(set: &Self) -> Self { let mut not = Self::empty(); if set.contains(Self::TRUE) { not.toggle(Self::FALSE); @@ -114,7 +114,7 @@ impl TernarySet { /// U │ F U U /// T │ F U T /// ``` - fn and(lhs: Self, rhs: Self) -> Self { + fn and(lhs: &Self, rhs: &Self) -> Self { if lhs.is_empty() || rhs.is_empty() { return Self::empty(); } @@ -149,7 +149,7 @@ impl TernarySet { /// U │ U U T /// T │ T T T /// ``` - fn or(lhs: Self, rhs: Self) -> Self { + fn or(lhs: &Self, rhs: &Self) -> Self { let mut or = Self::empty(); if lhs.contains(Self::TRUE) || rhs.contains(Self::TRUE) { or.toggle(Self::TRUE); @@ -327,36 +327,38 @@ impl PredicateBoundsEvaluator<'_> { match e.get_type(self.input_schema)? { // If `e` is a boolean expression, try to evaluate it and test for not unknown DataType::Boolean => { - TernarySet::not(self.evaluate_bounds(e)?.is_unknown()) + TernarySet::not(&self.evaluate_bounds(e)?.is_unknown()) } // If `e` is not a boolean expression, check if `e` is provably null - _ => TernarySet::not(self.is_null(e)), + _ => TernarySet::not(&self.is_null(e)), } } } Expr::IsTrue(e) => self.evaluate_bounds(e)?.is_true(), - Expr::IsNotTrue(e) => TernarySet::not(self.evaluate_bounds(e)?.is_true()), + Expr::IsNotTrue(e) => TernarySet::not(&self.evaluate_bounds(e)?.is_true()), Expr::IsFalse(e) => self.evaluate_bounds(e)?.is_false(), - Expr::IsNotFalse(e) => TernarySet::not(self.evaluate_bounds(e)?.is_false()), + Expr::IsNotFalse(e) => TernarySet::not(&self.evaluate_bounds(e)?.is_false()), Expr::IsUnknown(e) => self.evaluate_bounds(e)?.is_unknown(), Expr::IsNotUnknown(e) => { - TernarySet::not(self.evaluate_bounds(e)?.is_unknown()) + TernarySet::not(&self.evaluate_bounds(e)?.is_unknown()) } - Expr::Not(e) => TernarySet::not(self.evaluate_bounds(e)?), + Expr::Not(e) => TernarySet::not(&self.evaluate_bounds(e)?), Expr::BinaryExpr(BinaryExpr { left, op: Operator::And, right, - }) => { - TernarySet::and(self.evaluate_bounds(left)?, self.evaluate_bounds(right)?) - } + }) => TernarySet::and( + &self.evaluate_bounds(left)?, + &self.evaluate_bounds(right)?, + ), Expr::BinaryExpr(BinaryExpr { left, op: Operator::Or, right, - }) => { - TernarySet::or(self.evaluate_bounds(left)?, self.evaluate_bounds(right)?) - } + }) => TernarySet::or( + &self.evaluate_bounds(left)?, + &self.evaluate_bounds(right)?, + ), e => { let mut result = TernarySet::empty(); let is_null = self.is_null(e); @@ -505,7 +507,7 @@ mod tests { ]; for case in cases { - assert_eq!(TernarySet::not(case.0), case.1); + assert_eq!(TernarySet::not(&case.0), case.1); } } @@ -571,7 +573,7 @@ mod tests { for case in cases { assert_eq!( - TernarySet::and(case.0.clone(), case.1.clone()), + TernarySet::and(&case.0, &case.1), case.2.clone(), "{:?} & {:?} = {:?}", case.0.clone(), @@ -579,12 +581,12 @@ mod tests { case.2.clone() ); assert_eq!( - TernarySet::and(case.1.clone(), case.0.clone()), + TernarySet::and(&case.1, &case.0), case.2.clone(), "{:?} & {:?} = {:?}", - case.1.clone(), - case.0.clone(), - case.2.clone() + case.1, + case.0, + case.2 ); } } @@ -641,20 +643,20 @@ mod tests { for case in cases { assert_eq!( - TernarySet::or(case.0.clone(), case.1.clone()), + TernarySet::or(&case.0, &case.1), case.2.clone(), "{:?} | {:?} = {:?}", - case.0.clone(), - case.1.clone(), - case.2.clone() + case.0, + case.1, + case.2 ); assert_eq!( - TernarySet::or(case.1.clone(), case.0.clone()), + TernarySet::or(&case.1, &case.0), case.2.clone(), "{:?} | {:?} = {:?}", - case.1.clone(), - case.0.clone(), - case.2.clone() + case.1, + case.0, + case.2 ); } }