diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index 6f61c164f41d..811fca6a4dde 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -66,8 +66,8 @@ async fn main() -> Result<()> { write_out(&ctx).await?; register_aggregate_test_data("t1", &ctx).await?; register_aggregate_test_data("t2", &ctx).await?; - where_scalar_subquery(&ctx).await?; - where_in_subquery(&ctx).await?; + Box::pin(where_scalar_subquery(&ctx)).await?; + Box::pin(where_in_subquery(&ctx)).await?; where_exist_subquery(&ctx).await?; Ok(()) } diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 3f9429bff71d..c3af4d9ba221 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -734,6 +734,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { } } +/// Coercion rules for boolean types: If at least one argument is +/// a boolean type and both arguments can be coerced into a boolean type, coerce +/// to boolean type. +fn boolean_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Boolean, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64) + | (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Boolean) => { + Some(Boolean) + } + _ => None, + } +} + /// Returns the output type of applying mathematics operations such as /// `+` to arguments of `lhs_type` and `rhs_type`. fn mathematics_numerical_coercion( @@ -2434,6 +2449,32 @@ mod tests { DataType::List(Arc::clone(&inner_field)) ); + // boolean + let int_types = vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + ]; + for int_type in int_types { + test_coercion_binary_rule!( + DataType::Boolean, + int_type, + Operator::Eq, + DataType::Boolean + ); + test_coercion_binary_rule!( + int_type, + DataType::Boolean, + Operator::Eq, + DataType::Boolean + ); + } + // Negative test: inner_timestamp_field and inner_field are not compatible because their inner types are not compatible let inner_timestamp_field = Arc::new(Field::new_list_field( DataType::Timestamp(TimeUnit::Microsecond, None), diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 07a069cbb400..5abc375510e5 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -797,12 +797,12 @@ mod tests { let pivot = Pivot { input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: schema.clone(), + schema: Arc::clone(&schema), })), aggregate_expr: Expr::Column(Column::from_name("sum_value")), pivot_column: Column::from_name("category"), pivot_values, - schema: schema.clone(), + schema: Arc::clone(&schema), value_subquery: None, default_on_null_expr: None, }; @@ -834,15 +834,15 @@ mod tests { let pivot = Pivot { input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: schema.clone(), + schema: Arc::clone(&schema), })), aggregate_expr: Expr::Column(Column::from_name("sum_value")), pivot_column: Column::from_name("category"), pivot_values: vec![], - schema: schema.clone(), + schema: Arc::clone(&schema), value_subquery: Some(Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: schema.clone(), + schema: Arc::clone(&schema), }))), default_on_null_expr: None, }; diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index ccea816ccf78..142fdf815a7e 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -166,7 +166,7 @@ mod tests { use arrow::datatypes::DataType; use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; #[test] diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index d47f7ea6ce68..df9738c70d24 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1572,11 +1572,18 @@ mod test { let expected = "Projection: a IS TRUE\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - let empty = empty_with_type(DataType::Int64); + let empty = empty_with_type(DataType::Float64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, ""); let err = ret.unwrap_err().to_string(); - assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}"); + assert!(err.contains("Cannot infer common argument type for comparison operation Float64 IS DISTINCT FROM Boolean"), "{err}"); + + // integer + let expr = col("a").is_true(); + let empty = empty_with_type(DataType::Int64); + let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + let expected = "Projection: CAST(a AS Boolean) IS TRUE\n EmptyRelation"; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; // is not true let expr = col("a").is_not_true(); diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 854c715eb0a2..494f9fe0d211 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -1086,7 +1086,7 @@ mod tests { #[test] fn case_test_incompatible() -> Result<()> { - // 1 then is int64 + // 1 then is float64 // 2 then is boolean let batch = case_test_batch()?; let schema = batch.schema(); @@ -1098,7 +1098,7 @@ mod tests { lit("foo"), &batch.schema(), )?; - let then1 = lit(123i32); + let then1 = lit(1.23f64); let when2 = binary( col("a", &schema)?, Operator::Eq, diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index f583d659fd4f..dae5fc68b3c4 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1536,8 +1536,10 @@ SELECT not(true), not(false) ---- false true -query error type_coercion\ncaused by\nError during planning: Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean +query BB SELECT not(1), not(0) +---- +false true query ?B SELECT null, not(null) diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 8fb8a59fb860..4f5b4b779d73 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -1075,8 +1075,8 @@ use datafusion_expr::Expr; pub struct EchoFunction {} impl TableFunctionImpl for EchoFunction { - fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { + fn call(&self, exprs: &[(datafusion_expr::Expr, Option)]) -> Result> { + let Some((Expr::Literal(ScalarValue::Int64(Some(value))), _)) = exprs.get(0) else { return plan_err!("First argument must be an integer"); }; @@ -1116,8 +1116,8 @@ With the UDTF implemented, you can register it with the `SessionContext`: # pub struct EchoFunction {} # # impl TableFunctionImpl for EchoFunction { -# fn call(&self, exprs: &[Expr]) -> Result> { -# let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { +# fn call(&self, exprs: &[(datafusion_expr::Expr, Option)]) -> Result> { +# let Some((Expr::Literal(ScalarValue::Int64(Some(value))), _)) = exprs.get(0) else { # return plan_err!("First argument must be an integer"); # }; #