diff --git a/Cargo.lock b/Cargo.lock index 567c0843d31c..7bba669c2982 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6097,8 +6097,7 @@ dependencies = [ [[package]] name = "sqlparser" version = "0.58.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec4b661c54b1e4b603b37873a18c59920e4c51ea8ea2cf527d925424dbd4437c" +source = "git+https://github.com/Embucket/datafusion-sqlparser-rs.git?rev=28012078c09714542c95ddbad8203a282fa37cb2#28012078c09714542c95ddbad8203a282fa37cb2" dependencies = [ "log", "recursive", @@ -6108,8 +6107,7 @@ dependencies = [ [[package]] name = "sqlparser_derive" version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" +source = "git+https://github.com/Embucket/datafusion-sqlparser-rs.git?rev=28012078c09714542c95ddbad8203a282fa37cb2#28012078c09714542c95ddbad8203a282fa37cb2" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index ae004056b420..4ef933eef4d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -173,7 +173,9 @@ recursive = "0.1.1" regex = "1.11" rstest = "0.25.0" serde_json = "1" -sqlparser = { version = "0.58.0", default-features = false, features = ["std", "visitor"] } +sqlparser = { git = "https://github.com/Embucket/datafusion-sqlparser-rs.git", rev = "28012078c09714542c95ddbad8203a282fa37cb2", features = [ + "visitor", +] } tempfile = "3" testcontainers = { version = "0.24", features = ["default"] } testcontainers-modules = { version = "0.12" } diff --git a/FETCH_HEAD b/FETCH_HEAD new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 3ec446c51583..17dfc94e390f 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -321,10 +321,14 @@ fn fixed_len_byte_array_to_string(val: Option<&FixedLenByteArray>) -> Option Result> { + fn call(&self, exprs: &[(Expr, Option)]) -> Result> { + if exprs.is_empty() { + return plan_err!("parquet_metadata requires string argument as its input"); + } + let filename = match exprs.first() { - Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet') - Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") + Some((Expr::Literal(ScalarValue::Utf8(Some(s)), _), _)) => s, // single quote: parquet_metadata('x.parquet') + Some((Expr::Column(Column { name, .. }), _)) => name, // double quote: parquet_metadata("x.parquet") _ => { return plan_err!( "parquet_metadata requires string argument as its input" @@ -510,7 +514,7 @@ impl MetadataCacheFunc { } impl TableFunctionImpl for MetadataCacheFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call(&self, exprs: &[(Expr, Option)]) -> Result> { if !exprs.is_empty() { return plan_err!("metadata_cache should have no arguments"); } diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index a5ee571a1476..826c92af00db 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -69,8 +69,8 @@ async fn main() -> Result<()> { write_out(&ctx).await?; register_aggregate_test_data("t1", &ctx).await?; register_aggregate_test_data("t2", &ctx).await?; - where_scalar_subquery(&ctx).await?; - where_in_subquery(&ctx).await?; + Box::pin(where_scalar_subquery(&ctx)).await?; + Box::pin(where_in_subquery(&ctx)).await?; where_exist_subquery(&ctx).await?; Ok(()) } diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index b65ffb8d7174..ab4ecd40fb81 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -132,15 +132,16 @@ impl TableProvider for LocalCsvTable { struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { - fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)), _)) = exprs.first() + fn call(&self, exprs: &[(Expr, Option)]) -> Result> { + let Some((Expr::Literal(ScalarValue::Utf8(Some(ref path)), _), _)) = + exprs.first() else { return plan_err!("read_csv requires at least one string argument"); }; let limit = exprs .get(1) - .map(|expr| { + .map(|(expr, _)| { // try to simplify the expression, so 1+2 becomes 3, for example let execution_props = ExecutionProps::new(); let info = SimplifyContext::new(&execution_props); diff --git a/datafusion/catalog/src/default_table_source.rs b/datafusion/catalog/src/default_table_source.rs index 11963c06c88f..6bd7366db483 100644 --- a/datafusion/catalog/src/default_table_source.rs +++ b/datafusion/catalog/src/default_table_source.rs @@ -83,6 +83,10 @@ impl TableSource for DefaultTableSource { fn get_column_default(&self, column: &str) -> Option<&Expr> { self.table_provider.get_column_default(column) } + + fn statistics(&self) -> Option { + self.table_provider.statistics() + } } /// Wrap TableProvider in TableSource diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index ac2e1884ba92..6039a7d72293 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -316,7 +316,7 @@ pub trait TableProviderFactory: Debug + Sync + Send { /// A trait for table function implementations pub trait TableFunctionImpl: Debug + Sync + Send { /// Create a table provider - fn call(&self, args: &[Expr]) -> Result>; + fn call(&self, args: &[(Expr, Option)]) -> Result>; } /// A table that uses a function to generate data @@ -345,7 +345,10 @@ impl TableFunction { } /// Get the function implementation and generate a table - pub fn create_table_provider(&self, args: &[Expr]) -> Result> { + pub fn create_table_provider( + &self, + args: &[(Expr, Option)], + ) -> Result> { self.fun.call(args) } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index a7b3bdeeace8..99a2282aeb34 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -274,7 +274,9 @@ impl Session for SessionState { } impl SessionState { - pub(crate) fn resolve_table_ref( + /// Resolve a [`TableReference`] into a [`ResolvedTableReference`] using + /// the session's configured default catalog and schema. + pub fn resolve_table_ref( &self, table_ref: impl Into, ) -> ResolvedTableReference { @@ -474,7 +476,8 @@ impl SessionState { query.statement_to_plan(statement) } - fn get_parser_options(&self) -> ParserOptions { + /// Get the parser options + pub fn get_parser_options(&self) -> ParserOptions { let sql_parser_options = &self.config.options().sql_parser; ParserOptions { @@ -1634,9 +1637,11 @@ impl From for SessionStateBuilder { /// /// This is used so the SQL planner can access the state of the session without /// having a direct dependency on the [`SessionState`] struct (and core crate) -struct SessionContextProvider<'a> { - state: &'a SessionState, - tables: HashMap>, +pub struct SessionContextProvider<'a> { + /// The session state + pub state: &'a SessionState, + /// The tables available in the session + pub tables: HashMap>, } impl ContextProvider for SessionContextProvider<'_> { @@ -1666,20 +1671,25 @@ impl ContextProvider for SessionContextProvider<'_> { fn get_table_function_source( &self, name: &str, - args: Vec, + args: Vec<(Expr, Option)>, ) -> datafusion_common::Result> { + let name = name.to_ascii_lowercase(); let tbl_func = self .state .table_functions - .get(name) + .get(&name) .cloned() .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; let dummy_schema = DFSchema::empty(); let simplifier = ExprSimplifier::new(SessionSimplifyProvider::new(self.state, &dummy_schema)); + let args = args .into_iter() - .map(|arg| simplifier.simplify(arg)) + .map(|(expr, named_param)| { + // simplify returns Result, map it into Result<(Expr, Option)> + simplifier.simplify(expr).map(|e| (e, named_param)) + }) .collect::>>()?; let provider = tbl_func.create_table_provider(&args)?; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 6618d9495d78..b2b08affcc47 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -78,9 +78,9 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, - Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, + Analyze, BinaryExpr, Cast, DescribeTable, DmlStatement, Explain, ExplainFormat, + Extension, FetchType, Filter, JoinType, LogicalPlanBuilder, RecursiveQuery, SkipType, + StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{Column, Literal}; @@ -97,12 +97,15 @@ use datafusion_sql::TableReference; use sqlparser::ast::NullTreatment; use async_trait::async_trait; +use datafusion_expr_common::operator::Operator; use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper}; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::debug; use tokio::sync::Mutex; +use datafusion_physical_plan::collect; + /// Physical query planner that converts a `LogicalPlan` to an /// `ExecutionPlan` suitable for execution. #[async_trait] @@ -945,7 +948,60 @@ impl DefaultPhysicalPlanner { options.clone(), )) } + LogicalPlan::Pivot(pivot) => { + return if !pivot.pivot_values.is_empty() { + let agg_plan = transform_pivot_to_aggregate( + Arc::new(pivot.input.as_ref().clone()), + &pivot.aggregate_expr, + &pivot.pivot_column, + pivot.pivot_values.clone(), + pivot.default_on_null_expr.as_ref(), + )?; + + self.create_physical_plan(&agg_plan, session_state).await + } else if let Some(subquery) = &pivot.value_subquery { + let optimized_subquery = session_state.optimize(subquery.as_ref())?; + + let subquery_physical_plan = self + .create_physical_plan(&optimized_subquery, session_state) + .await?; + + let subquery_results = collect( + Arc::clone(&subquery_physical_plan), + session_state.task_ctx(), + ) + .await?; + + let mut pivot_values = Vec::new(); + for batch in subquery_results.iter() { + if batch.num_columns() != 1 { + return plan_err!( + "Pivot subquery must return a single column" + ); + } + + let column = batch.column(0); + for row_idx in 0..batch.num_rows() { + if !column.is_null(row_idx) { + pivot_values + .push(ScalarValue::try_from_array(column, row_idx)?); + } + } + } + let agg_plan = transform_pivot_to_aggregate( + Arc::new(pivot.input.as_ref().clone()), + &pivot.aggregate_expr, + &pivot.pivot_column, + pivot_values, + pivot.default_on_null_expr.as_ref(), + )?; + + self.create_physical_plan(&agg_plan, session_state).await + } else { + plan_err!("PIVOT operation requires at least one value to pivot on") + } + } // 2 Children LogicalPlan::Join(Join { left: original_left, @@ -1815,6 +1871,151 @@ pub fn create_aggregate_expr_and_maybe_filter( ) } +/// Transform a PIVOT operation into a more standard Aggregate + Projection plan +/// For known pivot values, we create a projection that includes "IS NOT DISTINCT FROM" conditions +/// +/// For example, for SUM(amount) PIVOT(quarter FOR quarter in ('2023_Q1', '2023_Q2')), we create: +/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q1') AS "2023_Q1" +/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q2') AS "2023_Q2" +/// +/// If DEFAULT ON NULL is specified, each aggregate expression is wrapped with an outer projection that +/// applies COALESCE to the results. +pub fn transform_pivot_to_aggregate( + input: Arc, + aggregate_expr: &Expr, + pivot_column: &datafusion_common::Column, + pivot_values: Vec, + default_on_null_expr: Option<&Expr>, +) -> Result { + let df_schema = input.schema(); + + let all_columns: Vec = df_schema.columns(); + + // Filter to include only columns we want for GROUP BY + // (exclude pivot column and aggregate expression columns) + let group_by_columns: Vec = all_columns + .into_iter() + .filter(|col: &datafusion_common::Column| { + col.name != pivot_column.name + && !aggregate_expr + .column_refs() + .iter() + .any(|agg_col| agg_col.name == col.name) + }) + .map(|col: datafusion_common::Column| Expr::Column(col)) + .collect(); + + let builder = LogicalPlanBuilder::from(Arc::unwrap_or_clone(Arc::clone(&input))); + + // Create the aggregate plan with filtered aggregates + let mut aggregate_exprs = Vec::new(); + + let input_schema = input.schema(); + let pivot_col_idx = match input_schema.index_of_column(pivot_column) { + Ok(idx) => idx, + Err(_) => { + return plan_err!( + "Pivot column '{}' does not exist in input schema", + pivot_column + ) + } + }; + let pivot_col_type = input_schema.field(pivot_col_idx).data_type(); + + for value in &pivot_values { + let filter_condition = Expr::BinaryExpr(BinaryExpr::new( + Box::new(Expr::Column(pivot_column.clone())), + Operator::IsNotDistinctFrom, + Box::new(Expr::Cast(Cast::new( + Box::new(Expr::Literal(value.clone(), None)), + pivot_col_type.clone(), + ))), + )); + + let filtered_agg = match aggregate_expr { + Expr::AggregateFunction(agg) => { + let mut new_params = agg.params.clone(); + new_params.filter = Some(Box::new(filter_condition)); + Expr::AggregateFunction(AggregateFunction { + func: Arc::clone(&agg.func), + params: new_params, + }) + } + _ => { + return plan_err!( + "Unsupported aggregate expression should always be AggregateFunction" + ); + } + }; + + // Use the pivot value as the column name + let field_name = value.to_string().trim_matches('\'').to_string(); + let aliased_agg = Expr::Alias(Alias { + expr: Box::new(filtered_agg), + relation: None, + name: field_name, + metadata: None, + }); + + aggregate_exprs.push(aliased_agg); + } + + // Create the plan with the aggregate + let aggregate_plan = builder + .aggregate(group_by_columns, aggregate_exprs)? + .build()?; + + // If DEFAULT ON NULL is specified, add a projection to apply COALESCE + if let Some(default_expr) = default_on_null_expr { + let schema = aggregate_plan.schema(); + let mut projection_exprs = Vec::new(); + + for field in schema.fields() { + if !pivot_values + .iter() + .any(|v| field.name() == v.to_string().trim_matches('\'')) + { + projection_exprs.push(Expr::Column( + datafusion_common::Column::from_name(field.name()), + )); + } + } + + // Apply COALESCE to aggregate columns + for value in &pivot_values { + let field_name = value.to_string().trim_matches('\'').to_string(); + let aggregate_col = + Expr::Column(datafusion_common::Column::from_name(&field_name)); + + // Create COALESCE expression using CASE: CASE WHEN col IS NULL THEN default_value ELSE col END + let coalesce_expr = Expr::Case(datafusion_expr::expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(Expr::IsNull(Box::new(aggregate_col.clone()))), + Box::new(default_expr.clone()), + )], + else_expr: Some(Box::new(aggregate_col)), + }); + + let aliased_coalesce = Expr::Alias(Alias { + expr: Box::new(coalesce_expr), + relation: None, + name: field_name, + metadata: None, + }); + + projection_exprs.push(aliased_coalesce); + } + + // Apply the projection + LogicalPlanBuilder::from(aggregate_plan) + .project(projection_exprs)? + .build() + } else { + Ok(aggregate_plan) + } +} + impl DefaultPhysicalPlanner { /// Handles capturing the various plans for EXPLAIN queries /// @@ -2186,6 +2387,38 @@ impl DefaultPhysicalPlanner { .collect::>>()?; let num_input_columns = input_exec.schema().fields().len(); + // When we detect a PIVOT-derived plan with a value_subquery, ensure all generated columns are preserved + if let LogicalPlan::Pivot(pivot) = input.as_ref() { + if pivot.value_subquery.is_some() + && input_exec + .as_any() + .downcast_ref::() + .is_some() + { + let agg_exec = + input_exec.as_any().downcast_ref::().unwrap(); + let schema = input_exec.schema(); + let group_by_len = agg_exec.group_expr().expr().len(); + + if group_by_len < schema.fields().len() { + let mut all_exprs = physical_exprs.clone(); + + for (i, field) in + schema.fields().iter().enumerate().skip(group_by_len) + { + if !physical_exprs.iter().any(|(_, name)| name == field.name()) { + all_exprs.push(( + Arc::new(Column::new(field.name(), i)) + as Arc, + field.name().clone(), + )); + } + } + + return Ok(Arc::new(ProjectionExec::try_new(all_exprs, input_exec)?)); + } + } + } match self.try_plan_async_exprs( num_input_columns, diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index 2c6611f382ce..9ff798f6d2a2 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -200,10 +200,11 @@ impl SimpleCsvTable { struct SimpleCsvTableFunc {} impl TableFunctionImpl for SimpleCsvTableFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call(&self, args: &[(Expr, Option)]) -> Result> { + dbg!(args); let mut new_exprs = vec![]; let mut filepath = String::new(); - for expr in exprs { + for (expr, _) in args { match expr { Expr::Literal(ScalarValue::Utf8(Some(ref path)), _) => { filepath.clone_from(path); diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index f344a71451d4..93d9ef08e4a0 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -784,6 +784,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { } } +/// Coercion rules for boolean types: If at least one argument is +/// a boolean type and both arguments can be coerced into a boolean type, coerce +/// to boolean type. +fn boolean_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Boolean, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64) + | (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Boolean) => { + Some(Boolean) + } + _ => None, + } +} + /// Returns the output type of applying mathematics operations such as /// `+` to arguments of `lhs_type` and `rhs_type`. fn mathematics_numerical_coercion( @@ -1204,7 +1219,17 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_value_type, rhs_value_type).or(None) } - _ => None, + _ => { + if can_cast_types(lhs_type, &Utf8) && can_cast_types(rhs_type, &Utf8) { + Some(Utf8) + } else if can_cast_types(lhs_type, &LargeUtf8) + && can_cast_types(rhs_type, &LargeUtf8) + { + Some(LargeUtf8) + } else { + None + } + } }) } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 88d49722a587..a4ede50d7b1a 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -32,7 +32,7 @@ use crate::expr_rewriter::{ }; use crate::logical_plan::{ Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, Pivot, PlanType, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, Window, }; @@ -1480,6 +1480,23 @@ impl LogicalPlanBuilder { unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) .map(Self::new) } + + pub fn pivot( + self, + aggregate_expr: Expr, + pivot_column: Column, + pivot_values: Vec, + default_on_null: Option, + ) -> Result { + let pivot_plan = Pivot::try_new( + self.plan, + aggregate_expr, + pivot_column, + pivot_values, + default_on_null, + )?; + Ok(Self::new(LogicalPlan::Pivot(pivot_plan))) + } } impl From for LogicalPlanBuilder { @@ -2787,4 +2804,30 @@ mod tests { Ok(()) } + + #[test] + fn plan_builder_pivot() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("region", DataType::Utf8, false), + Field::new("product", DataType::Utf8, false), + Field::new("sales", DataType::Int32, false), + ]); + + let plan = LogicalPlanBuilder::scan("sales", table_source(&schema), None)? + .pivot( + col("sales"), + Column::from_name("product"), + vec![ + ScalarValue::Utf8(Some("widget".to_string())), + ScalarValue::Utf8(Some("gadget".to_string())), + ], + None, + )? + .build()?; + + let expected = "Pivot: sales FOR product IN (widget, gadget)\n TableScan: sales"; + assert_eq!(expected, format!("{plan}")); + + Ok(()) + } } diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index cc3fbad7b0c2..b1a6fda29ba5 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -22,7 +22,7 @@ use std::fmt; use crate::{ expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr, - Filter, Join, Limit, LogicalPlan, Partitioning, Projection, RecursiveQuery, + Filter, Join, Limit, LogicalPlan, Partitioning, Pivot, Projection, RecursiveQuery, Repartition, Sort, Subquery, SubqueryAlias, TableProviderFilterPushDown, TableScan, Unnest, Values, Window, }; @@ -651,6 +651,41 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "StructColumn": expr_vec_fmt!(struct_type_columns), }) } + LogicalPlan::Pivot(Pivot { + aggregate_expr, + pivot_column, + pivot_values, + value_subquery, + default_on_null_expr, + .. + }) => { + let mut object = json!({ + "Node Type": "Pivot", + "Aggregate": format!("{}", aggregate_expr), + "Pivot Column": format!("{}", pivot_column), + }); + + if !pivot_values.is_empty() { + object["Pivot Values"] = serde_json::Value::Array( + pivot_values + .iter() + .map(|v| serde_json::Value::String(v.to_string())) + .collect(), + ); + } + + if value_subquery.is_some() { + object["Value Subquery"] = + serde_json::Value::String("Provided".to_string()); + } + + if default_on_null_expr.is_some() { + object["Default On Null"] = + serde_json::Value::String("Provided".to_string()); + } + + object + } } } } @@ -722,8 +757,11 @@ impl<'n> TreeNodeVisitor<'n> for PgJsonVisitor<'_, '_> { #[cfg(test)] mod tests { - use arrow::datatypes::{DataType, Field}; + use crate::EmptyRelation; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Column, DFSchema, ScalarValue}; use insta::assert_snapshot; + use std::sync::Arc; use super::*; @@ -742,4 +780,84 @@ mod tests { assert_snapshot!(display_schema(&schema), @"[id:Int32, first_name:Utf8;N]"); } + + #[test] + fn test_pivot_to_json_value() { + // Create a mock schema + let schema = Arc::new(DFSchema::empty().to_owned()); + + // Create mock pivot values + let pivot_values = vec![ + ScalarValue::Utf8(Some("A".to_string())), + ScalarValue::Utf8(Some("B".to_string())), + ]; + + // Create a Pivot plan + let pivot = Pivot { + input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&schema), + })), + aggregate_expr: Expr::Column(Column::from_name("sum_value")), + pivot_column: Column::from_name("category"), + pivot_values, + schema: Arc::clone(&schema), + value_subquery: None, + default_on_null_expr: None, + }; + + // Test the to_json_value function + let json_value = PgJsonVisitor::to_json_value(&LogicalPlan::Pivot(pivot)); + + // Check the JSON structure + assert_eq!(json_value["Node Type"], "Pivot"); + assert_eq!(json_value["Aggregate"], "sum_value"); + assert_eq!(json_value["Pivot Column"], "category"); + + // Check the pivot values + let pivot_values = json_value["Pivot Values"].as_array().unwrap(); + assert_eq!(pivot_values.len(), 2); + assert_eq!(pivot_values[0], "A"); + assert_eq!(pivot_values[1], "B"); + + // Check that Value Subquery is not present + assert!(json_value.get("Value Subquery").is_none()); + } + + #[test] + fn test_pivot_with_subquery_to_json_value() { + // Create a mock schema + let schema = Arc::new(DFSchema::empty().to_owned()); + + // Create a Pivot plan with a value subquery + let pivot = Pivot { + input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&schema), + })), + aggregate_expr: Expr::Column(Column::from_name("sum_value")), + pivot_column: Column::from_name("category"), + pivot_values: vec![], + schema: Arc::clone(&schema), + value_subquery: Some(Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&schema), + }))), + default_on_null_expr: None, + }; + + // Test the to_json_value function + let json_value = PgJsonVisitor::to_json_value(&LogicalPlan::Pivot(pivot)); + + // Check the JSON structure + assert_eq!(json_value["Node Type"], "Pivot"); + assert_eq!(json_value["Aggregate"], "sum_value"); + assert_eq!(json_value["Pivot Column"], "category"); + + // Check that pivot values are not present + assert!(json_value.get("Pivot Values").is_none()); + + // Check that Value Subquery is present + assert_eq!(json_value["Value Subquery"], "Provided"); + } } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 7de2fd117487..6179f4c37833 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -40,7 +40,7 @@ pub use dml::{DmlStatement, WriteOp}; pub use plan::{ projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, ExplainOption, Extension, FetchType, Filter, - Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, + Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, Pivot, PlanType, Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 7dc750a35c0e..3af45d8daca7 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -291,6 +291,8 @@ pub enum LogicalPlan { Unnest(Unnest), /// A variadic query (e.g. "Recursive CTEs") RecursiveQuery(RecursiveQuery), + /// Pivot + Pivot(Pivot), } impl Default for LogicalPlan { @@ -355,6 +357,7 @@ impl LogicalPlan { // we take the schema of the static term as the schema of the entire recursive query static_term.schema() } + LogicalPlan::Pivot(Pivot { schema, .. }) => schema, } } @@ -471,7 +474,8 @@ impl LogicalPlan { LogicalPlan::Dml(write) => vec![&write.input], LogicalPlan::Copy(copy) => vec![©.input], LogicalPlan::Ddl(ddl) => ddl.inputs(), - LogicalPlan::Unnest(Unnest { input, .. }) => vec![input], + LogicalPlan::Unnest(Unnest { input, .. }) + | LogicalPlan::Pivot(Pivot { input, .. }) => vec![input], LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, recursive_term, @@ -595,7 +599,8 @@ impl LogicalPlan { | LogicalPlan::Copy(_) | LogicalPlan::Ddl(_) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Unnest(_) => Ok(None), + | LogicalPlan::Unnest(_) + | LogicalPlan::Pivot(_) => Ok(None), } } @@ -752,6 +757,39 @@ impl LogicalPlan { // Update schema with unnested column type. unnest_with_options(Arc::unwrap_or_clone(input), exec_columns, options) } + LogicalPlan::Pivot(Pivot { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + .. + }) => { + // Create Pivot with the same value_subquery + let new_pivot = if let Some(subquery) = value_subquery { + Pivot { + input, + aggregate_expr, + pivot_column: pivot_column.clone(), + pivot_values: pivot_values.clone(), + schema: Arc::clone(&schema), + value_subquery: Some(Arc::clone(&subquery)), + default_on_null_expr: None, + } + } else { + Pivot::try_new( + Arc::clone(&input), + aggregate_expr.clone(), + pivot_column.clone(), + pivot_values.clone(), + default_on_null_expr.clone(), + )? + }; + + Ok(LogicalPlan::Pivot(new_pivot)) + } } } @@ -1145,6 +1183,39 @@ impl LogicalPlan { unnest_with_options(input, columns.clone(), options.clone())?; Ok(new_plan) } + LogicalPlan::Pivot(Pivot { + aggregate_expr: _, + pivot_column, + pivot_values, + schema: _, + value_subquery, + default_on_null_expr, + .. + }) => { + let input = self.only_input(inputs)?; + let new_aggregate_expr = self.only_expr(expr)?; + + // Create Pivot with the same value_subquery + let new_pivot = if let Some(subquery) = value_subquery { + Pivot::try_new_with_subquery( + Arc::new(input), + new_aggregate_expr, + pivot_column.clone(), + Arc::clone(subquery), + default_on_null_expr.clone(), + )? + } else { + Pivot::try_new( + Arc::new(input), + new_aggregate_expr, + pivot_column.clone(), + pivot_values.clone(), + default_on_null_expr.clone(), + )? + }; + + Ok(LogicalPlan::Pivot(new_pivot)) + } } } @@ -1379,7 +1450,8 @@ impl LogicalPlan { | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) | LogicalPlan::Statement(_) - | LogicalPlan::Extension(_) => None, + | LogicalPlan::Extension(_) + | LogicalPlan::Pivot(_) => None, } } @@ -2027,6 +2099,20 @@ impl LogicalPlan { expr_vec_fmt!(list_type_columns), expr_vec_fmt!(struct_type_columns)) } + LogicalPlan::Pivot(Pivot { + aggregate_expr, + pivot_column, + pivot_values, + .. + }) => { + write!( + f, + "Pivot: {} FOR {} IN ({})", + aggregate_expr, + pivot_column, + pivot_values.iter().map(|v| v.to_string()).collect::>().join(", ") + ) + } } } } @@ -2220,6 +2306,155 @@ pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result, + /// Aggregate expression (e.g., SUM(amount)) + pub aggregate_expr: Expr, + /// Column whose values become new columns + pub pivot_column: Column, + /// List of pivot values (distinct values from pivot column) + pub pivot_values: Vec, + /// Output schema after pivot + pub schema: DFSchemaRef, + /// Optional subquery for pivot values + /// When provided, this will be executed during physical planning + /// to dynamically determine the pivot values + pub value_subquery: Option>, + /// Optional default value for replacing NULL values in the pivot result + pub default_on_null_expr: Option, +} + +impl PartialOrd for Pivot { + fn partial_cmp(&self, other: &Self) -> Option { + let self_tuple = ( + &self.input, + &self.aggregate_expr, + &self.pivot_column, + &self.pivot_values, + &self.value_subquery, + &self.default_on_null_expr, + ); + let other_tuple = ( + &other.input, + &other.aggregate_expr, + &other.pivot_column, + &other.pivot_values, + &other.value_subquery, + &other.default_on_null_expr, + ); + self_tuple.partial_cmp(&other_tuple) + } +} + +impl Pivot { + pub fn try_new( + input: Arc, + aggregate_expr: Expr, + pivot_column: Column, + pivot_values: Vec, + default_on_null_expr: Option, + ) -> Result { + let schema = pivot_schema( + input.schema(), + &aggregate_expr, + &pivot_column, + &pivot_values, + )?; + + Ok(Self { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema: Arc::new(schema), + value_subquery: None, + default_on_null_expr, + }) + } + + /// Create a new Pivot with a subquery for pivot values + pub fn try_new_with_subquery( + input: Arc, + aggregate_expr: Expr, + pivot_column: Column, + value_subquery: Arc, + default_on_null_expr: Option, + ) -> Result { + let schema = + pivot_schema_without_values(input.schema(), &aggregate_expr, &pivot_column)?; + + Ok(Self { + input, + aggregate_expr, + pivot_column, + pivot_values: Vec::new(), + schema: Arc::new(schema), + value_subquery: Some(value_subquery), + default_on_null_expr, + }) + } +} + +fn pivot_schema_without_values( + input_schema: &DFSchemaRef, + aggregate_expr: &Expr, + pivot_column: &Column, +) -> Result { + let mut fields = vec![]; + + // Include all fields except pivot and value columns + for field in input_schema.fields() { + if !aggregate_expr + .column_refs() + .iter() + .any(|col| col.name() == field.name()) + && field.name() != pivot_column.name() + { + fields.push(Arc::clone(field)); + } + } + + let fields_with_table_ref: Vec<(Option, Arc)> = + fields.into_iter().map(|field| (None, field)).collect(); + + DFSchema::new_with_metadata(fields_with_table_ref, input_schema.metadata().clone()) +} + +fn pivot_schema( + input_schema: &DFSchemaRef, + aggregate_expr: &Expr, + pivot_column: &Column, + pivot_values: &[ScalarValue], +) -> Result { + let mut fields = vec![]; + + for field in input_schema.fields() { + if !aggregate_expr + .column_refs() + .iter() + .any(|col| col.name() == field.name()) + && field.name() != pivot_column.name() + { + fields.push(Arc::clone(field)); + } + } + + for pivot_value in pivot_values { + let field_name = format!("{pivot_value}"); + let data_type = aggregate_expr.get_type(input_schema)?; + fields.push(Arc::new(Field::new(field_name, data_type, true))); + } + + let fields_with_table_ref: Vec<(Option, Arc)> = + fields.into_iter().map(|field| (None, field)).collect(); + + DFSchema::new_with_metadata(fields_with_table_ref, input_schema.metadata().clone()) +} + /// Aliased subquery #[derive(Debug, Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 47088370a1d9..c008a9a6e522 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -40,8 +40,8 @@ use crate::{ dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, - Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, - Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, + Limit, LogicalPlan, Partitioning, Pivot, Prepare, Projection, RecursiveQuery, + Repartition, Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, Window, }; use datafusion_common::tree_node::TreeNodeRefContainer; @@ -322,6 +322,25 @@ impl TreeNode for LogicalPlan { options, }) }), + LogicalPlan::Pivot(Pivot { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + }) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::Pivot(Pivot { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + }) + }), LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term, @@ -461,6 +480,7 @@ impl LogicalPlan { } _ => Ok(TreeNodeRecursion::Continue), }, + LogicalPlan::Pivot(Pivot { aggregate_expr, .. }) => f(aggregate_expr), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) @@ -631,6 +651,25 @@ impl LogicalPlan { LogicalPlan::Limit(Limit { skip, fetch, input }) }) } + LogicalPlan::Pivot(Pivot { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + }) => f(aggregate_expr)?.update_data(|aggregate_expr| { + LogicalPlan::Pivot(Pivot { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + }) + }), LogicalPlan::Statement(stmt) => match stmt { Statement::Execute(e) => { e.parameters.map_elements(f)?.update_data(|parameters| { diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 669931d7bae7..51d15a28bf47 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -52,7 +52,7 @@ pub trait ContextProvider { fn get_table_function_source( &self, _name: &str, - _args: Vec, + _args: Vec<(Expr, Option)>, ) -> Result> { not_impl_err!("Table Functions are not supported") } diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index d3b253c0e102..dbd034ffa736 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -20,7 +20,7 @@ use crate::{Expr, LogicalPlan}; use arrow::datatypes::SchemaRef; -use datafusion_common::{Constraints, Result}; +use datafusion_common::{Constraints, Result, Statistics}; use std::{any::Any, borrow::Cow}; @@ -129,4 +129,12 @@ pub trait TableSource: Sync + Send { fn get_column_default(&self, _column: &str) -> Option<&Expr> { None } + + /// Get statistics for this table source, if available + /// Although not presently used in mainline DataFusion, this allows implementation specific + /// behavior for downstream repositories, in conjunction with specialized optimizer rules to + /// perform operations such as re-ordering of joins. + fn statistics(&self) -> Option { + None + } } diff --git a/datafusion/ffi/src/udtf.rs b/datafusion/ffi/src/udtf.rs index ceedec2599a2..b65fb29197ea 100644 --- a/datafusion/ffi/src/udtf.rs +++ b/datafusion/ffi/src/udtf.rs @@ -100,7 +100,11 @@ unsafe extern "C" fn call_fn_wrapper( let args = rresult_return!(parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)); - let table_provider = rresult_return!(udtf.call(&args)); + let args_with_names = args + .into_iter() + .map(|expr| (expr, None)) + .collect::>(); + let table_provider = rresult_return!(udtf.call(&args_with_names)); RResult::ROk(FFI_TableProvider::new(table_provider, false, runtime)) } @@ -176,10 +180,13 @@ impl From for ForeignTableFunction { } impl TableFunctionImpl for ForeignTableFunction { - fn call(&self, args: &[Expr]) -> Result> { + fn call(&self, args: &[(Expr, Option)]) -> Result> { let codec = DefaultLogicalExtensionCodec {}; let expr_list = LogicalExprList { - expr: serialize_exprs(args, &codec)?, + expr: serialize_exprs( + args.iter().map(|(expr, _)| expr).collect::>(), + &codec, + )?, }; let filters_serialized = expr_list.encode_to_vec().into(); @@ -210,10 +217,13 @@ mod tests { struct TestUDTF {} impl TableFunctionImpl for TestUDTF { - fn call(&self, args: &[Expr]) -> Result> { + fn call( + &self, + args: &[(Expr, Option)], + ) -> Result> { let args = args .iter() - .map(|arg| { + .map(|(arg, _)| { if let Expr::Literal(scalar, _) = arg { Ok(scalar) } else { @@ -293,8 +303,12 @@ mod tests { let foreign_udf: ForeignTableFunction = local_udtf.into(); - let table = - foreign_udf.call(&vec![lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?; + let table = foreign_udf.call(&vec![ + (lit(6_u64), None), + (lit("one"), None), + (lit(2.0), None), + (lit(3_u64), None), + ])?; let ctx = SessionContext::default(); let _ = ctx.register_table("test-table", table)?; diff --git a/datafusion/functions-table/src/generate_series.rs b/datafusion/functions-table/src/generate_series.rs index 82be31a15837..a36ee075ab90 100644 --- a/datafusion/functions-table/src/generate_series.rs +++ b/datafusion/functions-table/src/generate_series.rs @@ -490,13 +490,14 @@ struct GenerateSeriesFuncImpl { } impl TableFunctionImpl for GenerateSeriesFuncImpl { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call(&self, exprs: &[(Expr, Option)]) -> Result> { if exprs.is_empty() || exprs.len() > 3 { return plan_err!("{} function requires 1 to 3 arguments", self.name); } // Determine the data type from the first argument - match &exprs[0] { + + match &exprs[0].0 { Expr::Literal( // Default to int64 for null ScalarValue::Null | ScalarValue::Int64(_), @@ -520,10 +521,13 @@ impl TableFunctionImpl for GenerateSeriesFuncImpl { } impl GenerateSeriesFuncImpl { - fn call_int64(&self, exprs: &[Expr]) -> Result> { + fn call_int64( + &self, + exprs: &[(Expr, Option)], + ) -> Result> { let mut normalize_args = Vec::new(); for (expr_index, expr) in exprs.iter().enumerate() { - match expr { + match &expr.0 { Expr::Literal(ScalarValue::Null, _) => {} Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n), other => { @@ -583,7 +587,10 @@ impl GenerateSeriesFuncImpl { })) } - fn call_timestamp(&self, exprs: &[Expr]) -> Result> { + fn call_timestamp( + &self, + exprs: &[(Expr, Option)], + ) -> Result> { if exprs.len() != 3 { return plan_err!( "{} function with timestamps requires exactly 3 arguments", @@ -592,7 +599,7 @@ impl GenerateSeriesFuncImpl { } // Parse start timestamp - let (start_ts, tz) = match &exprs[0] { + let (start_ts, tz) = match &exprs[0].0 { Expr::Literal(ScalarValue::TimestampNanosecond(ts, tz), _) => { (*ts, tz.clone()) } @@ -605,7 +612,7 @@ impl GenerateSeriesFuncImpl { }; // Parse end timestamp - let end_ts = match &exprs[1] { + let end_ts = match &exprs[1].0 { Expr::Literal(ScalarValue::Null, _) => None, Expr::Literal(ScalarValue::TimestampNanosecond(ts, _), _) => *ts, other => { @@ -617,7 +624,7 @@ impl GenerateSeriesFuncImpl { }; // Parse step interval - let step_interval = match &exprs[2] { + let step_interval = match &exprs[2].0 { Expr::Literal(ScalarValue::Null, _) => None, Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), _) => *interval, other => { @@ -659,7 +666,10 @@ impl GenerateSeriesFuncImpl { })) } - fn call_date(&self, exprs: &[Expr]) -> Result> { + fn call_date( + &self, + exprs: &[(Expr, Option)], + ) -> Result> { if exprs.len() != 3 { return plan_err!( "{} function with dates requires exactly 3 arguments", @@ -674,7 +684,7 @@ impl GenerateSeriesFuncImpl { )])); // Parse start date - let start_date = match &exprs[0] { + let start_date = match &exprs[0].0 { Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date, Expr::Literal(ScalarValue::Date32(None), _) | Expr::Literal(ScalarValue::Null, _) => { @@ -692,7 +702,7 @@ impl GenerateSeriesFuncImpl { }; // Parse end date - let end_date = match &exprs[1] { + let end_date = match &exprs[1].0 { Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date, Expr::Literal(ScalarValue::Date32(None), _) | Expr::Literal(ScalarValue::Null, _) => { @@ -710,7 +720,7 @@ impl GenerateSeriesFuncImpl { }; // Parse step interval - let step_interval = match &exprs[2] { + let step_interval = match &exprs[2].0 { Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval)), _) => { *interval } @@ -756,7 +766,7 @@ impl GenerateSeriesFuncImpl { pub struct GenerateSeriesFunc {} impl TableFunctionImpl for GenerateSeriesFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call(&self, exprs: &[(Expr, Option)]) -> Result> { let impl_func = GenerateSeriesFuncImpl { name: "generate_series", include_end: true, @@ -769,7 +779,7 @@ impl TableFunctionImpl for GenerateSeriesFunc { pub struct RangeFunc {} impl TableFunctionImpl for RangeFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call(&self, exprs: &[(Expr, Option)]) -> Result> { let impl_func = GenerateSeriesFuncImpl { name: "range", include_end: false, diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 3341a5dbb52e..b22a57cf85df 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -143,9 +143,8 @@ impl ScalarUDFImpl for ToDateFunc { if args.len() > 1 { validate_data_types(&args, "to_date")?; } - match args[0].data_type() { - Int32 | Int64 | Null | Float64 | Date32 | Date64 => { + Int32 | Int64 | Null | Float64 | Date32 | Date64 | Timestamp(_, _) => { args[0].cast_to(&Date32, None) } Utf8View | LargeUtf8 | Utf8 => self.to_date(&args), @@ -424,12 +423,12 @@ mod tests { "to_date created wrong value for date with 2 format strings" ); } - _ => panic!("Conversion failed",), + _ => panic!("Conversion failed"), } } #[test] - fn test_to_date_from_timestamp() { + fn test_to_date_from_timestamp_str() { let test_cases = vec![ "2020-09-08T13:42:29Z", "2020-09-08T13:42:29.190855-05:00", @@ -453,6 +452,28 @@ mod tests { } } + #[test] + fn test_to_date_from_timestamp() { + let test_cases = vec![ + ScalarValue::TimestampSecond(Some(1736782134), None), + ScalarValue::TimestampMillisecond(Some(1736782134736), None), + ScalarValue::TimestampMicrosecond(Some(1736782134736782), None), + ScalarValue::TimestampNanosecond(Some(1736782134736782134), None), + ]; + for scalar in test_cases { + let timestamp_to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(scalar.clone())], 1); + + match timestamp_to_date_result { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { + let expected = Date32Type::parse_formatted("2025-01-13", "%Y-%m-%d"); + assert_eq!(date_val, expected, "to_date created wrong value"); + } + _ => panic!("Conversion of {scalar}"), + } + } + } + #[test] fn test_to_date_string_with_valid_number() { let date_str = "20241231"; diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index da4e23f91de7..80fc9aebbc44 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -27,6 +27,7 @@ pub mod regexpinstr; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; +pub mod regexpsubstr; // create UDFs make_udf_function!(regexpcount::RegexpCountFunc, regexp_count); @@ -34,6 +35,7 @@ make_udf_function!(regexpinstr::RegexpInstrFunc, regexp_instr); make_udf_function!(regexpmatch::RegexpMatchFunc, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, regexp_like); make_udf_function!(regexpreplace::RegexpReplaceFunc, regexp_replace); +make_udf_function!(regexpsubstr::RegexpSubstrFunc, regexp_substr); pub mod expr_fn { use datafusion_expr::Expr; @@ -93,7 +95,34 @@ pub mod expr_fn { }; super::regexp_instr().call(args) } + /// Returns true if a regex has at least one match in a string, false otherwise. + /// Returns the substring that matches a regular expression within a string. + pub fn regexp_substr( + values: Expr, + regex: Expr, + start: Option, + occurrence: Option, + flags: Option, + group_num: Option, + ) -> Expr { + let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + if let Some(occurrence) = occurrence { + args.push(occurrence); + }; + if let Some(flags) = flags { + args.push(flags); + }; + if let Some(group_num) = group_num { + args.push(group_num); + }; + super::regexp_substr().call(args) + } + + /// Returns true if a has at least one match in a string, false otherwise. pub fn regexp_like(values: Expr, regex: Expr, flags: Option) -> Expr { let mut args = vec![values, regex]; if let Some(flags) = flags { @@ -125,6 +154,7 @@ pub fn functions() -> Vec> { regexp_instr(), regexp_like(), regexp_replace(), + regexp_substr(), ] } diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 0554844d11c1..a6c8047513ad 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -70,6 +70,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo #[derive(Debug, PartialEq, Eq, Hash)] pub struct RegexpLikeFunc { signature: Signature, + aliases: Vec, } impl Default for RegexpLikeFunc { @@ -95,6 +96,7 @@ impl RegexpLikeFunc { ], Volatility::Immutable, ), + aliases: vec![String::from("rlike")], } } } @@ -123,6 +125,10 @@ impl ScalarUDFImpl for RegexpLikeFunc { }) } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn invoke_with_args( &self, args: datafusion_expr::ScalarFunctionArgs, diff --git a/datafusion/functions/src/regex/regexpsubstr.rs b/datafusion/functions/src/regex/regexpsubstr.rs new file mode 100644 index 000000000000..c12e5dcd9689 --- /dev/null +++ b/datafusion/functions/src/regex/regexpsubstr.rs @@ -0,0 +1,613 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Regex expressions +use arrow::array::{ + Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, +}; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::error::ArrowError; +use datafusion_common::plan_err; +use datafusion_common::ScalarValue; +use datafusion_common::{ + cast::as_generic_string_array, internal_err, DataFusionError, Result, +}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs, TypeSignature}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use regex::Regex; +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct RegexpSubstrFunc { + signature: Signature, +} + +impl Default for RegexpSubstrFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpSubstrFunc { + pub fn new() -> Self { + use DataType::{Int64, LargeUtf8, Utf8}; + Self { + signature: Signature::one_of( + vec![ + // Planner attempts coercion to the target type starting with the most preferred candidate. + // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8, Utf8)`. + // If that fails, it proceeds to `(LargeUtf8, Utf8)`. + TypeSignature::Exact(vec![Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]), + TypeSignature::Exact(vec![ + LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, + ]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64, Int64, Utf8, Int64]), + TypeSignature::Exact(vec![ + LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64, + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpSubstrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regexp_substr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let len = args + .args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .args + .iter() + .map(|arg| arg.to_array(inferred_length)) + .collect::>>()?; + + let result = regexp_subst_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_substr_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_substr_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder( + DOC_SECTION_REGEX, + "Returns the substring that matches a [regular expression](https://docs.rs/regex/latest/regex/#syntax) within a string.", + "regexp_substr(str, regexp[, position[, occurrence[, flags[, group_num]]]])") + .with_sql_example(r#"```sql + > select regexp_substr('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_substr(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | Köln | + +---------------------------------------------------------+ + SELECT regexp_substr('aBc', '(b|d)', 1, 1, 'i'); + +---------------------------------------------------+ + | regexp_substr(Utf8("aBc"),Utf8("(b|d)"), Int32(1), Int32(1), Utf8("i")) | + +---------------------------------------------------+ + | B | + +---------------------------------------------------+ +``` +Additional examples can be found [here](https://docs.snowflake.com/en/sql-reference/functions/regexp_substr#examples) +"#) + .with_standard_argument("str", Some("String")) + .with_argument("regexp", "Regular expression to match against. + Can be a constant, column, or function.") + .with_argument("position", "Number of characters from the beginning of the string where the function starts searching for matches. Default: 1") + .with_argument("occurrence", "Specifies the first occurrence of the pattern from which to start returning matches.. Default: 1") + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **c**: case-sensitive: letters match upper or lower case. Default flag + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **e**: extract submatches (for Snowflake compatibility) + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"#) + .with_argument("group_num", "Specifies which group to extract. Groups are specified by using parentheses in the regular expression.") + .build() + }) +} + +fn regexp_subst_func(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => regexp_substr::(args), + DataType::LargeUtf8 => regexp_substr::(args), + other => { + internal_err!("Unsupported data type {other:?} for function regexp_substr") + } + } +} +pub fn regexp_substr(args: &[ArrayRef]) -> Result { + let args_len = args.len(); + let get_int_arg = |index: usize, name: &str| -> Result> { + if args_len > index { + let arg = args[index].as_primitive::(); + if arg.is_empty() { + return plan_err!( + "regexp_substr() requires the {:?} argument to be an integer", + name + ); + } + Ok(Some(arg.value(0))) + } else { + Ok(None) + } + }; + + let values = as_generic_string_array::(&args[0])?; + let regex = Some(as_generic_string_array::(&args[1])?.value(0)); + let start = get_int_arg(2, "position")?; + let occurrence = get_int_arg(3, "occurrence")?; + let flags = if args_len > 4 { + let flags = args[4].as_string::(); + if flags.iter().any(|s| s == Some("g")) { + return plan_err!("regexp_substr() does not support the \"global\" option"); + } + Some(flags.value(0)) + } else { + None + }; + + let group_num = get_int_arg(5, "group_num")?; + + let result = + regexp_substr_inner::(values, regex, start, occurrence, flags, group_num)?; + Ok(Arc::new(result)) +} + +fn regexp_substr_inner( + values: &GenericStringArray, + regex: Option<&str>, + start: Option, + occurrence: Option, + flags: Option<&str>, + group_num: Option, +) -> Result { + let regex = match regex { + None | Some("") => { + return Ok(Arc::new(GenericStringArray::::new_null(values.len()))) + } + Some(regex) => regex, + }; + + // Check for 'e' flag and set group_num to 1 if not provided + let group_num = if flags.is_some_and(|f| f.contains('e')) { + group_num.or(Some(1)) + } else { + group_num + }; + + let regex = compile_regex(regex, flags)?; + let mut builder = GenericStringBuilder::::new(); + + values.iter().try_for_each(|value| { + match value { + Some(value) => { + // Skip characters from the beginning + let cleaned_value = if let Some(start) = start { + if start < 1 { + return Err(DataFusionError::from(ArrowError::ComputeError( + "regexp_count() requires start to be 1 based".to_string(), + ))); + } + value.chars().skip(start as usize - 1).collect() + } else { + value.to_string() + }; + + let matches = + get_matches(cleaned_value.as_str(), ®ex, occurrence, group_num); + if matches.is_empty() { + builder.append_null(); + } else { + // Return only first substring that matches the pattern + if let Some(first_match) = matches.first() { + builder.append_value(first_match); + } + } + } + _ => builder.append_null(), + } + Ok(()) + })?; + Ok(Arc::new(builder.finish())) +} + +fn get_matches( + value: &str, + regex: &Regex, + occurrence: Option, + group_num: Option, +) -> Vec { + let mut matches = Vec::new(); + let occurrence = occurrence.unwrap_or(1) as usize; + + for caps in regex.captures_iter(value) { + match group_num { + Some(group_num) => { + if let Some(m) = caps.get(group_num as usize) { + matches.push(m.as_str().to_string()); + } + } + None => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + matches.push(m.as_str().to_string()); + } + } + } + } + + if matches.len() > occurrence { + matches = matches.split_off(occurrence - 1); + } + matches +} +fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_substr() does not support global flag".to_string(), + )); + } + // Case-sensitive enabled by default + let flags = flags.replace("c", "").replace("e", ""); + if flags.is_empty() { + regex.to_string() + } else { + format!("(?{flags}){regex}") + } + } + }; + + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError( + format!("Regular expression did not compile: {pattern}",), + ) + }) +} + +#[cfg(test)] +mod tests { + use crate::regex::regexpsubstr::{regexp_substr, RegexpSubstrFunc}; + use arrow::array::{Array, ArrayRef, Int64Array, LargeStringArray, StringArray}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use std::sync::Arc; + + #[test] + fn test_regexp_substr() { + let values = [ + "Hellooo Woorld", + "How are you doing today floor?", + "the quick brown fox jumps over the lazy dog door", + "PACK MY BOX WITH FIVE DOZEN LIQUOR JUGS", + ]; + let regex = ["\\b\\S*o\\S*\\b", "(..or)"]; + let expected = [ + ["Hellooo", "How", "brown", ""], + ["Woor", "loor", "door", ""], + ]; + + // Scalar + values.iter().enumerate().for_each(|(pos, &value)| { + regex.iter().enumerate().for_each(|(rpos, regex)| { + let expected = expected.get(rpos).unwrap().get(pos).unwrap().to_string(); + + // Utf8, LargeUtf8 + for (data_type, scalar) in &[ + ( + DataType::Utf8, + ScalarValue::Utf8 as fn(Option) -> ScalarValue, + ), + ( + DataType::LargeUtf8, + ScalarValue::LargeUtf8 as fn(Option) -> ScalarValue, + ), + ] { + let args_vec = vec![ + ColumnarValue::Scalar(scalar(Some(value.to_string()))), + ColumnarValue::Scalar(scalar(Some(regex.to_string()))), + ]; + let arg_fields = args_vec + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("f_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let result = + RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: args_vec, + arg_fields, + number_rows: 1, + return_field: Arc::new(Field::new( + "f", + data_type.clone(), + true, + )), + config_options: Arc::new( + datafusion_common::config::ConfigOptions::default(), + ), + }); + match result { + Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(ref res) | ScalarValue::LargeUtf8(ref res), + )) => { + if res.is_some() { + assert_eq!( + res.as_ref().unwrap(), + &expected.to_string(), + "regexp_substr scalar test failed" + ); + } else { + assert_eq!( + "", expected, + "regexp_substr scalar utf8 test failed" + ) + } + } + _ => panic!("Unexpected result"), + } + } + }); + }); + + // Array (column) + regex.iter().enumerate().for_each(|(rpos, regex)| { + // Utf8, LargeUtf8 + for data_type in &[DataType::Utf8, DataType::LargeUtf8] { + let (array_values, regex) = match data_type { + DataType::Utf8 => ( + Arc::new(StringArray::from( + values.iter().map(|v| v.to_string()).collect::>(), + )) as ArrayRef, + ScalarValue::Utf8(Some(regex.to_string())), + ), + DataType::LargeUtf8 => ( + Arc::new(LargeStringArray::from( + values.iter().map(|v| v.to_string()).collect::>(), + )) as ArrayRef, + ScalarValue::LargeUtf8(Some(regex.to_string())), + ), + _ => unreachable!(), + }; + let args_vec = vec![ + ColumnarValue::Array(Arc::new(array_values)), + ColumnarValue::Scalar(regex), + ]; + let arg_fields = args_vec + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("f_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let result = + RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: args_vec, + arg_fields, + number_rows: 1, + return_field: Arc::new(Field::new("f", data_type.clone(), true)), + config_options: Arc::new( + datafusion_common::config::ConfigOptions::default(), + ), + }); + match result { + Ok(ColumnarValue::Array(array)) => { + let expected = expected + .get(rpos) + .unwrap() + .iter() + .map(|v| { + if v.is_empty() { + return None; + } + Some(v.to_string()) + }) + .collect::>>(); + + assert_eq!(array.data_type(), data_type, "wrong array datatype"); + match data_type { + DataType::Utf8 => { + let array = + array.as_any().downcast_ref::().unwrap(); + let expected = StringArray::from(expected); + assert_eq!( + array, &expected, + "regexp_substr array Utf8 test failed" + ); + } + DataType::LargeUtf8 => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + let expected = LargeStringArray::from(expected); + assert_eq!( + array, &expected, + "regexp_substr array LargeUtf8 test failed" + ); + } + _ => unreachable!(), + }; + } + _ => panic!("Unexpected result"), + } + } + }); + } + + #[test] + fn test_regexp_substr_with_params() { + let values = [ + "", + "aabc aabca vff ddf", + "abc abca abcD vff", + "Abcab abcD caddd", + "abCab cabcd dasaaabc VfFddd", + "ab dasacabd caBcv dasaaabcdv", + ]; + let regex = ["abc", "(abc\\S)|(bca)", "(abc)|(bca)", "(abc)|(vff)|(d)"]; + let flags = ["i", "ie", "e", "i"]; + let group_num = [0, 1, 0, 2]; + let expected = [ + ["", "abc", "abc", "Abc", "abC", "aBc"], + ["", "abca", "abca", "Abca", "abCa", "aBcv"], + ["", "abc", "abc", "bca", "abc", "abc"], + ["", "vff", "vff", "", "VfF", ""], + ]; + + // Scalar + regex.iter().enumerate().for_each(|(spos, ®ex)| { + values.iter().enumerate().for_each(|(pos, &value)| { + let expected = expected.get(spos).unwrap().get(pos).cloned().unwrap(); + // Utf8, LargeUtf8 + for (data_type, scalar) in &[ + ( + DataType::Utf8, + ScalarValue::Utf8 as fn(Option) -> ScalarValue, + ), + ( + DataType::LargeUtf8, + ScalarValue::LargeUtf8 as fn(Option) -> ScalarValue, + ), + ] { + let args_vec = vec![ + ColumnarValue::Scalar(scalar(Some(value.to_string()))), + ColumnarValue::Scalar(scalar(Some(regex.to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(scalar(Some(flags[spos].to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(group_num[spos]))), + ]; + let arg_fields = args_vec + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("f_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let result = + RegexpSubstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: args_vec, + arg_fields, + number_rows: 1, + return_field: Arc::new(Field::new( + "f", + data_type.clone(), + true, + )), + config_options: Arc::new( + datafusion_common::config::ConfigOptions::default(), + ), + }); + match result { + Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(ref res) | ScalarValue::LargeUtf8(ref res), + )) => { + if res.is_some() { + assert_eq!( + res.as_ref().unwrap(), + &expected.to_string(), + "regexp_substr scalar test failed" + ); + } else { + assert_eq!( + "", expected, + "regexp_substr scalar utf8 test failed" + ) + } + } + _ => panic!("Unexpected result"), + } + } + }) + }); + } + + #[test] + fn test_unsupported_global_flag_regexp_substr() { + let values = StringArray::from(vec!["abc"]); + let patterns = StringArray::from(vec!["^(a)"]); + let position = Int64Array::from(vec![1]); + let occurrence = Int64Array::from(vec![1]); + let flags = StringArray::from(vec!["g"]); + + let re_err = regexp_substr::(&[ + Arc::new(values), + Arc::new(patterns), + Arc::new(position), + Arc::new(occurrence), + Arc::new(flags), + ]) + .expect_err("unsupported flag should have failed"); + + assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_substr() does not support the \"global\" option"); + } +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e26805378141..13d726d58cb9 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1866,11 +1866,11 @@ mod test { " )?; - let empty = empty_with_type(DataType::Int64); + let empty = empty_with_type(DataType::Float64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); assert_type_coercion_error( plan, - "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean" + "Cannot infer common argument type for comparison operation Float64 IS DISTINCT FROM Boolean" )?; // is not true diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index d78819c7c315..af4d3eab5313 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -220,12 +220,15 @@ impl CommonSubexprEliminate { .into_iter() .zip(window_schemas) .try_rfold(new_input, |plan, (new_window_expr, schema)| { - Window::try_new_with_schema( - new_window_expr, - Arc::new(plan), + match Window::try_new_with_schema( + new_window_expr.clone(), + Arc::new(plan.clone()), schema, - ) - .map(LogicalPlan::Window) + ) { + Ok(win) => Ok(LogicalPlan::Window(win)), + Err(_) => Window::try_new(new_window_expr, Arc::new(plan)) + .map(LogicalPlan::Window), + } }) } }) @@ -564,7 +567,8 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) - | LogicalPlan::RecursiveQuery(_) => { + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Pivot(_) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like // manner. plan.map_children(|c| self.rewrite(c, config))? @@ -795,12 +799,14 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::logical_plan::{table_scan, JoinType}; + use datafusion_expr::window_frame::WindowFrame; use datafusion_expr::{ grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility, }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_functions_window::row_number::row_number_udwf; use super::*; use crate::assert_optimized_plan_eq_snapshot; @@ -1764,6 +1770,59 @@ mod test { ) } + #[test] + fn test_window_cse_rebuild_preserves_schema() { + // Build a plan similar to SELECT ... QUALIFY ROW_NUMBER() + let scan = test_table_scan().unwrap(); + let col0 = col("a"); + let col1 = col("b"); + + let wnd = Expr::WindowFunction(Box::new(datafusion_expr::expr::WindowFunction { + fun: datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( + row_number_udwf(), + ), + params: datafusion_expr::expr::WindowFunctionParams { + partition_by: vec![col0.clone()], + order_by: vec![col1.clone().sort(true, false)], + window_frame: WindowFrame::new(None), + args: vec![], + null_treatment: None, + distinct: false, + filter: None, + }, + })); + + let windowed = LogicalPlanBuilder::from(scan) + .window(vec![wnd.clone()]) + .unwrap() + .project(vec![col0.clone(), col1.clone(), wnd.clone()]) + .unwrap() + .build() + .unwrap(); + + // Simulate QUALIFY as a filter on the window output + let filtered = LogicalPlanBuilder::from(windowed) + .filter(Expr::BinaryExpr(BinaryExpr { + left: Box::new(wnd), + op: Operator::Eq, + right: Box::new(Expr::Literal( + datafusion_common::ScalarValue::UInt64(Some(1)), + None, + )), + })) + .unwrap() + .project(vec![col("a"), col("b")]) + .unwrap() + .build() + .unwrap(); + + let rule = CommonSubexprEliminate::new(); + let cfg = OptimizerContext::new(); + let res = rule.rewrite(filtered, &cfg).unwrap(); + + assert_fields_eq(&res.data, vec!["a", "b"]); + } + /// returns a "random" function that is marked volatile (aka each invocation /// returns a different value) /// diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index a72657bf689d..fffaeb0c8ce8 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -30,7 +30,9 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, plan_err, Column, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; -use datafusion_expr::logical_plan::{JoinType, Subquery}; +use datafusion_expr::logical_plan::{ + Join as LogicalJoin, JoinType, Projection, Subquery, +}; use datafusion_expr::utils::{conjunction, split_conjunction_owned}; use datafusion_expr::{ exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, @@ -66,54 +68,166 @@ impl OptimizerRule for DecorrelatePredicateSubquery { })? .data; - let LogicalPlan::Filter(filter) = plan else { - return Ok(Transformed::no(plan)); - }; + // Handle Filters first (existing behavior) + if let LogicalPlan::Filter(filter) = plan.clone() { + if !has_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } - if !has_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); - } + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter.predicate) + .into_iter() + .partition(has_subquery); - let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = - split_conjunction_owned(filter.predicate) - .into_iter() - .partition(has_subquery); + if with_subqueries.is_empty() { + return internal_err!( + "can not find expected subqueries in DecorrelatePredicateSubquery" + ); + } - if with_subqueries.is_empty() { - return internal_err!( - "can not find expected subqueries in DecorrelatePredicateSubquery" - ); + // iterate through all exists clauses in predicate, turning each into a join + let mut cur_input = Arc::unwrap_or_clone(filter.input); + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + // The subquery expression is at the top level of the filter + SubqueryPredicate::Top(subquery) => { + match build_join_top( + &subquery, + &cur_input, + config.alias_generator(), + )? { + Some(plan) => cur_input = plan, + // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter + None => other_exprs.push(subquery.expr()), + } + } + // The subquery expression is embedded within another expression + SubqueryPredicate::Embedded(expr) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = plan; + other_exprs.push(expr_without_subqueries); + } + } + } + + let expr = conjunction(other_exprs); + if let Some(expr) = expr { + let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; + return Ok(Transformed::yes(LogicalPlan::Filter(new_filter))); + } + return Ok(Transformed::yes(cur_input)); } - // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = Arc::unwrap_or_clone(filter.input); - for subquery_expr in with_subqueries { - match extract_subquery_info(subquery_expr) { - // The subquery expression is at the top level of the filter - SubqueryPredicate::Top(subquery) => { - match build_join_top(&subquery, &cur_input, config.alias_generator())? - { - Some(plan) => cur_input = plan, - // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter - None => other_exprs.push(subquery.expr()), + // Additionally handle subqueries embedded in Join.filter expressions + if let LogicalPlan::Join(join) = plan { + if let Some(predicate) = &join.filter { + if has_subquery(predicate) { + let (new_left, new_predicate) = rewrite_inner_subqueries( + Arc::unwrap_or_clone(join.left), + predicate.clone(), + config, + )?; + + let new_join = LogicalJoin::try_new( + Arc::new(new_left), + Arc::clone(&join.right), + join.on.clone(), + Some(new_predicate), + join.join_type, + join.join_constraint, + join.null_equality, + )?; + return Ok(Transformed::yes(LogicalPlan::Join(new_join))); + } + } + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + + // Handle subqueries embedded in Aggregate group/aggregate expressions + if let LogicalPlan::Aggregate(aggregate) = plan { + let mut needs_rewrite = false; + for e in &aggregate.group_expr { + if has_subquery(e) { + needs_rewrite = true; + break; + } + } + if !needs_rewrite { + for e in &aggregate.aggr_expr { + if has_subquery(e) { + needs_rewrite = true; + break; } } - // The subquery expression is embedded within another expression - SubqueryPredicate::Embedded(expr) => { - let (plan, expr_without_subqueries) = + } + if !needs_rewrite { + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + let mut cur_input = Arc::unwrap_or_clone(aggregate.input); + let mut new_group_exprs = Vec::with_capacity(aggregate.group_expr.len()); + for expr in aggregate.group_expr { + if has_subquery(&expr) { + let (next_input, rewritten_expr) = rewrite_inner_subqueries(cur_input, expr, config)?; - cur_input = plan; - other_exprs.push(expr_without_subqueries); + cur_input = next_input; + new_group_exprs.push(rewritten_expr); + } else { + new_group_exprs.push(expr); } } + let mut new_aggr_exprs = Vec::with_capacity(aggregate.aggr_expr.len()); + for expr in aggregate.aggr_expr { + if has_subquery(&expr) { + let old_name = expr.schema_name().to_string(); + let (next_input, rewritten_expr) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = next_input; + let new_name = rewritten_expr.schema_name().to_string(); + if new_name != old_name { + new_aggr_exprs.push(rewritten_expr.alias(old_name)); + } else { + new_aggr_exprs.push(rewritten_expr); + } + } else { + new_aggr_exprs.push(expr); + } + } + + let new_plan = LogicalPlanBuilder::from(cur_input) + .aggregate(new_group_exprs, new_aggr_exprs)? + .build()?; + return Ok(Transformed::yes(new_plan)); } - let expr = conjunction(other_exprs); - if let Some(expr) = expr { - let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; - cur_input = LogicalPlan::Filter(new_filter); + // Handle Projection nodes with subqueries in expressions + if let LogicalPlan::Projection(proj) = plan { + // Only proceed if any projection expression contains a subquery + if !proj.expr.iter().any(has_subquery) { + return Ok(Transformed::no(LogicalPlan::Projection(proj))); + } + + let mut cur_input = Arc::unwrap_or_clone(proj.input); + let mut new_exprs = Vec::with_capacity(proj.expr.len()); + for e in proj.expr { + let old_name = e.schema_name().to_string(); + let (plan_after, rewritten) = + rewrite_inner_subqueries(cur_input, e, config)?; + cur_input = plan_after; + let new_name = rewritten.schema_name().to_string(); + if new_name != old_name { + new_exprs.push(rewritten.alias(old_name)); + } else { + new_exprs.push(rewritten); + } + } + let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?; + return Ok(Transformed::yes(LogicalPlan::Projection(new_proj))); } - Ok(Transformed::yes(cur_input)) + + // Other plans unchanged + Ok(Transformed::no(plan)) } fn name(&self) -> &str { @@ -455,6 +569,45 @@ mod tests { )) } + /// Aggregation with CASE WHEN ... IN (subquery) should be decorrelated under the Aggregate + #[test] + fn aggregate_case_in_subquery() -> Result<()> { + let table_scan = test_table_scan_with_name("distinct_source")?; + use datafusion_expr::expr_fn::when; + use datafusion_functions_aggregate::expr_fn::max as agg_max; + + let agg_b: Expr = agg_max(col("distinct_source.b")); + let subq = LogicalPlanBuilder::from(table_scan.clone()) + .aggregate(Vec::::new(), vec![agg_b])? + .project(vec![col("max(distinct_source.b)")])? + .build()?; + + let case_expr = when( + in_subquery(col("distinct_source.b"), Arc::new(subq)), + lit(1), + ) + .otherwise(lit(0))?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("distinct_source.a").alias("primary_key")], + vec![ + agg_max(case_expr).alias("is_in_most_recent_task"), + agg_max(col("distinct_source.c")).alias("max_timestamp"), + ], + )? + .build()?; + + use crate::{OptimizerContext, OptimizerRule}; + let optimized = DecorrelatePredicateSubquery::new() + .rewrite(plan, &OptimizerContext::new())? + .data; + let lp = optimized.display_indent().to_string(); + assert!(lp.contains("Aggregate:")); + assert!(lp.contains("Left")); + Ok(()) + } + /// Test for several IN subquery expressions #[test] fn in_subquery_multiple() -> Result<()> { @@ -545,6 +698,40 @@ mod tests { ) } + /// Projection IN (subquery) should be decorrelated via LeftMark join in Projection + #[test] + fn projection_in_subquery_simple() -> Result<()> { + // Build outer values t(a) = (1),(2) + let outer = LogicalPlanBuilder::values(vec![vec![lit(1_i32)], vec![lit(2_i32)]])? + .project(vec![col("column1").alias("a")])? + .build()?; + + // Build subquery u(a) = (2) + let sub = Arc::new( + LogicalPlanBuilder::values(vec![vec![lit(2_i32)]])? + .project(vec![col("column1").alias("ua")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer) + .project(vec![col("a"), in_subquery(col("a"), sub).alias("flag")])? + .build()?; + + // We expect a LeftMark join inserted and the projection keeps columns + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, __correlated_sq_1.mark AS flag [a:Int32;N, flag:Boolean] + LeftMark Join: Filter: a = __correlated_sq_1.ua [a:Int32;N, mark:Boolean] + Projection: column1 AS a [a:Int32;N] + Values: (Int32(1)), (Int32(2)) [column1:Int32;N] + SubqueryAlias: __correlated_sq_1 [ua:Int32;N] + Projection: column1 AS ua [ua:Int32;N] + Values: (Int32(2)) [column1:Int32;N] + " + ) + } + /// Test multiple correlated subqueries /// See subqueries.rs where_in_multiple() #[test] diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 280010e3d92c..19eda3beabbd 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -58,6 +58,7 @@ pub mod optimizer; pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; +pub mod reorder_join; pub mod replace_distinct_aggregate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index d6e3f6051f34..80355e5267de 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -404,6 +404,9 @@ fn optimize_projections( }); vec![required_indices.append(&additional_necessary_child_indices)] } + LogicalPlan::Pivot(_) => { + return Ok(Transformed::no(plan)); + } }; // Required indices are currently ordered (child0, child1, ...) diff --git a/datafusion/optimizer/src/reorder_join/cost.rs b/datafusion/optimizer/src/reorder_join/cost.rs new file mode 100644 index 000000000000..e346c8bc6441 --- /dev/null +++ b/datafusion/optimizer/src/reorder_join/cost.rs @@ -0,0 +1,59 @@ +use datafusion_common::{plan_datafusion_err, plan_err, stats::Precision, Result}; +use datafusion_expr::{Join, JoinType, LogicalPlan}; + +pub trait JoinCostEstimator: std::fmt::Debug { + fn cardinality(&self, plan: &LogicalPlan) -> Option { + estimate_cardinality(plan).ok() + } + + fn selectivity(&self, join: &Join) -> f64 { + match join.join_type { + JoinType::Inner => 0.1, + _ => 1.0, + } + } + + fn cost(&self, selectivity: f64, cardinality: f64) -> f64 { + selectivity * cardinality + } +} + +/// Default implementation of JoinCostEstimator +#[derive(Debug, Clone, Copy)] +pub struct DefaultCostEstimator; + +impl JoinCostEstimator for DefaultCostEstimator {} + +fn estimate_cardinality(plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Filter(filter) => { + let input_cardinality = estimate_cardinality(&filter.input)?; + Ok(0.1 * input_cardinality) + } + LogicalPlan::Aggregate(agg) => { + let input_cardinality = estimate_cardinality(&agg.input)?; + Ok(0.1 * input_cardinality) + } + LogicalPlan::TableScan(scan) => { + let statistics = scan + .source + .statistics() + .ok_or_else(|| plan_datafusion_err!("Table statistics not available"))?; + if let Precision::Exact(num_rows) | Precision::Inexact(num_rows) = + statistics.num_rows + { + Ok(num_rows as f64) + } else { + plan_err!("Number of rows not available") + } + } + x => { + let inputs = x.inputs(); + if inputs.len() == 1 { + estimate_cardinality(inputs[0]) + } else { + plan_err!("Cannot estimate cardinality for plan with multiple inputs") + } + } + } +} diff --git a/datafusion/optimizer/src/reorder_join/left_deep_join_plan.rs b/datafusion/optimizer/src/reorder_join/left_deep_join_plan.rs new file mode 100644 index 000000000000..1ec187ac1d76 --- /dev/null +++ b/datafusion/optimizer/src/reorder_join/left_deep_join_plan.rs @@ -0,0 +1,882 @@ +use std::{collections::HashSet, fmt::Debug, rc::Rc, sync::Arc}; + +use datafusion_common::{plan_datafusion_err, plan_err, Result}; +use datafusion_expr::LogicalPlan; + +use crate::reorder_join::{ + cost::JoinCostEstimator, + query_graph::{NodeId, QueryGraph}, +}; + +/// Generates an optimized left-deep join plan from a logical plan using the Ibaraki-Kameda algorithm. +/// +/// This function is the main entry point for join reordering optimization. It takes a logical plan +/// that may contain joins along with wrapper operators (filters, sorts, aggregations, etc.) and +/// produces an optimized plan with reordered joins while preserving the wrapper operators. +/// +/// # Algorithm Overview +/// +/// The optimization process consists of several steps: +/// +/// 1. **Extraction**: Separates the join subtree from wrapper operators (filters, sorts, limits, etc.) +/// 2. **Graph Conversion**: Converts the join subtree into a query graph representation where: +/// - Nodes represent base relations (table scans, subqueries, etc.) +/// - Edges represent join conditions between relations +/// 3. **Optimization**: Uses the Ibaraki-Kameda algorithm to find the optimal left-deep join ordering +/// by trying each node as a potential root and selecting the plan with the lowest estimated cost +/// 4. **Reconstruction**: Rebuilds the complete logical plan by applying the wrapper operators +/// to the optimized join plan +/// +/// # Left-Deep Join Plans +/// +/// A left-deep join plan is a join tree where: +/// - Each join has a relation or previous join result on the left side +/// - Each join has a single relation on the right side +/// - This creates a linear "chain" of joins processed left-to-right +/// +/// Example: `((A ⋈ B) ⋈ C) ⋈ D` is left-deep, while `(A ⋈ B) ⋈ (C ⋈ D)` is not. +/// +/// Left-deep plans are preferred because they: +/// - Allow pipelining of intermediate results +/// - Work well with hash join implementations +/// - Have predictable memory usage patterns +/// +/// # Arguments +/// +/// * `plan` - The logical plan to optimize. Must contain at least one join node. +/// * `cost_estimator` - Cost estimator for calculating join costs, cardinality, and selectivity. +/// Used to compare different join orderings and select the optimal one. +/// +/// # Returns +/// +/// Returns a `LogicalPlan` with optimized join ordering. The plan structure is: +/// - Wrapper operators (filters, sorts, etc.) in their original positions +/// - Joins reordered to minimize estimated execution cost +/// - Join semantics preserved (same result set as input plan) +/// +/// # Errors +/// +/// Returns an error if: +/// - The plan does not contain any join nodes +/// - Join extraction fails (e.g., joins are not consecutive in the plan tree) +/// - The query graph cannot be constructed from the join subtree +/// - Join reordering optimization fails (no valid join ordering found) +/// - Plan reconstruction fails +/// +/// # Example +/// +/// ```ignore +/// use datafusion_optimizer::reorder_join::{optimal_left_deep_join_plan, cost::JoinCostEstimator}; +/// use std::rc::Rc; +/// +/// // Assume we have a plan with joins: customer ⋈ orders ⋈ lineitem +/// let plan = ...; // Your logical plan +/// let cost_estimator: Rc = Rc::new(MyCostEstimator::new()); +/// +/// // Optimize join ordering +/// let optimized = optimal_left_deep_join_plan(plan, cost_estimator)?; +/// // Result might reorder to: lineitem ⋈ orders ⋈ customer (if this is cheaper) +/// ``` +pub fn optimal_left_deep_join_plan( + plan: LogicalPlan, + cost_estimator: Rc, +) -> Result { + // Extract the join subtree and wrappers + let (join_subtree, wrappers) = + crate::reorder_join::query_graph::extract_join_subtree(plan)?; + + // Convert join subtree to query graph + let query_graph = QueryGraph::try_from(join_subtree)?; + + // Optimize the joins + let optimized_joins = + query_graph_to_optimal_left_deep_join_plan(query_graph, cost_estimator)?; + + // Reconstruct the full plan with wrappers + + crate::reorder_join::query_graph::reconstruct_plan(optimized_joins, wrappers) +} + +/// Generates an optimized linear join plan from a query graph using the Ibaraki-Kameda algorithm. +/// +/// This function finds the optimal join ordering for a query by: +/// 1. Trying each node in the query graph as a potential root +/// 2. For each root, building a precedence tree and optimizing it through normalization/denormalization +/// 3. Selecting the plan with the lowest estimated cost +/// +/// The optimization process uses the Ibaraki-Kameda algorithm, which arranges joins to minimize +/// intermediate result sizes by considering both cardinality and cost estimates. +/// +/// # Algorithm Steps +/// +/// For each candidate root node: +/// 1. **Construction**: Build a precedence tree from the query graph starting at that node +/// 2. **Normalization**: Transform the tree into a chain structure ordered by rank +/// 3. **Denormalization**: Split merged operations back into individual nodes while maintaining chain structure +/// 4. **Cost Comparison**: Compare the resulting plan's cost against the current best +/// +/// # Arguments +/// +/// * `query_graph` - The query graph containing logical plan nodes and join specifications +/// * `cost_estimator` - The cost estimator to use for calculating cardinality, selectivity, and cost +/// +/// # Returns +/// +/// Returns a `LogicalPlan` representing the optimal join ordering with the lowest estimated cost. +/// +/// # Errors +/// +/// Returns an error if: +/// - The query graph is empty or invalid +/// - Tree construction, normalization, or denormalization fails +/// - No valid precedence graph can be generated +pub fn query_graph_to_optimal_left_deep_join_plan( + query_graph: QueryGraph, + cost_estimator: Rc, +) -> Result { + let mut best_graph: Option = None; + + for (node_id, _) in query_graph.nodes() { + let mut precedence_graph = PrecedenceTreeNode::from_query_graph( + &query_graph, + node_id, + Rc::clone(&cost_estimator), + )?; + precedence_graph.normalize(); + precedence_graph.denormalize()?; + + best_graph = match best_graph.take() { + Some(current) => { + let new_cost = precedence_graph.cost()?; + if new_cost < current.cost()? { + Some(precedence_graph) + } else { + Some(current) + } + } + None => Some(precedence_graph), + }; + } + + best_graph + .ok_or_else(|| plan_datafusion_err!("No valid precedence graph found"))? + .into_logical_plan(&query_graph) +} + +#[derive(Debug)] +struct QueryNode { + node_id: NodeId, + // T in [IbarakiKameda84] + selectivity: f64, + // C in [IbarakiKameda84] + cost: f64, +} + +impl QueryNode { + fn rank(&self) -> f64 { + (self.selectivity - 1.0) / self.cost + } +} + +/// A node in the precedence tree for query optimization. +/// +/// The precedence tree is a data structure used by the Ibaraki-Kameda algorithm for +/// optimizing join ordering in database queries. It can represent both arbitrary tree +/// structures and linear chain structures (where each node has at most one child). +/// +/// # Lifecycle +/// +/// A typical precedence tree goes through three phases: +/// +/// 1. **Construction** ([`from_query_graph`](Self::from_query_graph)): Build an initial tree +/// from a query graph, creating nodes with cost/cardinality estimates +/// 2. **Normalization** ([`normalize`](Self::normalize)): Transform the tree into a chain +/// where nodes are ordered by rank, potentially merging multiple query operations into +/// single nodes +/// 3. **Denormalization** ([`denormalize`](Self::denormalize)): Split merged operations back +/// into individual nodes while maintaining the optimized chain structure +/// +/// The result is a linear execution order that minimizes intermediate result sizes. +/// +/// # Fields +/// +/// * `query_nodes` - Vector of query operations with cost estimates. In an initial tree, +/// contains one operation. After normalization, may contain multiple merged operations. +/// After denormalization, contains exactly one operation. +/// * `children` - Child nodes in the tree. In a normalized/denormalized chain, contains +/// at most one child. In an arbitrary tree, may contain multiple children. +/// * `query_graph` - Reference to the original query graph, used for accessing node +/// relationships and metadata during tree transformations. +struct PrecedenceTreeNode<'graph> { + query_nodes: Vec, + children: Vec>, + query_graph: &'graph QueryGraph, +} + +impl Debug for PrecedenceTreeNode<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrecedenceTreeNode") + .field("query_nodes", &self.query_nodes) + .field("children", &self.children) + .finish() + } +} + +impl<'graph> PrecedenceTreeNode<'graph> { + /// Creates a precedence tree from a query graph. + /// + /// This is the main entry point for transforming a query graph into a precedence tree + /// structure. The tree represents an initial join ordering with cost and cardinality + /// estimates for query optimization. + /// + /// The function performs a depth-first traversal starting from the root node, + /// building a tree where: + /// - Each node contains cost/cardinality estimates for a query operation + /// - Children represent connected query nodes (joins, filters, etc.) + /// - The root node starts with selectivity of 1.0 (no filtering) + /// + /// # Arguments + /// + /// * `graph` - The query graph to transform into a precedence tree + /// * `root_id` - The ID of the node to use as the root of the tree + /// * `cost_estimator` - The cost estimator to use for calculating cardinality, selectivity, and cost + /// + /// # Returns + /// + /// Returns a `PrecedenceTreeNode` representing the entire query graph as a tree structure, + /// with the specified root node at the top. + /// + /// # Errors + /// + /// Returns an error if: + /// - The `root_id` is not found in the query graph + /// - Any connected node cannot be found during traversal + pub(crate) fn from_query_graph( + graph: &'graph QueryGraph, + root_id: NodeId, + cost_estimator: Rc, + ) -> Result { + let mut remaining: HashSet = graph.nodes().map(|(x, _)| x).collect(); + remaining.remove(&root_id); + PrecedenceTreeNode::from_query_node( + root_id, + 1.0, + graph, + &mut remaining, + cost_estimator, + true, + ) + } + + /// Recursively constructs a precedence tree node from a query graph node. + /// + /// This function builds a tree structure by: + /// 1. Creating a node with cost and cardinality estimates for the current query node + /// 2. Recursively processing all connected unvisited nodes as children + /// 3. Removing visited nodes from the `remaining` set to avoid cycles + /// + /// # Arguments + /// + /// * `node_id` - The ID of the query graph node to process + /// * `selectivity` - The selectivity factor from the parent edge (1.0 for root) + /// * `query_graph` - Reference to the query graph being transformed + /// * `remaining` - Mutable set of node IDs not yet visited (updated during traversal) + /// * `cost_estimator` - The cost estimator to use for calculating cardinality, selectivity, and cost + /// + /// # Returns + /// + /// Returns a `PrecedenceTreeNode` containing: + /// - A single `NodeEstimates` with cardinality and cost based on input cardinality and selectivity + /// - Child nodes for each connected unvisited neighbor in the query graph + /// + /// # Errors + /// + /// Returns an error if the specified `node_id` is not found in the query graph. + fn from_query_node( + node_id: NodeId, + selectivity: f64, + query_graph: &'graph QueryGraph, + remaining: &mut HashSet, + cost_estimator: Rc, + is_root: bool, + ) -> Result { + let node = query_graph + .get_node(node_id) + .ok_or_else(|| plan_datafusion_err!("Root node not found"))?; + let input_cardinality = cost_estimator.cardinality(&node.plan).unwrap_or(1.0); + + let children = node + .connections() + .iter() + .filter_map(|edge_id| { + let edge = query_graph.get_edge(*edge_id)?; + let other = edge + .nodes + .into_iter() + .find(|x| *x != node_id && remaining.contains(x))?; + + remaining.remove(&other); + let child_selectivity = cost_estimator.selectivity(&edge.join); + Some(PrecedenceTreeNode::from_query_node( + other, + child_selectivity, + query_graph, + remaining, + Rc::clone(&cost_estimator), + false, + )) + }) + .collect::>>()?; + + Ok(PrecedenceTreeNode { + query_nodes: vec![QueryNode { + node_id, + selectivity: (selectivity * input_cardinality), + cost: if is_root { + 0.0 + } else { + cost_estimator.cost(selectivity, input_cardinality) + }, + }], + children, + query_graph, + }) + } + + /// Rank function according to IbarakiKameda84 + fn rank(&self) -> f64 { + let (cardinality, cost) = + self.query_nodes + .iter() + .fold((1.0, 0.0), |(cardinality, cost), node| { + let cost = cost + cardinality * node.cost; + let cardinality = cardinality * node.selectivity; + (cardinality, cost) + }); + if cost == 0.0 { + 0.0 + } else { + (cardinality - 1.0) as f64 / cost + } + } + + /// Normalizes the precedence tree into a linear chain structure. + /// + /// This transformation converts the tree into a normalized form where each node + /// has at most one child, creating a linear sequence of query nodes. The normalization + /// process uses the rank function to determine optimal ordering according to the + /// Ibaraki-Kameda algorithm. + /// + /// The normalization handles three cases: + /// - **Leaf nodes (0 children)**: Already normalized, no action needed + /// - **Single child (1 child)**: If the child has lower rank than current node, merge + /// the child's query nodes into the current node, creating a sequence. Otherwise, + /// recursively normalize the child. + /// - **Multiple children (2+ children)**: Recursively normalize all children into chains, + /// then merge all child chains into a single chain using the merge operation. + /// + /// After normalization, the tree becomes a chain where nodes are ordered by their + /// rank values, with each node containing one or more query operations in sequence. + /// + /// # Algorithm + /// + /// Based on the Ibaraki-Kameda join ordering algorithm, which optimizes query + /// execution by arranging operations to minimize intermediate result sizes. + fn normalize(&mut self) { + match self.children.len() { + 0 => (), + 1 => { + // If child has lower rank, merge it into current node + if self.children[0].rank() < self.rank() { + let mut child = self.children.pop().unwrap(); + self.query_nodes.append(&mut child.query_nodes); + self.children = child.children; + self.normalize(); + } else { + self.children[0].normalize(); + } + } + _ => { + // Normalize all child trees into chains, then merge them + for child in &mut self.children { + child.normalize(); + } + let child = std::mem::take(&mut self.children) + .into_iter() + .reduce(Self::merge) + .unwrap(); + self.children = vec![child]; + } + } + } + + /// Merges two precedence tree chains into a single chain. + /// + /// This operation combines two normalized tree chains (each with at most one child) + /// into a single chain, preserving rank ordering. The chain with the lower rank becomes + /// the parent, and the higher-ranked chain is attached as a descendant. + /// + /// The merge strategy depends on whether the lower-ranked chain has children: + /// - **No children**: The higher-ranked chain becomes the direct child + /// - **Has child**: Recursively merge the higher-ranked chain with the child, + /// maintaining the chain structure + /// + /// This ensures the resulting chain maintains proper rank ordering from root to leaf, + /// which is essential for the Ibaraki-Kameda optimization algorithm. + /// + /// # Arguments + /// + /// * `self` - The first tree chain to merge + /// * `other` - The second tree chain to merge + /// + /// # Returns + /// + /// Returns a merged `PrecedenceTreeNode` chain with both input chains combined, + /// ordered by rank values. + /// + /// # Panics + /// + /// May panic if called on non-normalized trees (trees with multiple children). + fn merge(self, other: PrecedenceTreeNode<'graph>) -> Self { + let (mut first, second) = if self.rank() < other.rank() { + (self, other) + } else { + (other, self) + }; + if first.children.is_empty() { + first.children = vec![second]; + } else { + first.children = vec![first.children.pop().unwrap().merge(second)]; + } + first + } + + /// Denormalizes a normalized precedence tree by splitting merged query nodes. + /// + /// This is the inverse operation of normalization, but with a critical property: + /// **the result is still a chain structure** (each node has at most one child). + /// It converts a normalized chain where nodes contain multiple query operations + /// into a longer chain where each node contains exactly one query operation. + /// + /// The denormalization process: + /// 1. **Validates input**: Ensures the tree is normalized (0 or 1 children per node) + /// 2. **Recursively processes children**: Denormalizes the child chain first + /// 3. **Splits merged nodes**: For nodes with multiple query operations, iteratively + /// extracts operations one at a time based on neighbor relationships with the child + /// 4. **Maintains ordering**: Uses rank-based selection to determine which query node + /// to extract next, choosing the highest-ranked neighbor of the child node + /// + /// **Key property**: After denormalization, the result remains a chain (not a tree with + /// branches). Each node contains exactly one query operation, but the chain structure + /// is preserved. This is the essence of the normalize-denormalize algorithm: transforming + /// an arbitrary tree into an optimized chain while respecting query dependencies. + /// + /// # Errors + /// + /// Returns an error if: + /// - The tree is not normalized (has more than one child at any level) + /// + /// # Algorithm + /// + /// The splitting process uses the query graph's neighbor relationships to determine + /// which nodes should be adjacent in the chain, maintaining logical dependencies + /// between query operations while producing a linear execution order. + fn denormalize(&mut self) -> Result<()> { + // Normalized trees must have 0 or 1 children + match self.children.len() { + 0 => (), + 1 => self.children[0].denormalize()?, + _ => return plan_err!("Tree is not normalized"), + } + + // Split query nodes into a chain based on neighbor relationships + while self.query_nodes.len() > 1 { + if self.children.is_empty() { + let highest_rank_idx = self + .query_nodes + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.rank().partial_cmp(&b.rank()).unwrap()) + .map(|(idx, _)| idx) + .unwrap(); + + let node = self.query_nodes.remove(highest_rank_idx); + + self.children.push(PrecedenceTreeNode { + query_nodes: vec![node], + children: Vec::new(), + query_graph: self.query_graph, + }); + } else { + let child_id = self.children[0].query_nodes[0].node_id; + let child_node = self.query_graph.get_node(child_id).unwrap(); + let neighbours = child_node.neighbours(child_id, self.query_graph); + + // Find the highest-ranked neighbor node + let highest_rank_idx = self + .query_nodes + .iter() + .enumerate() + .filter(|(_, node)| neighbours.contains(&node.node_id)) + .max_by(|(_, a), (_, b)| a.rank().partial_cmp(&b.rank()).unwrap()) + .map(|(idx, _)| idx) + .unwrap(); + + let node = self.query_nodes.remove(highest_rank_idx); + + let child = std::mem::replace( + &mut self.children[0], + PrecedenceTreeNode { + query_nodes: vec![node], + children: Vec::new(), + query_graph: self.query_graph, + }, + ); + self.children[0].children = vec![child]; + }; + + // Insert the node between current and its child + } + Ok(()) + } + + /// Converts the precedence tree chain into a DataFusion `LogicalPlan`. + /// + /// This method walks down the optimized chain structure, building a left-deep join tree + /// by repeatedly joining the accumulated result with the next node in the chain. + /// + /// # Algorithm + /// + /// 1. Start with the first node's `LogicalPlan` from the query graph + /// 2. For each subsequent node in the chain: + /// - Get the node's `LogicalPlan` from the query graph + /// - Find the edge connecting the current and next nodes + /// - Create a join using the edge's join specification + /// - The accumulated plan becomes the left side of the join + /// 3. Return the final joined `LogicalPlan` + /// + /// # Arguments + /// + /// * `query_graph` - The query graph containing the logical plans and join specifications + /// + /// # Returns + /// + /// Returns a `LogicalPlan` representing the optimized join execution order. + /// + /// # Errors + /// + /// Returns an error if: + /// - A node or edge is missing from the query graph + /// - The precedence tree is not in the expected chain format + pub(crate) fn into_logical_plan( + self, + query_graph: &QueryGraph, + ) -> Result { + // Get the first node's logical plan + let current_node_id = self.query_nodes[0].node_id; + let mut current_plan = query_graph + .get_node(current_node_id) + .ok_or_else(|| plan_datafusion_err!("Node {:?} not found", current_node_id))? + .plan + .as_ref() + .clone(); + + // Track all processed nodes in order + let mut processed_nodes = vec![current_node_id]; + + // Walk down the chain, joining each subsequent node + let mut current_chain = &self; + + while !current_chain.children.is_empty() { + let child = ¤t_chain.children[0]; + let next_node_id = child.query_nodes[0].node_id; + + // Get the next node's logical plan + let next_plan = query_graph + .get_node(next_node_id) + .ok_or_else(|| plan_datafusion_err!("Node {:?} not found", next_node_id))? + .plan + .as_ref() + .clone(); + + // Find the edge connecting next_node to any processed node + let next_node = query_graph.get_node(next_node_id).ok_or_else(|| { + plan_datafusion_err!("Node {:?} not found", next_node_id) + })?; + + let edge = processed_nodes + .iter() + .rev() + .find_map(|&processed_id| { + next_node.connection_with(processed_id, query_graph) + }) + .ok_or_else(|| { + plan_datafusion_err!( + "No edge found between {:?} and any processed nodes {:?}", + next_node_id, + processed_nodes + ) + })?; + + // Determine if the join order was swapped compared to the original edge. + // We check if the qualified columns (relation + name) from the join expressions + // match the schemas. This handles all cases including when multiple tables + // have columns with the same name. + let current_schema = current_plan.schema(); + let next_schema = next_plan.schema(); + + let join_order_swapped = if !edge.join.on.is_empty() { + // Helper to check if a qualified column exists in a schema + let column_in_schema = |col: &datafusion_common::Column, + schema: &datafusion_common::DFSchema| + -> bool { + if let Some(relation) = &col.relation { + // Column has a table qualifier - must match exactly (relation + name) + schema.iter().any(|(qualifier, field)| { + qualifier == Some(relation) && field.name() == col.name() + }) + } else { + // Unqualified column - check if the name exists anywhere in schema + schema.field_with_unqualified_name(&col.name).is_ok() + } + }; + + // Collect all columns from all join conditions + let mut all_left_columns = vec![]; + let mut all_right_columns = vec![]; + + for (left_expr, right_expr) in &edge.join.on { + all_left_columns.extend(left_expr.column_refs()); + all_right_columns.extend(right_expr.column_refs()); + } + + // Check which schema each expression's columns belong to + let left_in_current = + all_left_columns.iter().all(|c| column_in_schema(c, current_schema.as_ref())); + let right_in_next = + all_right_columns.iter().all(|c| column_in_schema(c, next_schema.as_ref())); + let left_in_next = + all_left_columns.iter().all(|c| column_in_schema(c, next_schema.as_ref())); + let right_in_current = + all_right_columns.iter().all(|c| column_in_schema(c, current_schema.as_ref())); + + // Determine swap based on where the qualified columns are found + if left_in_current && right_in_next { + // Left expression belongs to current, right to next → no swap + false + } else if left_in_next && right_in_current { + // Left expression belongs to next, right to current → swap + true + } else { + // Ambiguous or error case - default to no swap to preserve original order + // This shouldn't happen with properly qualified columns + false + } + } else { + // If there are no join conditions, we can't determine swap status + // This shouldn't happen in practice for equi-joins + false + }; + + // When the join order is swapped, we need to adjust the on conditions and join type + // to maintain correct semantics. For example: + // - Original: A LeftSemi B ON A.x = B.y + // - After swap: B RightSemi A ON B.y = A.x + let (on, join_type) = if join_order_swapped { + let swapped_on = edge + .join + .on + .iter() + .map(|(left, right)| (right.clone(), left.clone())) + .collect(); + (swapped_on, edge.join.join_type.swap()) + } else { + (edge.join.on.clone(), edge.join.join_type) + }; + + // Create the join plan + current_plan = LogicalPlan::Join(datafusion_expr::Join { + left: Arc::new(current_plan), + right: Arc::new(next_plan), + on, + filter: edge.join.filter.clone(), + join_type, + join_constraint: edge.join.join_constraint, + schema: Arc::clone(&edge.join.schema), + null_equality: edge.join.null_equality, + }); + + // Move to the next node in the chain + processed_nodes.push(next_node_id); + current_chain = child; + } + + Ok(current_plan) + } + + fn cost(&self) -> Result { + self.cost_recursive(self.query_nodes[0].selectivity, 0.0) + } + + fn cost_recursive(&self, cardinality: f64, cost: f64) -> Result { + let cost = match self.children.len() { + 0 => cost + cardinality * self.query_nodes[0].cost, + 1 => self.children[0].cost_recursive( + cardinality * self.query_nodes[0].selectivity, + cost + cardinality * self.query_nodes[0].cost, + )?, + _ => { + return plan_err!( + "Cost calculation requires normalized tree with 0 or 1 children" + ) + } + }; + Ok(cost) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; + use crate::eliminate_filter::EliminateFilter; + use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; + use crate::filter_null_join_keys::FilterNullJoinKeys; + use crate::optimizer::{Optimizer, OptimizerContext}; + use crate::push_down_filter::PushDownFilter; + use crate::reorder_join::cost::JoinCostEstimator; + use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; + use crate::simplify_expressions::SimplifyExpressions; + use crate::test::*; + use datafusion_expr::logical_plan::JoinType; + use datafusion_expr::LogicalPlanBuilder; + + /// A simple cost estimator for testing + #[derive(Debug)] + struct TestCostEstimator; + + impl JoinCostEstimator for TestCostEstimator {} + + /// A simple TableSource implementation for testing join ordering with statistics + #[derive(Debug)] + struct JoinSource { + schema: arrow::datatypes::SchemaRef, + num_rows: usize, + } + + impl datafusion_expr::TableSource for JoinSource { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + Arc::clone(&self.schema) + } + + fn statistics(&self) -> Option { + use datafusion_common::stats::Precision; + Some( + datafusion_common::Statistics::new_unknown(&self.schema) + .with_num_rows(Precision::Exact(self.num_rows)), + ) + } + } + + /// Create a table scan with statistics for testing join ordering + fn scan_tpch_table_with_stats(table: &str, num_rows: usize) -> LogicalPlan { + let schema = Arc::new(get_tpch_table_schema(table)); + let table_source: Arc = Arc::new(JoinSource { + schema: Arc::clone(&schema), + num_rows, + }); + LogicalPlanBuilder::scan(table, table_source, None) + .unwrap() + .build() + .unwrap() + } + + /// Test three-way join: customer -> orders -> lineitem + #[test] + fn test_three_way_join_customer_orders_lineitem() -> Result<()> { + use datafusion_expr::test::function_stub::sum; + use datafusion_expr::{col, in_subquery, lit}; + // Create the base table scans with statistics + // Create the base table scans with statistics + let customer = scan_tpch_table_with_stats("customer", 150); + let orders = scan_tpch_table_with_stats("orders", 1_500); + let lineitem = scan_tpch_table_with_stats("lineitem", 6_000); + + // Step 1: Build the subquery + // SELECT l_orderkey FROM lineitem + // GROUP BY l_orderkey + // HAVING sum(l_quantity) > 300 + let subquery = LogicalPlanBuilder::from(lineitem.clone()) + .aggregate(vec![col("l_orderkey")], vec![sum(col("l_quantity"))])? + .filter(sum(col("l_quantity")).gt(lit(300)))? + .project(vec![col("l_orderkey")])? + .build()?; + + // Step 2: Build the main query with joins + let plan = LogicalPlanBuilder::from(customer.clone()) + .join( + orders.clone(), + JoinType::Inner, + (vec!["c_custkey"], vec!["o_custkey"]), + None, + )? + .join( + lineitem.clone(), + JoinType::Inner, + (vec!["o_orderkey"], vec!["l_orderkey"]), + None, + )? + // Step 3: Apply the IN subquery filter + .filter(in_subquery(col("o_orderkey"), Arc::new(subquery)))? + // Step 4: Aggregate + .aggregate( + vec![ + col("c_name"), + col("c_custkey"), + col("o_orderkey"), + col("o_totalprice"), + ], + vec![sum(col("l_quantity"))], + )? + // Step 5: Sort + .sort(vec![col("o_totalprice").sort(false, true)])? + // Step 6: Limit + .limit(0, Some(100))? + .build()?; + + println!("{}", plan.display_indent()); + + // Optimize the plan with custom optimizer before join reordering + // We exclude OptimizeProjections to keep joins consecutive + let config = OptimizerContext::new().with_skip_failing_rules(false); + let optimizer = Optimizer::with_rules(vec![ + Arc::new(SimplifyExpressions::new()), + Arc::new(DecorrelatePredicateSubquery::new()), + Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateFilter::new()), + Arc::new(FilterNullJoinKeys::default()), + Arc::new(PushDownFilter::new()), + // Note: OptimizeProjections is intentionally excluded to keep joins consecutive + ]); + let plan = optimizer.optimize(plan, &config, |_, _| {}).unwrap(); + + println!("After standard optimization:"); + println!("{}", plan.display_indent()); + + let optimized_plan = + optimal_left_deep_join_plan(plan, Rc::new(TestCostEstimator)).unwrap(); + + println!("Optimized Plan:"); + println!("{}", optimized_plan.display_indent()); + + // Verify the plan structure + assert!(matches!(optimized_plan, LogicalPlan::Limit(_))); + + Ok(()) + } +} diff --git a/datafusion/optimizer/src/reorder_join/mod.rs b/datafusion/optimizer/src/reorder_join/mod.rs new file mode 100644 index 000000000000..758f1daeacb1 --- /dev/null +++ b/datafusion/optimizer/src/reorder_join/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule for reordering joins to minimize query execution cost + +pub mod cost; +pub mod left_deep_join_plan; +pub mod query_graph; diff --git a/datafusion/optimizer/src/reorder_join/query_graph.rs b/datafusion/optimizer/src/reorder_join/query_graph.rs new file mode 100644 index 000000000000..ce5e67f29fe4 --- /dev/null +++ b/datafusion/optimizer/src/reorder_join/query_graph.rs @@ -0,0 +1,492 @@ +use std::sync::Arc; + +use datafusion_common::{ + plan_err, + tree_node::{TreeNode, TreeNodeRecursion}, + DataFusionError, Result, +}; +use datafusion_expr::{utils::check_all_columns_from_schema, Join, LogicalPlan}; + +pub type NodeId = usize; + +pub struct Node { + pub plan: Arc, + pub(crate) connections: Vec, +} + +impl Node { + pub(crate) fn connections(&self) -> &[EdgeId] { + &self.connections + } + + pub(crate) fn connection_with<'graph>( + &self, + node_id: NodeId, + query_graph: &'graph QueryGraph, + ) -> Option<&'graph Edge> { + self.connections + .iter() + .filter_map(|edge_id| query_graph.get_edge(*edge_id)) + .find(move |x| x.nodes.contains(&node_id)) + } + + pub(crate) fn neighbours( + &self, + node_id: NodeId, + query_graph: &QueryGraph, + ) -> Vec { + self.connections + .iter() + .filter_map(|edge_id| query_graph.get_edge(*edge_id)) + .flat_map(|edge| edge.nodes) + .filter(|&id| id != node_id) + .collect() + } +} + +pub type EdgeId = usize; + +pub struct Edge { + pub nodes: [NodeId; 2], + pub join: Join, +} + +pub struct QueryGraph { + pub(crate) nodes: VecMap, + edges: VecMap, +} + +impl QueryGraph { + pub(crate) fn new() -> Self { + Self { + nodes: VecMap::new(), + edges: VecMap::new(), + } + } + + pub(crate) fn add_node(&mut self, node_data: Arc) -> NodeId { + self.nodes.insert(Node { + plan: node_data, + connections: Vec::new(), + }) + } + + pub(crate) fn add_node_with_edge( + &mut self, + other: NodeId, + node_data: Arc, + edge_data: Join, + ) -> Option { + if self.nodes.contains_key(other) { + let new_id = self.nodes.insert(Node { + plan: node_data, + connections: Vec::new(), + }); + self.add_edge(new_id, other, edge_data); + Some(new_id) + } else { + None + } + } + + fn add_edge(&mut self, from: NodeId, to: NodeId, data: Join) -> Option { + if self.nodes.contains_key(from) && self.nodes.contains_key(to) { + let edge_id = self.edges.insert(Edge { + nodes: [from, to], + join: data, + }); + if let Some(from) = self.nodes.get_mut(from) { + from.connections.push(edge_id); + } + if let Some(to) = self.nodes.get_mut(to) { + to.connections.push(edge_id); + } + Some(edge_id) + } else { + None + } + } + + pub(crate) fn remove_node(&mut self, node_id: NodeId) -> Option> { + if let Some(node) = self.nodes.remove(node_id) { + // Remove all edges connected to this node + for edge_id in &node.connections { + if let Some(edge) = self.edges.remove(*edge_id) { + // Remove the edge from the other node's connections + for other_node_id in edge.nodes { + if other_node_id != node_id { + if let Some(other_node) = self.nodes.get_mut(other_node_id) { + other_node.connections.retain(|id| id != edge_id); + } + } + } + } + } + Some(node.plan) + } else { + None + } + } + + fn remove_edge(&mut self, edge_id: EdgeId) -> Option { + if let Some(edge) = self.edges.remove(edge_id) { + // Remove the edge from both nodes' connections + for node_id in edge.nodes { + if let Some(node) = self.nodes.get_mut(node_id) { + node.connections.retain(|id| *id != edge_id); + } + } + Some(edge.join) + } else { + None + } + } + + pub(crate) fn nodes(&self) -> impl Iterator { + self.nodes.iter() + } + + pub(crate) fn get_node(&self, key: NodeId) -> Option<&Node> { + self.nodes.get(key) + } + + pub(crate) fn get_edge(&self, key: EdgeId) -> Option<&Edge> { + self.edges.get(key) + } +} + +/// Extracts the join subtree from a logical plan, separating it from wrapper operators. +/// +/// This function traverses the plan tree from the root downward, collecting all non-join +/// operators until it finds the topmost join node. The join subtree (all consecutive joins) +/// is extracted and returned separately from the wrapper operators. +/// +/// # Arguments +/// +/// * `plan` - The logical plan to extract from +/// +/// # Returns +/// +/// Returns a tuple of (join_subtree, wrapper_operators) where: +/// - `join_subtree` is the topmost join and all joins beneath it +/// - `wrapper_operators` is a vector of non-join operators above the joins, in order from root to join +/// +/// # Errors +/// +/// Returns an error if the plan doesn't contain any joins. +pub(crate) fn extract_join_subtree( + plan: LogicalPlan, +) -> Result<(LogicalPlan, Vec)> { + let mut wrappers = Vec::new(); + let mut current = plan; + + // Descend through non-join nodes until we find a join + loop { + match current { + LogicalPlan::Join(_) => { + // Found the join subtree root + return Ok((current, wrappers)); + } + other => { + // Check if this node contains joins in its children + if !contains_join(&other) { + return plan_err!( + "Plan does not contain any join nodes: {}", + other.display() + ); + } + + // This node is a wrapper - store it and descend to its child + // For now, we only support single-child wrappers (Filter, Sort, Limit, Aggregate, etc.) + let inputs = other.inputs(); + if inputs.len() != 1 { + return plan_err!( + "Join extraction only supports single-input operators, found {} inputs in: {}", + inputs.len(), + other.display() + ); + } + + wrappers.push(other.clone()); + current = (*inputs[0]).clone(); + } + } + } +} + +/// Reconstructs a logical plan by wrapping an optimized join plan with the original wrapper operators. +/// +/// This function takes an optimized join plan and re-applies the wrapper operators (Filter, Sort, +/// Aggregate, etc.) that were removed during extraction. The wrappers are applied in reverse order +/// (innermost to outermost) to reconstruct the original plan structure. +/// +/// # Arguments +/// +/// * `join_plan` - The optimized join plan to wrap +/// * `wrappers` - Vector of wrapper operators in order from outermost to innermost (root to join) +/// +/// # Returns +/// +/// Returns the fully reconstructed logical plan with all wrapper operators reapplied. +/// +/// # Errors +/// +/// Returns an error if reconstructing any wrapper operator fails. +pub(crate) fn reconstruct_plan( + join_plan: LogicalPlan, + wrappers: Vec, +) -> Result { + let mut current = join_plan; + + // Apply wrappers in reverse order (from innermost to outermost) + for wrapper in wrappers.into_iter().rev() { + // Use with_new_exprs to reconstruct the wrapper with the new input + current = wrapper.with_new_exprs(wrapper.expressions(), vec![current])?; + } + + Ok(current) +} + +impl TryFrom for QueryGraph { + type Error = DataFusionError; + + fn try_from(value: LogicalPlan) -> Result { + // First, extract the join subtree from any wrapper operators + let (join_subtree, _wrappers) = extract_join_subtree(value)?; + + // Now convert only the join subtree to a query graph + let mut query_graph = QueryGraph::new(); + flatten_joins_recursive(join_subtree, &mut query_graph)?; + Ok(query_graph) + } +} + +fn flatten_joins_recursive( + plan: LogicalPlan, + query_graph: &mut QueryGraph, +) -> Result<()> { + match plan { + LogicalPlan::Join(join) => { + flatten_joins_recursive( + Arc::unwrap_or_clone(Arc::clone(&join.left)), + query_graph, + )?; + flatten_joins_recursive( + Arc::unwrap_or_clone(Arc::clone(&join.right)), + query_graph, + )?; + + // Process each equijoin predicate to find which nodes it connects + for (left_key, right_key) in &join.on { + // Extract column references from both join keys + let left_columns = left_key.column_refs(); + let right_columns = right_key.column_refs(); + + // Filter nodes by checking which ones contain the columns from each expression + let matching_nodes: Vec = query_graph + .nodes() + .filter_map(|(node_id, node)| { + let schema = node.plan.schema(); + // Check if this node's schema contains columns from either left or right key + let has_left = + check_all_columns_from_schema(&left_columns, schema.as_ref()) + .unwrap_or(false); + let has_right = check_all_columns_from_schema( + &right_columns, + schema.as_ref(), + ) + .unwrap_or(false); + + // Include node if it contains columns from either key (but not both, as that would be invalid) + if (has_left && !has_right) || (!has_left && has_right) { + Some(node_id) + } else { + None + } + }) + .collect(); + + // We should have exactly two nodes: one with left_key columns, one with right_key columns + if matching_nodes.len() != 2 { + return plan_err!( + "Could not find exactly two nodes for join predicate: {} = {} (found {} nodes)", + left_key, + right_key, + matching_nodes.len() + ); + } + + let node_id_a = matching_nodes[0]; + let node_id_b = matching_nodes[1]; + + // Add an edge if one doesn't exist yet + if let Some(node_a) = query_graph.get_node(node_id_a) { + if node_a.connection_with(node_id_b, query_graph).is_none() { + // No edge exists yet, create one with this join + query_graph.add_edge(node_id_a, node_id_b, join.clone()); + } + } + } + + Ok(()) + } + x => { + if contains_join(&x) { + plan_err!( + "Join reordering requires joins to be consecutive in the plan tree. \ + Found a non-join node that contains nested joins: {}", + x.display() + ) + } else { + query_graph.add_node(Arc::new(x)); + Ok(()) + } + } + } +} + +/// Checks if a LogicalPlan contains any join nodes +/// +/// Uses a TreeNode visitor to traverse the plan tree and detect the presence +/// of any `LogicalPlan::Join` nodes. +/// +/// # Arguments +/// +/// * `plan` - The logical plan to check +/// +/// # Returns +/// +/// `true` if the plan contains at least one join node, `false` otherwise +pub(crate) fn contains_join(plan: &LogicalPlan) -> bool { + let mut has_join = false; + + // Use TreeNode's apply method to traverse the plan + let _ = plan.apply(|node| { + if matches!(node, LogicalPlan::Join(_)) { + has_join = true; + // Stop traversal once we find a join + Ok(TreeNodeRecursion::Stop) + } else { + // Continue traversal + Ok(TreeNodeRecursion::Continue) + } + }); + + has_join +} + +/// A simple Vec-based map that uses Option for sparse storage +/// Keys are never reused once removed +pub(crate) struct VecMap(Vec>); + +impl VecMap { + pub(crate) fn new() -> Self { + Self(Vec::new()) + } + + pub(crate) fn insert(&mut self, value: V) -> usize { + let idx = self.0.len(); + self.0.push(Some(value)); + idx + } + + pub(crate) fn get(&self, key: usize) -> Option<&V> { + self.0.get(key)?.as_ref() + } + + pub(crate) fn get_mut(&mut self, key: usize) -> Option<&mut V> { + self.0.get_mut(key)?.as_mut() + } + + pub(crate) fn remove(&mut self, key: usize) -> Option { + self.0.get_mut(key)?.take() + } + + pub(crate) fn contains_key(&self, key: usize) -> bool { + self.0.get(key).and_then(|v| v.as_ref()).is_some() + } + + pub(crate) fn iter(&self) -> impl Iterator { + self.0 + .iter() + .enumerate() + .filter_map(|(idx, slot)| slot.as_ref().map(|v| (idx, v))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use datafusion_expr::logical_plan::JoinType; + use datafusion_expr::{col, LogicalPlanBuilder}; + + /// Test converting a three-way join with filter into a QueryGraph + #[test] + fn test_try_from_three_way_join_with_filter() -> Result<(), DataFusionError> { + // Create three-way join: customer JOIN orders JOIN lineitem + // with a filter on the orders-lineitem join + let customer = scan_tpch_table("customer"); + let orders = scan_tpch_table("orders"); + let lineitem = scan_tpch_table("lineitem"); + + let plan = LogicalPlanBuilder::from(customer.clone()) + .join( + orders.clone(), + JoinType::Inner, + (vec!["c_custkey"], vec!["o_custkey"]), + None, + ) + .unwrap() + .join_with_expr_keys( + lineitem.clone(), + JoinType::Inner, + (vec![col("o_orderkey")], vec![col("l_orderkey")]), + Some(col("l_quantity").gt(datafusion_expr::lit(10.0))), + ) + .unwrap() + .build() + .unwrap(); + + // Convert to QueryGraph + let query_graph = QueryGraph::try_from(plan)?; + + // Verify structure: 3 nodes, 2 edges + assert_eq!(query_graph.nodes().count(), 3); + assert_eq!(query_graph.edges.iter().count(), 2); + + // Verify connectivity: one node has 2 connections (orders), two nodes have 1 + let mut connections: Vec = query_graph + .nodes() + .map(|(_, node)| node.connections().len()) + .collect(); + connections.sort(); + assert_eq!(connections, vec![1, 1, 2]); + + // Verify edges have correct join predicates + let edges: Vec<&Edge> = query_graph.edges.iter().map(|(_, e)| e).collect(); + + // One edge should have c_custkey = o_custkey + let has_customer_orders = edges.iter().any(|e| { + e.join.on.iter().any(|(l, r)| { + let s = format!("{l}{r}"); + s.contains("c_custkey") && s.contains("o_custkey") + }) + }); + assert!(has_customer_orders, "Missing customer-orders join"); + + // One edge should have o_orderkey = l_orderkey with a filter + let has_orders_lineitem = edges.iter().any(|e| { + e.join.on.iter().any(|(l, r)| { + let s = format!("{l}{r}"); + s.contains("o_orderkey") && s.contains("l_orderkey") + }) && e.join.filter.is_some() + }); + assert!( + has_orders_lineitem, + "Missing orders-lineitem join with filter" + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 65a210826664..39968d1a8fd4 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -1087,7 +1087,7 @@ mod tests { #[test] fn case_test_incompatible() -> Result<()> { - // 1 then is int64 + // 1 then is float64 // 2 then is boolean let batch = case_test_batch()?; let schema = batch.schema(); @@ -1099,7 +1099,7 @@ mod tests { lit("foo"), &batch.schema(), )?; - let then1 = lit(123i32); + let then1 = lit(1.23f64); let when2 = binary( col("a", &schema)?, Operator::Eq, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 4f411a4a9332..70d6caf7642b 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -62,6 +62,7 @@ message LogicalPlanNode { RecursiveQueryNode recursive_query = 31; CteWorkTableScanNode cte_work_table_scan = 32; DmlNode dml = 33; + PivotNode pivot = 34; } } @@ -1364,4 +1365,15 @@ message SortMergeJoinExecNode { JoinFilter filter = 5; repeated SortExprNode sort_options = 6; datafusion_common.NullEquality null_equality = 7; -} \ No newline at end of file +} + +message PivotNode { + LogicalPlanNode input = 1; + LogicalExprNode aggregate_expr = 2; + datafusion_common.Column pivot_column = 3; + repeated datafusion_common.ScalarValue pivot_values = 4; + datafusion_common.DfSchema schema = 5; + LogicalPlanNode value_subquery = 6; + LogicalExprNode default_on_null_expr = 7; + +} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index ff7519aa5df2..83f662e61112 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -12055,6 +12055,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::Dml(v) => { struct_ser.serialize_field("dml", v)?; } + logical_plan_node::LogicalPlanType::Pivot(v) => { + struct_ser.serialize_field("pivot", v)?; + } } } struct_ser.end() @@ -12114,6 +12117,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "cte_work_table_scan", "cteWorkTableScan", "dml", + "pivot", ]; #[allow(clippy::enum_variant_names)] @@ -12150,6 +12154,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { RecursiveQuery, CteWorkTableScan, Dml, + Pivot, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12203,6 +12208,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "recursiveQuery" | "recursive_query" => Ok(GeneratedField::RecursiveQuery), "cteWorkTableScan" | "cte_work_table_scan" => Ok(GeneratedField::CteWorkTableScan), "dml" => Ok(GeneratedField::Dml), + "pivot" => Ok(GeneratedField::Pivot), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12447,6 +12453,13 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("dml")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Dml) +; + } + GeneratedField::Pivot => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("pivot")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Pivot) ; } } @@ -18311,6 +18324,204 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { deserializer.deserialize_struct("datafusion.PhysicalWindowExprNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PivotNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.aggregate_expr.is_some() { + len += 1; + } + if self.pivot_column.is_some() { + len += 1; + } + if !self.pivot_values.is_empty() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + if self.value_subquery.is_some() { + len += 1; + } + if self.default_on_null_expr.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PivotNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.aggregate_expr.as_ref() { + struct_ser.serialize_field("aggregateExpr", v)?; + } + if let Some(v) = self.pivot_column.as_ref() { + struct_ser.serialize_field("pivotColumn", v)?; + } + if !self.pivot_values.is_empty() { + struct_ser.serialize_field("pivotValues", &self.pivot_values)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if let Some(v) = self.value_subquery.as_ref() { + struct_ser.serialize_field("valueSubquery", v)?; + } + if let Some(v) = self.default_on_null_expr.as_ref() { + struct_ser.serialize_field("defaultOnNullExpr", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PivotNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "aggregate_expr", + "aggregateExpr", + "pivot_column", + "pivotColumn", + "pivot_values", + "pivotValues", + "schema", + "value_subquery", + "valueSubquery", + "default_on_null_expr", + "defaultOnNullExpr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + AggregateExpr, + PivotColumn, + PivotValues, + Schema, + ValueSubquery, + DefaultOnNullExpr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "aggregateExpr" | "aggregate_expr" => Ok(GeneratedField::AggregateExpr), + "pivotColumn" | "pivot_column" => Ok(GeneratedField::PivotColumn), + "pivotValues" | "pivot_values" => Ok(GeneratedField::PivotValues), + "schema" => Ok(GeneratedField::Schema), + "valueSubquery" | "value_subquery" => Ok(GeneratedField::ValueSubquery), + "defaultOnNullExpr" | "default_on_null_expr" => Ok(GeneratedField::DefaultOnNullExpr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PivotNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PivotNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut aggregate_expr__ = None; + let mut pivot_column__ = None; + let mut pivot_values__ = None; + let mut schema__ = None; + let mut value_subquery__ = None; + let mut default_on_null_expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::AggregateExpr => { + if aggregate_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("aggregateExpr")); + } + aggregate_expr__ = map_.next_value()?; + } + GeneratedField::PivotColumn => { + if pivot_column__.is_some() { + return Err(serde::de::Error::duplicate_field("pivotColumn")); + } + pivot_column__ = map_.next_value()?; + } + GeneratedField::PivotValues => { + if pivot_values__.is_some() { + return Err(serde::de::Error::duplicate_field("pivotValues")); + } + pivot_values__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::ValueSubquery => { + if value_subquery__.is_some() { + return Err(serde::de::Error::duplicate_field("valueSubquery")); + } + value_subquery__ = map_.next_value()?; + } + GeneratedField::DefaultOnNullExpr => { + if default_on_null_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("defaultOnNullExpr")); + } + default_on_null_expr__ = map_.next_value()?; + } + } + } + Ok(PivotNode { + input: input__, + aggregate_expr: aggregate_expr__, + pivot_column: pivot_column__, + pivot_values: pivot_values__.unwrap_or_default(), + schema: schema__, + value_subquery: value_subquery__, + default_on_null_expr: default_on_null_expr__, + }) + } + } + deserializer.deserialize_struct("datafusion.PivotNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PlaceholderNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ffb73086650f..cc19add6fbe9 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -5,7 +5,7 @@ pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34" )] pub logical_plan_type: ::core::option::Option, } @@ -77,6 +77,8 @@ pub mod logical_plan_node { CteWorkTableScan(super::CteWorkTableScanNode), #[prost(message, tag = "33")] Dml(::prost::alloc::boxed::Box), + #[prost(message, tag = "34")] + Pivot(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -2067,6 +2069,25 @@ pub struct SortMergeJoinExecNode { #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] pub null_equality: i32, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PivotNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub aggregate_expr: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub pivot_column: ::core::option::Option, + #[prost(message, repeated, tag = "4")] + pub pivot_values: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "5")] + pub schema: ::core::option::Option, + #[prost(message, optional, boxed, tag = "6")] + pub value_subquery: ::core::option::Option< + ::prost::alloc::boxed::Box, + >, + #[prost(message, optional, tag = "7")] + pub default_on_null_expr: ::core::option::Option, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum WindowFrameUnits { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index cc3e805ed1df..98843166932e 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -65,7 +65,7 @@ use datafusion_expr::{ logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, DdlStatement, Distinct, EmptyRelation, - Extension, Join, JoinConstraint, Prepare, Projection, Repartition, Sort, + Extension, Join, JoinConstraint, Pivot, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, Window, }, DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, @@ -970,6 +970,61 @@ impl AsLogicalPlan for LogicalPlanNode { Arc::new(into_logical_plan!(dml_node.input, ctx, extension_codec)?), ), )), + LogicalPlanType::Pivot(pivot) => { + let aggregate_expr = pivot + .aggregate_expr + .as_ref() + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) + .transpose()? + .ok_or_else(|| { + DataFusionError::Internal("aggregate_expr required".to_string()) + })?; + let pivot_column = pivot + .pivot_column + .as_ref() + .map(|col| col.clone().into()) + .ok_or_else(|| { + DataFusionError::Internal("pivot_column required".to_string()) + })?; + let pivot_values = pivot + .pivot_values + .iter() + .map(|val| val.try_into()) + .collect::, _>>( + )?; + let schema = Arc::new(convert_required!(pivot.schema)?); + let value_subquery = if pivot.value_subquery.is_some() { + Some(Arc::new(into_logical_plan!( + pivot.value_subquery, + ctx, + extension_codec + )?)) + } else { + None + }; + let default_on_null_expr = if pivot.default_on_null_expr.is_some() { + pivot + .default_on_null_expr + .as_ref() + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) + .transpose()? + } else { + None + }; + Ok(LogicalPlan::Pivot(Pivot { + input: Arc::new(into_logical_plan!( + pivot.input, + ctx, + extension_codec + )?), + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + })) + } } } @@ -1787,6 +1842,9 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::Pivot(_) => Err(proto_error( + "LogicalPlan serde is not yet implemented for Statement", + )), } } } diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 7e11f160a397..dc5dafb5de4b 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -51,7 +51,7 @@ pub mod resolve; mod select; mod set_expr; mod stack; -mod statement; +pub mod statement; #[cfg(feature = "unparser")] pub mod unparser; pub mod utils; diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 2c673162ec9c..1abd718bf31b 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -489,6 +489,12 @@ impl<'a> DFParser<'a> { Token::Word(w) => { match w.keyword { Keyword::CREATE => { + if let Token::Word(w) = self.parser.peek_nth_token(2).token { + // use native parser for CREATE EXTERNAL VOLUME + if w.keyword == Keyword::VOLUME { + return self.parse_and_handle_statement(); + } + } self.parser.next_token(); // CREATE self.parse_create() } @@ -730,7 +736,11 @@ impl<'a> DFParser<'a> { self.parser.expect_keyword(Keyword::EXTERNAL)?; self.parse_create_external_table(true) } else { - Ok(Statement::Statement(Box::from(self.parser.parse_create()?))) + // Push back CREATE + self.parser.prev_token(); + Ok(Statement::Statement(Box::from( + self.parser.parse_statement()?, + ))) } } @@ -1094,6 +1104,26 @@ mod tests { } } + #[test] + fn skip_create_stage_snowflake() -> Result<(), DataFusionError> { + let sql = + "CREATE OR REPLACE STAGE stage URL='s3://data.csv' FILE_FORMAT=(TYPE=csv)"; + let dialect = Box::new(SnowflakeDialect); + let statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?; + + assert_eq!( + statements.len(), + 1, + "Expected to parse exactly one statement" + ); + match &statements[0] { + Statement::Statement(stmt) => { + assert_eq!(stmt.to_string(), sql); + } + _ => panic!("Expected statement type"), + } + Ok(()) + } #[test] fn create_external_table() -> Result<(), DataFusionError> { // positive case @@ -1600,6 +1630,24 @@ mod tests { Ok(()) } + #[test] + fn skip_external_volume() -> Result<(), DataFusionError> { + let sql = "CREATE OR REPLACE EXTERNAL VOLUME exvol STORAGE_LOCATIONS = + ((NAME = 's3' STORAGE_PROVIDER = 'S3' STORAGE_BASE_URL = 's3://my-example-bucket/' ))"; + let dialect = Box::new(SnowflakeDialect); + let statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?; + + assert_eq!( + statements.len(), + 1, + "Expected to parse exactly one statement" + ); + if let Statement::CreateExternalTable(_) = &statements[0] { + panic!("Expected non CREATE EXTERNAL TABLE statement, but was successful: {statements:?}"); + } + Ok(()) + } + #[test] fn explain_copy_to_table_to_table() -> Result<(), DataFusionError> { let cases = vec![ diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index fc678a8f8711..3332a45c511a 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -444,7 +444,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } /// Returns a vector of (column_name, default_expr) pairs - pub(super) fn build_column_defaults( + pub fn build_column_defaults( &self, columns: &Vec, planner_context: &mut PlannerContext, @@ -587,7 +587,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) } - pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { + pub fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { // First check if any of the registered type_planner can handle this type if let Some(type_planner) = self.context_provider.get_type_planner() { if let Some(data_type) = type_planner.plan_type(sql_type)? { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index aa37d74fd4d8..dd5e3aaa55fd 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -21,12 +21,14 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - not_impl_err, plan_err, DFSchema, Diagnostic, Result, Span, Spans, TableReference, + not_impl_err, plan_err, Column, DFSchema, Diagnostic, Result, ScalarValue, Span, + Spans, TableReference, }; +use datafusion_expr::binary::comparison_coercion; use datafusion_expr::builder::subquery_alias; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::{Subquery, SubqueryAlias}; -use sqlparser::ast::{FunctionArg, FunctionArgExpr, Spanned, TableFactor}; +use sqlparser::ast::{FunctionArg, FunctionArgExpr, NullInclusion, Spanned, TableFactor}; mod join; @@ -56,6 +58,21 @@ impl SqlToRel<'_, S> { &DFSchema::empty(), planner_context, ) + .map(|expr| (expr, None)) + } else if let FunctionArg::Named { name, arg, .. } = arg { + if let FunctionArgExpr::Expr(expr) = arg { + self.sql_expr_to_logical_expr( + expr, + &DFSchema::empty(), + planner_context, + ) + .map(|expr| (expr, Some(name.to_string()))) + } else { + plan_err!( + "Unsupported function argument type: {:?}", + arg + ) + } } else { plan_err!("Unsupported function argument type: {:?}", arg) } @@ -165,25 +182,266 @@ impl SqlToRel<'_, S> { let func_args = args .into_iter() .map(|arg| match arg { - FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) - | FunctionArg::Named { + FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => { + // No alias + self.sql_expr_to_logical_expr(expr, &schema, planner_context) + .map(|e| (e, None)) + } + FunctionArg::Named { + name, arg: FunctionArgExpr::Expr(expr), .. } => { + // Named arg → alias comes from `name.value` self.sql_expr_to_logical_expr(expr, &schema, planner_context) + .map(|e| (e, Some(name.value))) } _ => plan_err!("Unsupported function argument: {arg:?}"), }) - .collect::>>()?; + .collect::)>>>()?; + let provider = self .context_provider .get_table_function_source(tbl_func_ref.table(), func_args)?; + let plan = LogicalPlanBuilder::scan(tbl_func_ref.table(), provider, None)? .build()?; (plan, alias) } // @todo Support TableFactory::TableFunction? + TableFactor::Pivot { + table, + aggregate_functions, + value_column, + value_source, + default_on_null, + alias, + } => { + let input_plan = self.create_relation(*table, planner_context)?; + + if aggregate_functions.len() != 1 { + return plan_err!("PIVOT requires exactly one aggregate function"); + } + + let agg_expr = self.sql_expr_to_logical_expr( + aggregate_functions[0].expr.clone(), + input_plan.schema(), + planner_context, + )?; + + if value_column.is_empty() { + return plan_err!("PIVOT value column is required"); + } + + let column_name = value_column.last().unwrap().value.clone(); + let pivot_column = Column::new(None::<&str>, column_name); + + let default_on_null_expr = default_on_null + .map(|expr| { + self.sql_expr_to_logical_expr( + expr, + input_plan.schema(), // Default expression should be context-independent or use input schema + planner_context, + ) + }) + .transpose()?; + + match value_source { + sqlparser::ast::PivotValueSource::List(exprs) => { + let pivot_values = exprs + .iter() + .map(|expr| { + let logical_expr = self.sql_expr_to_logical_expr( + expr.expr.clone(), + input_plan.schema(), + planner_context, + )?; + + match logical_expr { + Expr::Literal(scalar, _) => Ok(scalar), + _ => plan_err!("PIVOT values must be literals"), + } + }) + .collect::>>()?; + + let input_arc = Arc::new(input_plan); + + let pivot_plan = datafusion_expr::Pivot::try_new( + input_arc, + agg_expr, + pivot_column, + pivot_values, + default_on_null_expr.clone(), + )?; + + (LogicalPlan::Pivot(pivot_plan), alias) + } + sqlparser::ast::PivotValueSource::Any(order_by) => { + let input_arc = Arc::new(input_plan); + + let mut subquery_builder = + LogicalPlanBuilder::from(input_arc.as_ref().clone()) + .project(vec![Expr::Column(pivot_column.clone())])? + .distinct()?; + + if !order_by.is_empty() { + let sort_exprs = order_by + .iter() + .map(|item| { + let input_schema = subquery_builder.schema(); + + let expr = self.sql_expr_to_logical_expr( + item.expr.clone(), + input_schema, + planner_context, + ); + + expr.map(|e| { + e.sort( + item.options.asc.unwrap_or(true), + item.options.nulls_first.unwrap_or(false), + ) + }) + }) + .collect::>>()?; + + subquery_builder = subquery_builder.sort(sort_exprs)?; + } + + let subquery_plan = subquery_builder.build()?; + + let pivot_plan = datafusion_expr::Pivot::try_new_with_subquery( + input_arc, + agg_expr, + pivot_column, + Arc::new(subquery_plan), + default_on_null_expr.clone(), + )?; + + (LogicalPlan::Pivot(pivot_plan), alias) + } + sqlparser::ast::PivotValueSource::Subquery(subquery) => { + let subquery_plan = + self.query_to_plan(*subquery.clone(), planner_context)?; + + let input_arc = Arc::new(input_plan); + + let pivot_plan = datafusion_expr::Pivot::try_new_with_subquery( + input_arc, + agg_expr, + pivot_column, + Arc::new(subquery_plan), + default_on_null_expr.clone(), + )?; + + (LogicalPlan::Pivot(pivot_plan), alias) + } + } + } + TableFactor::Unpivot { + table, + null_inclusion, + value, + name, + columns, + alias, + } => { + let base_plan = self.create_relation(*table, planner_context)?; + let base_schema = base_plan.schema(); + + let value_column = value.value.clone(); + let name_column = name.value.clone(); + + let mut unpivot_column_indices = Vec::new(); + let mut unpivot_column_names = Vec::new(); + + let mut common_type = None; + + for column_ident in &columns { + let column_name = column_ident.value.clone(); + + let idx = if let Some(i) = + base_schema.index_of_column_by_name(None, &column_name) + { + i + } else { + return plan_err!("Column '{}' not found in input", column_name); + }; + + let field = base_schema.field(idx); + let field_type = field.data_type(); + + // Verify all unpivot columns have compatible types + if let Some(current_type) = &common_type { + if comparison_coercion(current_type, field_type).is_none() { + return plan_err!( + "The type of column '{}' conflicts with the type of other columns in the UNPIVOT list.", + column_name.to_uppercase() + ); + } + } else { + common_type = Some(field_type.clone()); + } + + unpivot_column_indices.push(idx); + unpivot_column_names.push(column_name); + } + + if unpivot_column_names.is_empty() { + return plan_err!("UNPIVOT requires at least one column to unpivot"); + } + + let non_pivot_exprs: Vec = base_schema + .fields() + .iter() + .enumerate() + .filter(|(i, _)| !unpivot_column_indices.contains(i)) + .map(|(_, f)| Expr::Column(Column::new(None::<&str>, f.name()))) + .collect(); + + let mut union_inputs = Vec::with_capacity(unpivot_column_names.len()); + + for col_name in &unpivot_column_names { + let mut projection_exprs = non_pivot_exprs.clone(); + + let name_expr = Expr::Literal( + ScalarValue::Utf8(Some(col_name.to_uppercase())), + None, + ) + .alias(name_column.clone()); + + let value_expr = + Expr::Column(Column::new(None::<&str>, col_name.clone())) + .alias(value_column.clone()); + + projection_exprs.push(name_expr); + projection_exprs.push(value_expr); + + let mut builder = LogicalPlanBuilder::from(base_plan.clone()) + .project(projection_exprs)?; + + if let Some(NullInclusion::ExcludeNulls) | None = null_inclusion { + let col = Column::new(None::<&str>, value_column.clone()); + builder = builder + .filter(Expr::IsNotNull(Box::new(Expr::Column(col))))?; + } + + union_inputs.push(builder.build()?); + } + + let first = union_inputs.remove(0); + let mut union_builder = LogicalPlanBuilder::from(first); + + for plan in union_inputs { + union_builder = union_builder.union(plan)?; + } + + let unpivot_plan = union_builder.build()?; + + (unpivot_plan, alias) + } + // @todo: Support TableFactory::TableFunction _ => { return not_impl_err!( "Unsupported ast node {relation:?} in create_relation" diff --git a/datafusion/sql/src/resolve.rs b/datafusion/sql/src/resolve.rs index 9e909f66fa97..ea3dbdd92d0c 100644 --- a/datafusion/sql/src/resolve.rs +++ b/datafusion/sql/src/resolve.rs @@ -117,6 +117,9 @@ impl Visitor for RelationVisitor { | Statement::ShowColumns { .. } | Statement::ShowTables { .. } | Statement::ShowCollation { .. } + | Statement::ShowSchemas { .. } + | Statement::ShowDatabases { .. } + | Statement::ShowObjects { .. } ); if requires_information_schema { for s in INFORMATION_SCHEMA_TABLES { @@ -172,17 +175,17 @@ fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) { /// assert_eq!(ctes.len(), 0); /// ``` /// -/// ## Example with CTEs -/// -/// ``` -/// # use datafusion_sql::parser::DFParser; +/// ## Example with CTEs +/// +/// ``` +/// # use datafusion_sql::parser::DFParser; /// # use datafusion_sql::resolve::resolve_table_references; -/// let query = "with my_cte as (values (1), (2)) SELECT * from my_cte;"; -/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); -/// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); +/// let query = "with my_cte as (values (1), (2)) SELECT * from my_cte;"; +/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); +/// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); /// assert_eq!(table_refs.len(), 0); -/// assert_eq!(ctes.len(), 1); -/// assert_eq!(ctes[0].to_string(), "my_cte"); +/// assert_eq!(ctes.len(), 1); +/// assert_eq!(ctes[0].to_string(), "my_cte"); /// ``` pub fn resolve_table_references( statement: &crate::parser::Statement, diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 7f94fce7bd9f..4c0cefda4cdb 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -72,7 +72,7 @@ fn ident_to_string(ident: &Ident) -> String { normalize_ident(ident.to_owned()) } -fn object_name_to_string(object_name: &ObjectName) -> String { +pub fn object_name_to_string(object_name: &ObjectName) -> String { object_name .0 .iter() @@ -101,7 +101,9 @@ fn get_schema_name(schema_name: &SchemaName) -> String { /// Construct `TableConstraint`(s) for the given columns by iterating over /// `columns` and extracting individual inline constraint definitions. -fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { +pub fn calc_inline_constraints_from_columns( + columns: &[ColumnDef], +) -> Vec { let mut constraints = vec![]; for column in columns { for ast::ColumnOptionDef { name, option } in &column.options { @@ -535,6 +537,7 @@ impl SqlToRel<'_, S> { to, params, or_alter, + secure, } => { if materialized { return not_impl_err!("Materialized views not supported")?; @@ -572,6 +575,7 @@ impl SqlToRel<'_, S> { to, params, or_alter, + secure, }; let sql = stmt.to_string(); let Statement::CreateView { @@ -2280,7 +2284,7 @@ ON p.function_name = r.routine_name } /// Return true if there is a table provider available for "schema.table" - fn has_table(&self, schema: &str, table: &str) -> bool { + pub fn has_table(&self, schema: &str, table: &str) -> bool { let tables_reference = TableReference::Partial { schema: schema.into(), table: table.into(), diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 3826ef9feab2..ffca82229c48 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -118,6 +118,7 @@ impl Unparser<'_> { LogicalPlan::Extension(extension) => { self.extension_to_statement(extension.node.as_ref()) } + LogicalPlan::Pivot(_) => not_impl_err!("Unsupported plan Pivot: {plan:?}"), LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) | LogicalPlan::Ddl(_) diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index c18c70d16e3f..c311ecbc1898 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -281,10 +281,7 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr /// Returns a validated `DataType` for the specified precision and /// scale -pub(crate) fn make_decimal_type( - precision: Option, - scale: Option, -) -> Result { +pub fn make_decimal_type(precision: Option, scale: Option) -> Result { // postgres like behavior let (precision, scale) = match (precision, scale) { (Some(p), Some(s)) => (p as u8, s as i8), @@ -312,7 +309,7 @@ pub(crate) fn make_decimal_type( } /// Normalize an owned identifier to a lowercase string, unless the identifier is quoted. -pub(crate) fn normalize_ident(id: Ident) -> String { +pub fn normalize_ident(id: Ident) -> String { match id.quote_style { Some(_) => id.value, None => id.value.to_ascii_lowercase(), diff --git a/datafusion/sqllogictest/test_files/pivot.slt b/datafusion/sqllogictest/test_files/pivot.slt new file mode 100644 index 000000000000..cd80138633ee --- /dev/null +++ b/datafusion/sqllogictest/test_files/pivot.slt @@ -0,0 +1,444 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####### +# Setup test data table +####### +statement ok +CREATE TABLE quarterly_sales( + empid INT, + amount INT, + quarter TEXT) + AS SELECT * FROM VALUES + (1, 10000, '2023_Q1'), + (1, 400, '2023_Q1'), + (2, 4500, '2023_Q1'), + (2, 35000, '2023_Q1'), + (1, 5000, '2023_Q2'), + (1, 3000, '2023_Q2'), + (2, 200, '2023_Q2'), + (2, 90500, '2023_Q2'), + (1, 6000, '2023_Q3'), + (1, 5000, '2023_Q3'), + (2, 2500, '2023_Q3'), + (2, 9500, '2023_Q3'), + (3, 2700, '2023_Q3'), + (1, 8000, '2023_Q4'), + (1, 10000, '2023_Q4'), + (2, 800, '2023_Q4'), + (2, 4500, '2023_Q4'), + (3, 2700, '2023_Q4'), + (3, 16000, '2023_Q4'), + (3, 10200, '2023_Q4'); + +query IIIII +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q3', '2023_Q4')) +ORDER BY empid; +---- +1 10400 8000 11000 18000 +2 39500 90700 12000 5300 +3 NULL NULL 2700 28900 + +# PIVOT with NULL handling +query III +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) + FOR quarter IN ('2023_Q1', '2023_Q2') + DEFAULT ON NULL (1001)) +ORDER BY empid; +---- +1 10400 8000 +2 39500 90700 +3 1001 1001 + +# PIVOT with cast to pivot column type +query TIII +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) FOR empid IN (1,2,3)) +ORDER BY quarter; +---- +2023_Q1 10400 39500 NULL +2023_Q2 8000 90700 NULL +2023_Q3 11000 12000 2700 +2023_Q4 18000 5300 28900 + + +# PIVOT with automatic detection of all distinct column values using ANY +query TIII +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) FOR empid IN (ANY ORDER BY empid)) +ORDER BY quarter; +---- +2023_Q1 10400 39500 NULL +2023_Q2 8000 90700 NULL +2023_Q3 11000 12000 2700 +2023_Q4 18000 5300 28900 + +# PIVOT with ANY that includes output column reordering +query IIIII +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) FOR quarter IN (ANY ORDER BY quarter DESC)) +ORDER BY empid; +---- +1 18000 11000 8000 10400 +2 5300 12000 90700 39500 +3 28900 2700 NULL NULL + +# PIVOT with a subquery to specify the values +query III +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) + FOR quarter IN ( + SELECT DISTINCT quarter FROM quarterly_sales WHERE quarter LIKE '%Q1' OR quarter LIKE '%Q3' + )) +ORDER BY empid; +---- +1 10400 11000 +2 39500 12000 +3 NULL 2700 + +query IIIII +WITH sales_without_discount AS + (SELECT empid, amount, quarter FROM quarterly_sales) +SELECT * +FROM sales_without_discount +PIVOT(SUM(amount) FOR quarter IN (ANY ORDER BY quarter)) +ORDER BY empid; +---- +1 10400 8000 11000 18000 +2 39500 90700 12000 5300 +3 NULL NULL 2700 28900 + + +# Non-existent column in the FOR clause +query error DataFusion error: Error during planning: Pivot column 'non_existent_column' does not exist in input schema +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) FOR non_existent_column IN ('2023_Q1', '2023_Q2')) +ORDER BY empid; + +# Non-existent column in the aggregate function +query error DataFusion error: Schema error: No field named non_existent_column\. Valid fields are quarterly_sales\.empid, quarterly_sales\.amount, quarterly_sales\.quarter\. +SELECT * +FROM quarterly_sales +PIVOT(SUM(non_existent_column) FOR quarter IN ('2023_Q1', '2023_Q2')) +ORDER BY empid; + +# Trying to use non-aggregate function +query error DataFusion error: Error during planning: Unsupported aggregate expression should always be AggregateFunction +SELECT * +FROM quarterly_sales +PIVOT(ABS(amount) FOR quarter IN ('2023_Q1', '2023_Q2')) +ORDER BY empid; + +# Invalid subquery in the IN list - multiple columns +query error DataFusion error: Error during planning: Pivot subquery must return a single column +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) + FOR quarter IN (SELECT quarter, empid FROM quarterly_sales LIMIT 2)) +ORDER BY empid; + +# Invalid DEFAULT ON NULL value (dependent on pivot/aggregation columns) +query error DataFusion error: Schema error: No field named quarterly_sales\.amount\. Valid fields are quarterly_sales\.empid, "2023_Q1", "2023_Q2"\. +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) + FOR quarter IN ('2023_Q1', '2023_Q2') + DEFAULT ON NULL (amount)) +ORDER BY empid; + +# PIVOT after a PIVOT +query error DataFusion error: Schema error: No field named empid\. Valid fields are "2023_Q2", "0", "10000", "20000", "2023_Q2", "0", "10000", "20000"\. +SELECT * +FROM ( + SELECT * + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2')) +) +PIVOT(AVG(empid) FOR "2023_Q1" IN (0, 10000, 20000)) +ORDER BY empid; + +# PIVOT with window functions in the pivot expression +query error DataFusion error: Schema error: No field named empid\. Valid fields are "2023_Q1", "2023_Q2", "2023_Q1", "2023_Q2"\. +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) OVER (PARTITION BY empid) FOR quarter IN ('2023_Q1', '2023_Q2')) +ORDER BY empid; + +# PIVOT with ORDER BY in the aggregate function +query error DataFusion error: Schema error: No field named empid\. Valid fields are "2023_Q1", "2023_Q2", "2023_Q1", "2023_Q2"\. +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount ORDER BY empid) FOR quarter IN ('2023_Q1', '2023_Q2')) +ORDER BY empid; + +statement ok +CREATE TABLE employees( + empid INT, + name TEXT, + department TEXT, + hire_date DATE) + AS SELECT * FROM VALUES + (1, 'Alice', 'Sales', '2020-01-15'), + (2, 'Bob', 'Sales', '2021-03-10'), + (3, 'Charlie', 'Marketing', '2022-06-22'), + (4, 'David', 'Engineering', '2019-11-08'), + (5, 'Eve', 'Marketing', '2023-02-01'); + +statement ok +CREATE TABLE product_sales( + product_id INT, + category TEXT, + sale_amount INT, + sale_date DATE) + AS SELECT * FROM VALUES + (101, 'Electronics', 1200, '2023-01-10'), + (102, 'Clothing', 500, '2023-01-15'), + (103, 'Home', 800, '2023-01-20'), + (104, 'Electronics', 1500, '2023-02-05'), + (105, 'Clothing', 600, '2023-02-12'), + (106, 'Home', 900, '2023-02-25'), + (107, 'Electronics', 2000, '2023-03-08'), + (108, 'Clothing', 700, '2023-03-15'), + (109, 'Home', 1100, '2023-03-22'), + (110, 'Electronics', 1800, '2023-04-05'), + (111, 'Clothing', 550, '2023-04-14'), + (112, 'Home', 950, '2023-04-28'); + +query TIIIII +SELECT e.name, s.* +FROM employees e +JOIN ( + SELECT empid, "2023_Q1", "2023_Q2", "2023_Q3", "2023_Q4" + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q3', '2023_Q4')) +) s ON e.empid = s.empid +ORDER BY e.empid; +---- +Alice 1 10400 8000 11000 18000 +Bob 2 39500 90700 12000 5300 +Charlie 3 NULL NULL 2700 28900 + +# PIVOT with filtered subquery +query III +SELECT * +FROM ( + SELECT empid, amount, quarter + FROM quarterly_sales + WHERE amount > 5000 +) +PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q4')) +ORDER BY empid; +---- +1 10000 18000 +2 35000 NULL +3 NULL 26200 + +query TII +SELECT + category, + "Q1", + "Q2" +FROM ( + SELECT + category, + CASE + WHEN EXTRACT(QUARTER FROM sale_date) = 1 THEN 'Q1' + WHEN EXTRACT(QUARTER FROM sale_date) = 2 THEN 'Q2' + END AS quarter, + sale_amount + FROM product_sales +) +PIVOT( + SUM(sale_amount) AS total + FOR quarter IN ('Q1' AS "Q1", 'Q2' AS "Q2") +) +ORDER BY category; +---- +Clothing 1800 550 +Electronics 4700 1800 +Home 2800 950 + +# PIVOT with arithmetic operations on the aggregated values +query TIIIR +SELECT + e.name, + p."2023_Q1", + p."2023_Q4", + p."2023_Q4" - p."2023_Q1" AS q4_minus_q1, + CASE + WHEN p."2023_Q1" = 0 THEN NULL + ELSE (p."2023_Q4" - p."2023_Q1")*100.0/p."2023_Q1" + END AS percent_change +FROM employees e +LEFT JOIN ( + SELECT empid, "2023_Q1", "2023_Q4" + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q4')) +) p ON e.empid = p.empid +ORDER BY e.name; +---- +Alice 10400 18000 7600 73.076923076923 +Bob 39500 5300 -34200 -86.582278481013 +Charlie NULL 28900 NULL NULL +David NULL NULL NULL NULL +Eve NULL NULL NULL NULL + +# PIVOT with HAVING clause +query TII +WITH dept_pivot AS ( + SELECT + e.department, + q."2023_Q1", + q."2023_Q4" + FROM employees e + LEFT JOIN ( + SELECT empid, "2023_Q1", "2023_Q4" + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q4')) + ) q ON e.empid = q.empid +) +SELECT department, SUM("2023_Q1") as q1_total, SUM("2023_Q4") as q4_total +FROM dept_pivot +GROUP BY department +HAVING SUM("2023_Q4") > 0 +ORDER BY department; +---- +Marketing NULL 28900 +Sales 49900 23300 + +# PIVOT with CASE expressions for custom grouping +query III +SELECT * +FROM ( + SELECT + empid, + amount, + CASE + WHEN quarter IN ('2023_Q1', '2023_Q2') THEN 'H1' + WHEN quarter IN ('2023_Q3', '2023_Q4') THEN 'H2' + END AS half_year + FROM quarterly_sales +) +PIVOT(SUM(amount) FOR half_year IN ('H1', 'H2')) +ORDER BY empid; +---- +1 18400 29000 +2 130200 17300 +3 NULL 31600 + +# PIVOT WITH UNION +query TIRRR +SELECT 'Average sale amount' AS aggregate, * + FROM quarterly_sales + PIVOT(AVG(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q4')) +UNION +SELECT 'Highest value sale' AS aggregate, * + FROM quarterly_sales + PIVOT(MAX(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q4')) +UNION +SELECT 'Lowest value sale' AS aggregate, * + FROM quarterly_sales + PIVOT(MIN(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q4')) +UNION +SELECT 'Number of sales' AS aggregate, * + FROM quarterly_sales + PIVOT(COUNT(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q4')) +UNION +SELECT 'Total amount' AS aggregate, * + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q4')) +ORDER BY aggregate, empid; +---- +Average sale amount 1 5200 4000 9000 +Average sale amount 2 19750 45350 2650 +Average sale amount 3 NULL NULL 9633.333333333334 +Highest value sale 1 10000 5000 10000 +Highest value sale 2 35000 90500 4500 +Highest value sale 3 NULL NULL 16000 +Lowest value sale 1 400 3000 8000 +Lowest value sale 2 4500 200 800 +Lowest value sale 3 NULL NULL 2700 +Number of sales 1 2 2 2 +Number of sales 2 2 2 2 +Number of sales 3 0 0 3 +Total amount 1 10400 8000 18000 +Total amount 2 39500 90700 5300 +Total amount 3 NULL NULL 28900 + + +query TIIII +WITH sales_sum AS ( + SELECT + empid, + "2023_Q1" AS "Q1_Sales", + "2023_Q4" AS "Q4_Sales" + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q4')) +), +sales_count AS ( + SELECT + empid, + "2023_Q1" AS "Q1_Count", + "2023_Q4" AS "Q4_Count" + FROM quarterly_sales + PIVOT(COUNT(amount) FOR quarter IN ('2023_Q1', '2023_Q4')) +), +combined_sales AS ( + SELECT + ss.empid, + ss."Q1_Sales", + ss."Q4_Sales", + sc."Q1_Count", + sc."Q4_Count" + FROM sales_sum ss + JOIN sales_count sc ON ss.empid = sc.empid +) +SELECT dept.* +FROM ( + SELECT + e.department, + s."Q1_Sales", + s."Q4_Sales", + s."Q1_Count", + s."Q4_Count" + FROM employees e JOIN combined_sales s ON e.empid = s.empid +) dept +WHERE dept.department IN ('Sales', 'Marketing') +ORDER BY dept.department +---- +Marketing NULL 28900 0 3 +Sales 39500 5300 2 2 +Sales 10400 18000 2 2 + +# Test PIVOT subquery with projection +query TIRRRR +SELECT 'Average sale amount' AS aggregate, * + FROM quarterly_sales + PIVOT(AVG(amount) FOR quarter IN (ANY ORDER BY quarter)) ORDER by empid +---- +Average sale amount 1 5200 4000 5500 9000 +Average sale amount 2 19750 45350 6000 2650 +Average sale amount 3 NULL NULL 2700 9633.333333333334 diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index 77ee3e4f05a0..e27c5e61db63 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -857,12 +857,14 @@ explain select x from t where x NOT IN (1,2,3,4,5) AND x IN (1,2,3); logical_plan EmptyRelation: rows=0 physical_plan EmptyExec -query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression InSubquery\(InSubquery \{ expr: Literal\(Int64\(NULL\), None\), subquery: , negated: false \}\) +query BB WITH empty AS (SELECT 10 WHERE false) SELECT NULL IN (SELECT * FROM empty), -- should be false, as the right side is empty relation NULL NOT IN (SELECT * FROM empty) -- should be true, as the right side is empty relation FROM (SELECT 1) t; +---- +false true query I WITH empty AS (SELECT 10 WHERE false) diff --git a/datafusion/sqllogictest/test_files/projection.slt b/datafusion/sqllogictest/test_files/projection.slt index 97ebe2340dc2..9f840e7bdc2f 100644 --- a/datafusion/sqllogictest/test_files/projection.slt +++ b/datafusion/sqllogictest/test_files/projection.slt @@ -253,7 +253,7 @@ physical_plan statement ok drop table t; -# Regression test for +# Regression test for # https://github.com/apache/datafusion/issues/17513 query I diff --git a/datafusion/sqllogictest/test_files/qualify.slt b/datafusion/sqllogictest/test_files/qualify.slt index d53b56ce58de..5007c0e9f4b8 100644 --- a/datafusion/sqllogictest/test_files/qualify.slt +++ b/datafusion/sqllogictest/test_files/qualify.slt @@ -39,8 +39,8 @@ CREATE TABLE users ( # Basic QUALIFY with ROW_NUMBER query ITI -SELECT id, name, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) as rn -FROM users +SELECT id, name, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) as rn +FROM users QUALIFY rn = 1 ORDER BY dept, id; ---- @@ -49,8 +49,8 @@ ORDER BY dept, id; # QUALIFY with RANK query ITI -SELECT id, name, RANK() OVER (ORDER BY salary DESC) as rank -FROM users +SELECT id, name, RANK() OVER (ORDER BY salary DESC) as rank +FROM users QUALIFY rank <= 3 ORDER BY rank, id; ---- @@ -60,8 +60,8 @@ ORDER BY rank, id; # QUALIFY with DENSE_RANK query ITI -SELECT id, name, DENSE_RANK() OVER (PARTITION BY dept ORDER BY age) as dense_rank -FROM users +SELECT id, name, DENSE_RANK() OVER (PARTITION BY dept ORDER BY age) as dense_rank +FROM users QUALIFY dense_rank <= 2 ORDER BY dept, dense_rank, id; ---- @@ -78,7 +78,7 @@ ORDER BY dept, dense_rank, id; query ITII SELECT id, name, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) as rn, RANK() OVER (ORDER BY age) as age_rank -FROM users +FROM users QUALIFY rn <= 2 AND age_rank <= 5 ORDER BY dept, rn, id; ---- @@ -88,7 +88,7 @@ ORDER BY dept, rn, id; # QUALIFY with LAG function query ITRR SELECT id, name, salary, LAG(salary) OVER (PARTITION BY dept ORDER BY id) as prev_salary -FROM users +FROM users QUALIFY prev_salary IS NOT NULL AND salary > prev_salary ORDER BY dept, id; ---- @@ -99,7 +99,7 @@ ORDER BY dept, id; # QUALIFY with LEAD function query ITRR SELECT id, name, salary, LEAD(salary) OVER (PARTITION BY dept ORDER BY id) as next_salary -FROM users +FROM users QUALIFY next_salary IS NOT NULL AND salary < next_salary ORDER BY dept, id; ---- @@ -110,7 +110,7 @@ ORDER BY dept, id; # QUALIFY with NTILE query ITI SELECT id, name, NTILE(3) OVER (PARTITION BY dept ORDER BY salary DESC) as tile -FROM users +FROM users QUALIFY tile = 1 ORDER BY dept, id; ---- @@ -121,7 +121,7 @@ ORDER BY dept, id; # QUALIFY with PERCENT_RANK query ITR SELECT id, name, PERCENT_RANK() OVER (PARTITION BY dept ORDER BY salary) as pct_rank -FROM users +FROM users QUALIFY pct_rank >= 0.5 ORDER BY dept, pct_rank, id; ---- @@ -134,7 +134,7 @@ ORDER BY dept, pct_rank, id; # QUALIFY with CUME_DIST query ITR SELECT id, name, CUME_DIST() OVER (PARTITION BY dept ORDER BY age) as cume_dist -FROM users +FROM users QUALIFY cume_dist >= 0.75 ORDER BY dept, cume_dist, id; ---- @@ -145,11 +145,11 @@ ORDER BY dept, cume_dist, id; # QUALIFY with multiple window functions query ITIII -SELECT id, name, +SELECT id, name, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) as rn, RANK() OVER (ORDER BY age) as age_rank, DENSE_RANK() OVER (PARTITION BY dept ORDER BY age) as dept_age_rank -FROM users +FROM users QUALIFY rn <= 2 AND age_rank <= 4 AND dept_age_rank <= 2 ORDER BY dept, rn, id; ---- @@ -158,9 +158,9 @@ ORDER BY dept, rn, id; # QUALIFY with arithmetic expressions query ITRI -SELECT id, name, salary, +SELECT id, name, salary, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) as rn -FROM users +FROM users QUALIFY rn = 1 AND salary > 60000 ORDER BY dept, id; ---- @@ -169,9 +169,9 @@ ORDER BY dept, id; # QUALIFY with string functions query ITI -SELECT id, name, +SELECT id, name, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY name) as rn -FROM users +FROM users QUALIFY rn = 1 ORDER BY dept, id; ---- @@ -181,7 +181,7 @@ ORDER BY dept, id; # window function with aggregate function query ITI SELECT id, name, COUNT(*) OVER (PARTITION BY dept) as cnt -FROM users +FROM users QUALIFY cnt > 4 ORDER BY dept, id; ---- @@ -198,7 +198,7 @@ FROM users WHERE salary > 5000 GROUP BY dept, salary HAVING SUM(salary) > 20000 -QUALIFY r > 60000 +QUALIFY r > 60000 ---- Marketing 70000 Marketing 70000 @@ -368,6 +368,35 @@ physical_plan 14)--------------------------AggregateExec: mode=Partial, gby=[dept@1 as dept], aggr=[sum(users.salary)] 15)----------------------------DataSourceExec: partitions=1, partition_sizes=[1] +# WHERE with scalar aggregate subquery + QUALIFY +statement ok +CREATE TABLE bulk_import_entities ( + id INT, + _task_instance INT, + _uploaded_at TIMESTAMP +) AS VALUES + (1, 1, '2025-01-01 10:00:00'::timestamp), + (1, 2, '2025-01-02 09:00:00'::timestamp), + (1, 2, '2025-01-03 08:00:00'::timestamp), + (2, 1, '2025-01-01 11:00:00'::timestamp), + (2, 2, '2025-01-02 12:00:00'::timestamp), + (3, 1, '2025-01-01 13:00:00'::timestamp); + +query II +SELECT id, _task_instance +FROM bulk_import_entities +WHERE _task_instance = ( + SELECT MAX(_task_instance) FROM bulk_import_entities +) +QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY _uploaded_at) = 1 +ORDER BY id; +---- +1 2 +2 2 + # Clean up statement ok -DROP TABLE users; +DROP TABLE users; + +statement ok +DROP TABLE bulk_import_entities diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index b0e200015dfd..75949391ef0d 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1536,8 +1536,10 @@ SELECT not(true), not(false) ---- false true -query error type_coercion\ncaused by\nError during planning: Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean +query BB SELECT not(1), not(0) +---- +false true query ?B SELECT null, not(null) diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 43f85d1e2014..dc42ab6f0e72 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -322,6 +322,156 @@ physical_plan 14)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 15)--------DataSourceExec: partitions=1, partition_sizes=[2] +# in_subquery_in_join_on_with_equijoin +query II rowsort +SELECT ds.t1_id, o.t2_id +FROM t1 ds +LEFT JOIN t2 o + ON (ds.t1_id IN (SELECT t3_id FROM t3)) AND (ds.t1_id = o.t2_id); +---- +11 11 +22 22 +33 NULL +44 44 + +# not_in_subquery_in_join_on_with_equijoin +query II rowsort +SELECT ds.t1_id, o.t2_id +FROM t1 ds +JOIN t2 o + ON (ds.t1_id NOT IN (SELECT t3_id FROM t3 WHERE t3_int = 3)) AND (ds.t1_id = o.t2_id); +---- +22 22 + +# explain subquery with join +query TT +EXPLAIN SELECT ds.t1_id, o.t2_id +FROM t1 ds +JOIN t2 o + ON (ds.t1_id NOT IN (SELECT t3_id FROM t3 WHERE t3_int = 3)) AND (ds.t1_id = o.t2_id); +---- +logical_plan +01)Inner Join: ds.t1_id = o.t2_id +02)--Projection: ds.t1_id +03)----Filter: NOT __correlated_sq_1.mark +04)------LeftMark Join: ds.t1_id = __correlated_sq_1.t3_id +05)--------SubqueryAlias: ds +06)----------TableScan: t1 projection=[t1_id] +07)--------SubqueryAlias: __correlated_sq_1 +08)----------Projection: t3.t3_id +09)------------Filter: t3.t3_int = Int32(3) +10)--------------TableScan: t3 projection=[t3_id, t3_int] +11)--SubqueryAlias: o +12)----TableScan: t2 projection=[t2_id] +physical_plan +01)CoalesceBatchesExec: target_batch_size=2 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)] +03)----CoalescePartitionsExec +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------FilterExec: NOT mark@1, projection=[t1_id@0] +06)----------CoalesceBatchesExec: target_batch_size=2 +07)------------HashJoinExec: mode=CollectLeft, join_type=LeftMark, on=[(t1_id@0, t3_id@0)] +08)--------------DataSourceExec: partitions=1, partition_sizes=[2] +09)--------------CoalesceBatchesExec: target_batch_size=2 +10)----------------FilterExec: t3_int@1 = 3, projection=[t3_id@0] +11)------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +12)--------------------DataSourceExec: partitions=1, partition_sizes=[2] +13)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)------DataSourceExec: partitions=1, partition_sizes=[2] + +# aggregate_case_in_subquery +query III rowsort +WITH distinct_source AS ( + SELECT 1 AS namespace_id, 10 AS max_task_instance, 100 AS max_uploaded_at + UNION ALL + SELECT 2, 20, 200 + UNION ALL + SELECT 2, 15, 150 +) +SELECT + namespace_id AS primary_key, + MAX(CASE WHEN max_task_instance IN ( + SELECT MAX(max_task_instance) + FROM distinct_source + ) THEN 1 ELSE 0 END) AS is_in_most_recent_task, + MAX(max_uploaded_at) AS max_timestamp +FROM distinct_source +GROUP BY 1; +---- +1 0 100 +2 1 200 + +# explain subquery with aggregate +query TT +EXPLAIN WITH distinct_source AS ( + SELECT 1 AS namespace_id, 10 AS max_task_instance, 100 AS max_uploaded_at + UNION ALL + SELECT 2, 20, 200 + UNION ALL + SELECT 2, 15, 150 +) +SELECT + namespace_id AS primary_key, + MAX(CASE WHEN max_task_instance IN ( + SELECT MAX(max_task_instance) + FROM distinct_source + ) THEN 1 ELSE 0 END) AS is_in_most_recent_task, + MAX(max_uploaded_at) AS max_timestamp +FROM distinct_source +GROUP BY 1; +---- +logical_plan +01)Projection: distinct_source.namespace_id AS primary_key, max(CASE WHEN IN THEN Int64(1) ELSE Int64(0) END) AS is_in_most_recent_task, max(distinct_source.max_uploaded_at) AS max_timestamp +02)--Aggregate: groupBy=[[distinct_source.namespace_id]], aggr=[[max(CASE WHEN __correlated_sq_1.mark THEN Int64(1) ELSE Int64(0) END) AS max(CASE WHEN IN THEN Int64(1) ELSE Int64(0) END), max(distinct_source.max_uploaded_at)]] +03)----Projection: distinct_source.namespace_id, distinct_source.max_uploaded_at, __correlated_sq_1.mark +04)------LeftMark Join: distinct_source.max_task_instance = __correlated_sq_1.max(distinct_source.max_task_instance) +05)--------SubqueryAlias: distinct_source +06)----------Union +07)------------Projection: Int64(1) AS namespace_id, Int64(10) AS max_task_instance, Int64(100) AS max_uploaded_at +08)--------------EmptyRelation: rows=1 +09)------------Projection: Int64(2) AS namespace_id, Int64(20) AS max_task_instance, Int64(200) AS max_uploaded_at +10)--------------EmptyRelation: rows=1 +11)------------Projection: Int64(2) AS namespace_id, Int64(15) AS max_task_instance, Int64(150) AS max_uploaded_at +12)--------------EmptyRelation: rows=1 +13)--------SubqueryAlias: __correlated_sq_1 +14)----------Aggregate: groupBy=[[]], aggr=[[max(distinct_source.max_task_instance)]] +15)------------SubqueryAlias: distinct_source +16)--------------Union +17)----------------Projection: Int64(10) AS max_task_instance +18)------------------EmptyRelation: rows=1 +19)----------------Projection: Int64(20) AS max_task_instance +20)------------------EmptyRelation: rows=1 +21)----------------Projection: Int64(15) AS max_task_instance +22)------------------EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[namespace_id@0 as primary_key, max(CASE WHEN IN THEN Int64(1) ELSE Int64(0) END)@1 as is_in_most_recent_task, max(distinct_source.max_uploaded_at)@2 as max_timestamp] +02)--AggregateExec: mode=FinalPartitioned, gby=[namespace_id@0 as namespace_id], aggr=[max(CASE WHEN IN THEN Int64(1) ELSE Int64(0) END), max(distinct_source.max_uploaded_at)] +03)----CoalesceBatchesExec: target_batch_size=2 +04)------RepartitionExec: partitioning=Hash([namespace_id@0], 4), input_partitions=4 +05)--------AggregateExec: mode=Partial, gby=[namespace_id@0 as namespace_id], aggr=[max(CASE WHEN IN THEN Int64(1) ELSE Int64(0) END), max(distinct_source.max_uploaded_at)] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------CoalesceBatchesExec: target_batch_size=2 +08)--------------HashJoinExec: mode=CollectLeft, join_type=LeftMark, on=[(max_task_instance@1, max(distinct_source.max_task_instance)@0)], projection=[namespace_id@0, max_uploaded_at@2, mark@3] +09)----------------CoalescePartitionsExec +10)------------------UnionExec +11)--------------------ProjectionExec: expr=[1 as namespace_id, 10 as max_task_instance, 100 as max_uploaded_at] +12)----------------------PlaceholderRowExec +13)--------------------ProjectionExec: expr=[2 as namespace_id, 20 as max_task_instance, 200 as max_uploaded_at] +14)----------------------PlaceholderRowExec +15)--------------------ProjectionExec: expr=[2 as namespace_id, 15 as max_task_instance, 150 as max_uploaded_at] +16)----------------------PlaceholderRowExec +17)----------------AggregateExec: mode=Final, gby=[], aggr=[max(distinct_source.max_task_instance)] +18)------------------CoalescePartitionsExec +19)--------------------AggregateExec: mode=Partial, gby=[], aggr=[max(distinct_source.max_task_instance)] +20)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3 +21)------------------------UnionExec +22)--------------------------ProjectionExec: expr=[10 as max_task_instance] +23)----------------------------PlaceholderRowExec +24)--------------------------ProjectionExec: expr=[20 as max_task_instance] +25)----------------------------PlaceholderRowExec +26)--------------------------ProjectionExec: expr=[15 as max_task_instance] +27)----------------------------PlaceholderRowExec + query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 ---- @@ -1480,3 +1630,32 @@ logical_plan statement count 0 drop table person; + + +# Projection IN (subquery) decorrelation +query IB rowsort +WITH t(a) AS (VALUES (1),(2),(3),(4),(5)), + u(a) AS (VALUES (2),(4),(6)) +SELECT t.a, (t.a IN (SELECT u.a FROM u)) AS flag FROM t; +---- +1 false +2 true +3 false +4 true +5 false + +query TT +EXPLAIN WITH t(a) AS (VALUES (1),(2),(3),(4),(5)), + u(a) AS (VALUES (2),(4),(6)) +SELECT t.a, (t.a IN (SELECT u.a FROM u)) AS flag FROM t; +---- +logical_plan +01)Projection: t.a, __correlated_sq_1.mark AS flag +02)--LeftMark Join: t.a = __correlated_sq_1.a +03)----SubqueryAlias: t +04)------Projection: column1 AS a +05)--------Values: (Int64(1)), (Int64(2)), (Int64(3)), (Int64(4)), (Int64(5)) +06)----SubqueryAlias: __correlated_sq_1 +07)------SubqueryAlias: u +08)--------Projection: column1 AS a +09)----------Values: (Int64(2)), (Int64(4)), (Int64(6)) diff --git a/datafusion/sqllogictest/test_files/unpivot.slt b/datafusion/sqllogictest/test_files/unpivot.slt new file mode 100644 index 000000000000..500e15ae77d0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/unpivot.slt @@ -0,0 +1,269 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####### +# Setup test data table +####### +statement ok +CREATE TABLE monthly_sales( + empid INT, + dept TEXT, + jan INT, + feb INT, + mar INT, + apr INT) + AS SELECT * FROM VALUES + (1, 'electronics', 100, 200, 300, 100), + (2, 'clothes', 100, 300, 150, 200), + (3, 'cars', 200, 400, 100, 50), + (4, 'appliances', 100, NULL, 100, 50); + +# Basic UNPIVOT excluding nulls (default behavior) +query ITTI +SELECT * + FROM monthly_sales + UNPIVOT (sales FOR month IN (jan, feb, mar, apr)) + ORDER BY empid; +---- +1 electronics JAN 100 +1 electronics FEB 200 +1 electronics MAR 300 +1 electronics APR 100 +2 clothes JAN 100 +2 clothes FEB 300 +2 clothes MAR 150 +2 clothes APR 200 +3 cars JAN 200 +3 cars FEB 400 +3 cars MAR 100 +3 cars APR 50 +4 appliances JAN 100 +4 appliances MAR 100 +4 appliances APR 50 + +# UNPIVOT with INCLUDE NULLS option +query ITTI +SELECT * + FROM monthly_sales + UNPIVOT INCLUDE NULLS (sales FOR month IN (jan, feb, mar, apr)) + ORDER BY empid; +---- +1 electronics JAN 100 +1 electronics FEB 200 +1 electronics MAR 300 +1 electronics APR 100 +2 clothes JAN 100 +2 clothes FEB 300 +2 clothes MAR 150 +2 clothes APR 200 +3 cars JAN 200 +3 cars FEB 400 +3 cars MAR 100 +3 cars APR 50 +4 appliances JAN 100 +4 appliances FEB NULL +4 appliances MAR 100 +4 appliances APR 50 + +query TTI +SELECT dept, month, sales + FROM monthly_sales + UNPIVOT (sales FOR month IN (jan, feb, mar, apr)) + ORDER BY dept; +---- +appliances JAN 100 +appliances MAR 100 +appliances APR 50 +cars JAN 200 +cars FEB 400 +cars MAR 100 +cars APR 50 +clothes JAN 100 +clothes FEB 300 +clothes MAR 150 +clothes APR 200 +electronics JAN 100 +electronics FEB 200 +electronics MAR 300 +electronics APR 100 + +# UNPIVOT with filtering +query ITTI +SELECT * + FROM monthly_sales + UNPIVOT (sales FOR month IN (jan, feb, mar, apr)) + WHERE sales > 100 + ORDER BY empid; +---- +1 electronics FEB 200 +1 electronics MAR 300 +2 clothes FEB 300 +2 clothes MAR 150 +2 clothes APR 200 +3 cars JAN 200 +3 cars FEB 400 + +# UNPIVOT with aggregation +query TI +SELECT month, SUM(sales) as total_sales + FROM monthly_sales + UNPIVOT (sales FOR month IN (jan, feb, mar, apr)) + GROUP BY month + ORDER BY month; +---- +APR 400 +FEB 900 +JAN 500 +MAR 650 + +# UNPIVOT with JOIN +query ITTI +SELECT e.empid, e.dept, u.month, u.sales + FROM monthly_sales e + JOIN ( + SELECT empid, month, sales + FROM monthly_sales + UNPIVOT (sales FOR month IN (jan, feb, mar, apr)) + ) u ON e.empid = u.empid + WHERE u.sales > 200 + ORDER BY e.empid, u.month; +---- +1 electronics MAR 300 +2 clothes FEB 300 +3 cars FEB 400 + +query ITIITI +SELECT * + FROM monthly_sales + UNPIVOT (sales FOR month IN (jan, mar)) + ORDER BY empid; +---- +1 electronics 200 100 JAN 100 +1 electronics 200 100 MAR 300 +2 clothes 300 200 JAN 100 +2 clothes 300 200 MAR 150 +3 cars 400 50 JAN 200 +3 cars 400 50 MAR 100 +4 appliances NULL 50 JAN 100 +4 appliances NULL 50 MAR 100 + +# UNPIVOT with HAVING clause +query TI +SELECT month, SUM(sales) as total_sales + FROM monthly_sales + UNPIVOT (sales FOR month IN (jan, feb, mar, apr)) + GROUP BY month + HAVING SUM(sales) > 400 + ORDER BY month; +---- +FEB 900 +JAN 500 +MAR 650 + +# UNPIVOT with subquery +query ITTI +SELECT * + FROM ( + SELECT empid, dept, jan, feb, mar + FROM monthly_sales + WHERE dept IN ('electronics', 'clothes') + ) + UNPIVOT (sales FOR month IN (jan, feb, mar)) + ORDER BY empid; +---- +1 electronics JAN 100 +1 electronics FEB 200 +1 electronics MAR 300 +2 clothes JAN 100 +2 clothes FEB 300 +2 clothes MAR 150 + +# Non-existent column in the column list +query error DataFusion error: Error during planning: Column 'non_existent' not found in input +SELECT * + FROM monthly_sales + UNPIVOT (sales FOR month IN (non_existent, feb, mar)) + ORDER BY empid; + +statement ok +CREATE TABLE mixed_types( + id INT, + col1 INT, + col2 TEXT, + col3 FLOAT) + AS SELECT * FROM VALUES + (1, 100, 'abc', 10.5), + (2, 200, 'def', 20.5); + +query ITT +SELECT * + FROM mixed_types + UNPIVOT (val FOR col_name IN (col1, col2, col3)) + ORDER BY id; +---- +1 COL1 100 +1 COL2 abc +1 COL3 10.5 +2 COL1 200 +2 COL2 def +2 COL3 20.5 + +# UNPIVOT with CTE +query ITTI +WITH sales_data AS ( + SELECT * FROM monthly_sales WHERE empid < 3 +) +SELECT * + FROM sales_data + UNPIVOT (sales FOR month IN (jan, feb, mar, apr)) + ORDER BY empid; +---- +1 electronics JAN 100 +1 electronics FEB 200 +1 electronics MAR 300 +1 electronics APR 100 +2 clothes JAN 100 +2 clothes FEB 300 +2 clothes MAR 150 +2 clothes APR 200 + +# UNPIVOT with UNION +query ITIITI +SELECT * + FROM monthly_sales + UNPIVOT (sales FOR month IN (jan, feb)) + UNION ALL +SELECT * + FROM monthly_sales + UNPIVOT (sales FOR month IN (mar, apr)) + ORDER BY empid, month; +---- +1 electronics 100 200 APR 100 +1 electronics 300 100 FEB 200 +1 electronics 300 100 JAN 100 +1 electronics 100 200 MAR 300 +2 clothes 100 300 APR 200 +2 clothes 150 200 FEB 300 +2 clothes 150 200 JAN 100 +2 clothes 100 300 MAR 150 +3 cars 200 400 APR 50 +3 cars 100 50 FEB 400 +3 cars 100 50 JAN 200 +3 cars 200 400 MAR 100 +4 appliances 100 NULL APR 50 +4 appliances 100 50 JAN 100 +4 appliances 100 NULL MAR 100 diff --git a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs index c3599a2635ff..7cac8cd1ae5e 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs @@ -74,5 +74,6 @@ pub fn to_substrait_rel( LogicalPlan::RecursiveQuery(plan) => { not_impl_err!("Unsupported plan type: {plan:?}")? } + LogicalPlan::Pivot(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, } } diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 2335105882a1..5f53ea0f074a 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -1312,8 +1312,8 @@ use datafusion_expr::Expr; pub struct EchoFunction {} impl TableFunctionImpl for EchoFunction { - fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { + fn call(&self, exprs: &[(datafusion_expr::Expr, Option)]) -> Result> { + let Some((Expr::Literal(ScalarValue::Int64(Some(value)), _), _)) = exprs.get(0) else { return plan_err!("First argument must be an integer"); }; @@ -1353,8 +1353,8 @@ With the UDTF implemented, you can register it with the `SessionContext`: # pub struct EchoFunction {} # # impl TableFunctionImpl for EchoFunction { -# fn call(&self, exprs: &[Expr]) -> Result> { -# let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { +# fn call(&self, exprs: &[(datafusion_expr::Expr, Option)]) -> Result> { +# let Some((Expr::Literal(ScalarValue::Int64(Some(value)), _), _)) = exprs.get(0) else { # return plan_err!("First argument must be an integer"); # }; # diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index b0811ab7811b..6050f37f0ae1 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1797,6 +1797,8 @@ The following regular expression functions are supported: - [regexp_like](#regexp_like) - [regexp_match](#regexp_match) - [regexp_replace](#regexp_replace) +- [regexp_substr](#regexp_substr) +- [rlike](#rlike) ### `regexp_count` @@ -1900,6 +1902,10 @@ SELECT regexp_like('aBc', '(b|d)', 'i'); Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +#### Aliases + +- rlike + ### `regexp_match` Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. @@ -1980,6 +1986,54 @@ SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +### `regexp_substr` + +Returns the substring that matches a [regular expression](https://docs.rs/regex/latest/regex/#syntax) within a string. + +```sql +regexp_substr(str, regexp[, position[, occurrence[, flags[, group_num]]]]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to match against. + Can be a constant, column, or function. +- **position**: Number of characters from the beginning of the string where the function starts searching for matches. Default: 1 +- **occurrence**: Specifies the first occurrence of the pattern from which to start returning matches.. Default: 1 +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **c**: case-sensitive: letters match upper or lower case. Default flag + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **e**: extract submatches (for Snowflake compatibility) + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? +- **group_num**: Specifies which group to extract. Groups are specified by using parentheses in the regular expression. + +#### Example + +```sql + > select regexp_substr('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_substr(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | Köln | + +---------------------------------------------------------+ + SELECT regexp_substr('aBc', '(b|d)', 1, 1, 'i'); + +---------------------------------------------------+ + | regexp_substr(Utf8("aBc"),Utf8("(b|d)"), Int32(1), Int32(1), Utf8("i")) | + +---------------------------------------------------+ + | B | + +---------------------------------------------------+ +``` + +Additional examples can be found [here](https://docs.snowflake.com/en/sql-reference/functions/regexp_substr#examples) + +### `rlike` + +_Alias of [regexp_like](#regexp_like)._ + ## Time and Date Functions - [current_date](#current_date)