From d13b7a9f6de140d13e5e89246b1b1d692c3a1202 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Wed, 3 Sep 2025 17:44:03 +0300 Subject: [PATCH] Decorellate subqueries in IN inside JOIN filter and Aggregates --- .../src/decorrelate_predicate_subquery.rs | 245 +++++++++++++----- .../sqllogictest/test_files/subquery.slt | 150 +++++++++++ 2 files changed, 335 insertions(+), 60 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index cfb6b206ad87..c1fcb1f21f6d 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -30,7 +30,9 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, plan_err, Column, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; -use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; +use datafusion_expr::logical_plan::{ + Join as LogicalJoin, JoinType, Projection, Subquery, +}; use datafusion_expr::utils::{conjunction, split_conjunction_owned}; use datafusion_expr::{ exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, @@ -66,82 +68,166 @@ impl OptimizerRule for DecorrelatePredicateSubquery { })? .data; - match plan { - LogicalPlan::Filter(filter) => { - if !has_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); - } + // Handle Filters first (existing behavior) + if let LogicalPlan::Filter(filter) = plan.clone() { + if !has_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } - let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = - split_conjunction_owned(filter.predicate) - .into_iter() - .partition(has_subquery); + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter.predicate) + .into_iter() + .partition(has_subquery); - if with_subqueries.is_empty() { - return internal_err!( - "can not find expected subqueries in DecorrelatePredicateSubquery" - ); - } + if with_subqueries.is_empty() { + return internal_err!( + "can not find expected subqueries in DecorrelatePredicateSubquery" + ); + } - // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = Arc::unwrap_or_clone(filter.input); - for subquery_expr in with_subqueries { - match extract_subquery_info(subquery_expr) { - // The subquery expression is at the top level of the filter - SubqueryPredicate::Top(subquery) => { - match build_join_top( - &subquery, - &cur_input, - config.alias_generator(), - )? { - Some(plan) => cur_input = plan, - // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter - None => other_exprs.push(subquery.expr()), - } - } - // The subquery expression is embedded within another expression - SubqueryPredicate::Embedded(expr) => { - let (plan, expr_without_subqueries) = - rewrite_inner_subqueries(cur_input, expr, config)?; - cur_input = plan; - other_exprs.push(expr_without_subqueries); + // iterate through all exists clauses in predicate, turning each into a join + let mut cur_input = Arc::unwrap_or_clone(filter.input); + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + // The subquery expression is at the top level of the filter + SubqueryPredicate::Top(subquery) => { + match build_join_top( + &subquery, + &cur_input, + config.alias_generator(), + )? { + Some(plan) => cur_input = plan, + // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter + None => other_exprs.push(subquery.expr()), } } + // The subquery expression is embedded within another expression + SubqueryPredicate::Embedded(expr) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = plan; + other_exprs.push(expr_without_subqueries); + } } + } + + let expr = conjunction(other_exprs); + if let Some(expr) = expr { + let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; + return Ok(Transformed::yes(LogicalPlan::Filter(new_filter))); + } + return Ok(Transformed::yes(cur_input)); + } - let expr = conjunction(other_exprs); - let mut new_plan = cur_input; - if let Some(expr) = expr { - let new_filter = Filter::try_new(expr, Arc::new(new_plan))?; - new_plan = LogicalPlan::Filter(new_filter); + // Additionally handle subqueries embedded in Join.filter expressions + if let LogicalPlan::Join(join) = plan { + if let Some(predicate) = &join.filter { + if has_subquery(predicate) { + let (new_left, new_predicate) = rewrite_inner_subqueries( + Arc::unwrap_or_clone(join.left), + predicate.clone(), + config, + )?; + + let new_join = LogicalJoin::try_new( + Arc::new(new_left), + Arc::clone(&join.right), + join.on.clone(), + Some(new_predicate), + join.join_type, + join.join_constraint, + join.null_equals_null, + )?; + return Ok(Transformed::yes(LogicalPlan::Join(new_join))); } - Ok(Transformed::yes(new_plan)) } - LogicalPlan::Projection(proj) => { - // Only proceed if any projection expression contains a subquery - if !proj.expr.iter().any(has_subquery) { - return Ok(Transformed::no(LogicalPlan::Projection(proj))); + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + + // Handle subqueries embedded in Aggregate group/aggregate expressions + if let LogicalPlan::Aggregate(aggregate) = plan { + let mut needs_rewrite = false; + for e in &aggregate.group_expr { + if has_subquery(e) { + needs_rewrite = true; + break; + } + } + if !needs_rewrite { + for e in &aggregate.aggr_expr { + if has_subquery(e) { + needs_rewrite = true; + break; + } } + } + if !needs_rewrite { + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } - let mut cur_input = Arc::unwrap_or_clone(proj.input); - let mut new_exprs = Vec::with_capacity(proj.expr.len()); - for e in proj.expr { - let old_name = e.schema_name().to_string(); - let (plan_after, rewritten) = - rewrite_inner_subqueries(cur_input, e, config)?; - cur_input = plan_after; - let new_name = rewritten.schema_name().to_string(); + let mut cur_input = Arc::unwrap_or_clone(aggregate.input); + let mut new_group_exprs = Vec::with_capacity(aggregate.group_expr.len()); + for expr in aggregate.group_expr { + if has_subquery(&expr) { + let (next_input, rewritten_expr) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = next_input; + new_group_exprs.push(rewritten_expr); + } else { + new_group_exprs.push(expr); + } + } + let mut new_aggr_exprs = Vec::with_capacity(aggregate.aggr_expr.len()); + for expr in aggregate.aggr_expr { + if has_subquery(&expr) { + let old_name = expr.schema_name().to_string(); + let (next_input, rewritten_expr) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = next_input; + let new_name = rewritten_expr.schema_name().to_string(); if new_name != old_name { - new_exprs.push(rewritten.alias(old_name)); + new_aggr_exprs.push(rewritten_expr.alias(old_name)); } else { - new_exprs.push(rewritten); + new_aggr_exprs.push(rewritten_expr); } + } else { + new_aggr_exprs.push(expr); } - let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?; - Ok(Transformed::yes(LogicalPlan::Projection(new_proj))) } - other => Ok(Transformed::no(other)), + + let new_plan = LogicalPlanBuilder::from(cur_input) + .aggregate(new_group_exprs, new_aggr_exprs)? + .build()?; + return Ok(Transformed::yes(new_plan)); } + + // Handle Projection nodes with subqueries in expressions + if let LogicalPlan::Projection(proj) = plan { + // Only proceed if any projection expression contains a subquery + if !proj.expr.iter().any(has_subquery) { + return Ok(Transformed::no(LogicalPlan::Projection(proj))); + } + + let mut cur_input = Arc::unwrap_or_clone(proj.input); + let mut new_exprs = Vec::with_capacity(proj.expr.len()); + for e in proj.expr { + let old_name = e.schema_name().to_string(); + let (plan_after, rewritten) = + rewrite_inner_subqueries(cur_input, e, config)?; + cur_input = plan_after; + let new_name = rewritten.schema_name().to_string(); + if new_name != old_name { + new_exprs.push(rewritten.alias(old_name)); + } else { + new_exprs.push(rewritten); + } + } + let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?; + return Ok(Transformed::yes(LogicalPlan::Projection(new_proj))); + } + + // Other plans unchanged + Ok(Transformed::no(plan)) } fn name(&self) -> &str { @@ -477,6 +563,45 @@ mod tests { )) } + /// Aggregation with CASE WHEN ... IN (subquery) should be decorrelated under the Aggregate + #[test] + fn aggregate_case_in_subquery() -> Result<()> { + let table_scan = test_table_scan_with_name("distinct_source")?; + use datafusion_expr::expr_fn::when; + use datafusion_functions_aggregate::expr_fn::max as agg_max; + + let agg_b: Expr = agg_max(col("distinct_source.b")); + let subq = LogicalPlanBuilder::from(table_scan.clone()) + .aggregate(Vec::::new(), vec![agg_b])? + .project(vec![col("max(distinct_source.b)")])? + .build()?; + + let case_expr = when( + in_subquery(col("distinct_source.b"), Arc::new(subq)), + lit(1), + ) + .otherwise(lit(0))?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("distinct_source.a").alias("primary_key")], + vec![ + agg_max(case_expr).alias("is_in_most_recent_task"), + agg_max(col("distinct_source.c")).alias("max_timestamp"), + ], + )? + .build()?; + + use crate::{OptimizerContext, OptimizerRule}; + let optimized = DecorrelatePredicateSubquery::new() + .rewrite(plan, &OptimizerContext::new())? + .data; + let lp = optimized.display_indent().to_string(); + assert!(lp.contains("Aggregate:")); + assert!(lp.contains("Left")); + Ok(()) + } + /// Test for several IN subquery expressions #[test] fn in_subquery_multiple() -> Result<()> { diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index f32a23e26ff5..09d9cc9897ee 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -322,6 +322,156 @@ physical_plan 14)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 15)--------DataSourceExec: partitions=1, partition_sizes=[1] +# in_subquery_in_join_on_with_equijoin +query II rowsort +SELECT ds.t1_id, o.t2_id +FROM t1 ds +LEFT JOIN t2 o + ON (ds.t1_id IN (SELECT t3_id FROM t3)) AND (ds.t1_id = o.t2_id); +---- +11 11 +22 22 +33 NULL +44 44 + +# not_in_subquery_in_join_on_with_equijoin +query II rowsort +SELECT ds.t1_id, o.t2_id +FROM t1 ds +JOIN t2 o + ON (ds.t1_id NOT IN (SELECT t3_id FROM t3 WHERE t3_int = 3)) AND (ds.t1_id = o.t2_id); +---- +22 22 + +# explain subquery with join +query TT +EXPLAIN SELECT ds.t1_id, o.t2_id +FROM t1 ds +JOIN t2 o + ON (ds.t1_id NOT IN (SELECT t3_id FROM t3 WHERE t3_int = 3)) AND (ds.t1_id = o.t2_id); +---- +logical_plan +01)Inner Join: ds.t1_id = o.t2_id +02)--Projection: ds.t1_id +03)----Filter: NOT __correlated_sq_1.mark +04)------LeftMark Join: ds.t1_id = __correlated_sq_1.t3_id +05)--------SubqueryAlias: ds +06)----------TableScan: t1 projection=[t1_id] +07)--------SubqueryAlias: __correlated_sq_1 +08)----------Projection: t3.t3_id +09)------------Filter: t3.t3_int = Int32(3) +10)--------------TableScan: t3 projection=[t3_id, t3_int] +11)--SubqueryAlias: o +12)----TableScan: t2 projection=[t2_id] +physical_plan +01)CoalesceBatchesExec: target_batch_size=2 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)] +03)----CoalescePartitionsExec +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------FilterExec: NOT mark@1, projection=[t1_id@0] +06)----------CoalesceBatchesExec: target_batch_size=2 +07)------------HashJoinExec: mode=CollectLeft, join_type=LeftMark, on=[(t1_id@0, t3_id@0)] +08)--------------DataSourceExec: partitions=1, partition_sizes=[1] +09)--------------CoalesceBatchesExec: target_batch_size=2 +10)----------------FilterExec: t3_int@1 = 3, projection=[t3_id@0] +11)------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +12)--------------------DataSourceExec: partitions=1, partition_sizes=[1] +13)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)------DataSourceExec: partitions=1, partition_sizes=[1] + +# aggregate_case_in_subquery +query III rowsort +WITH distinct_source AS ( + SELECT 1 AS namespace_id, 10 AS max_task_instance, 100 AS max_uploaded_at + UNION ALL + SELECT 2, 20, 200 + UNION ALL + SELECT 2, 15, 150 +) +SELECT + namespace_id AS primary_key, + MAX(CASE WHEN max_task_instance IN ( + SELECT MAX(max_task_instance) + FROM distinct_source + ) THEN 1 ELSE 0 END) AS is_in_most_recent_task, + MAX(max_uploaded_at) AS max_timestamp +FROM distinct_source +GROUP BY 1; +---- +1 0 100 +2 1 200 + +# explain subquery with aggregate +query TT +EXPLAIN WITH distinct_source AS ( + SELECT 1 AS namespace_id, 10 AS max_task_instance, 100 AS max_uploaded_at + UNION ALL + SELECT 2, 20, 200 + UNION ALL + SELECT 2, 15, 150 +) +SELECT + namespace_id AS primary_key, + MAX(CASE WHEN max_task_instance IN ( + SELECT MAX(max_task_instance) + FROM distinct_source + ) THEN 1 ELSE 0 END) AS is_in_most_recent_task, + MAX(max_uploaded_at) AS max_timestamp +FROM distinct_source +GROUP BY 1; +---- +logical_plan +01)Projection: distinct_source.namespace_id AS primary_key, max(CASE WHEN IN THEN Int64(1) ELSE Int64(0) END) AS is_in_most_recent_task, max(distinct_source.max_uploaded_at) AS max_timestamp +02)--Aggregate: groupBy=[[distinct_source.namespace_id]], aggr=[[max(CASE WHEN __correlated_sq_1.mark THEN Int64(1) ELSE Int64(0) END) AS max(CASE WHEN IN THEN Int64(1) ELSE Int64(0) END), max(distinct_source.max_uploaded_at)]] +03)----Projection: distinct_source.namespace_id, distinct_source.max_uploaded_at, __correlated_sq_1.mark +04)------LeftMark Join: distinct_source.max_task_instance = __correlated_sq_1.max(distinct_source.max_task_instance) +05)--------SubqueryAlias: distinct_source +06)----------Union +07)------------Projection: Int64(1) AS namespace_id, Int64(10) AS max_task_instance, Int64(100) AS max_uploaded_at +08)--------------EmptyRelation +09)------------Projection: Int64(2) AS namespace_id, Int64(20) AS max_task_instance, Int64(200) AS max_uploaded_at +10)--------------EmptyRelation +11)------------Projection: Int64(2) AS namespace_id, Int64(15) AS max_task_instance, Int64(150) AS max_uploaded_at +12)--------------EmptyRelation +13)--------SubqueryAlias: __correlated_sq_1 +14)----------Aggregate: groupBy=[[]], aggr=[[max(distinct_source.max_task_instance)]] +15)------------SubqueryAlias: distinct_source +16)--------------Union +17)----------------Projection: Int64(10) AS max_task_instance +18)------------------EmptyRelation +19)----------------Projection: Int64(20) AS max_task_instance +20)------------------EmptyRelation +21)----------------Projection: Int64(15) AS max_task_instance +22)------------------EmptyRelation +physical_plan +01)ProjectionExec: expr=[namespace_id@0 as primary_key, max(CASE WHEN IN THEN Int64(1) ELSE Int64(0) END)@1 as is_in_most_recent_task, max(distinct_source.max_uploaded_at)@2 as max_timestamp] +02)--AggregateExec: mode=FinalPartitioned, gby=[namespace_id@0 as namespace_id], aggr=[max(CASE WHEN IN THEN Int64(1) ELSE Int64(0) END), max(distinct_source.max_uploaded_at)] +03)----CoalesceBatchesExec: target_batch_size=2 +04)------RepartitionExec: partitioning=Hash([namespace_id@0], 4), input_partitions=4 +05)--------AggregateExec: mode=Partial, gby=[namespace_id@0 as namespace_id], aggr=[max(CASE WHEN IN THEN Int64(1) ELSE Int64(0) END), max(distinct_source.max_uploaded_at)] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------CoalesceBatchesExec: target_batch_size=2 +08)--------------HashJoinExec: mode=CollectLeft, join_type=LeftMark, on=[(max_task_instance@1, max(distinct_source.max_task_instance)@0)], projection=[namespace_id@0, max_uploaded_at@2, mark@3] +09)----------------CoalescePartitionsExec +10)------------------UnionExec +11)--------------------ProjectionExec: expr=[1 as namespace_id, 10 as max_task_instance, 100 as max_uploaded_at] +12)----------------------PlaceholderRowExec +13)--------------------ProjectionExec: expr=[2 as namespace_id, 20 as max_task_instance, 200 as max_uploaded_at] +14)----------------------PlaceholderRowExec +15)--------------------ProjectionExec: expr=[2 as namespace_id, 15 as max_task_instance, 150 as max_uploaded_at] +16)----------------------PlaceholderRowExec +17)----------------AggregateExec: mode=Final, gby=[], aggr=[max(distinct_source.max_task_instance)] +18)------------------CoalescePartitionsExec +19)--------------------AggregateExec: mode=Partial, gby=[], aggr=[max(distinct_source.max_task_instance)] +20)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3 +21)------------------------UnionExec +22)--------------------------ProjectionExec: expr=[10 as max_task_instance] +23)----------------------------PlaceholderRowExec +24)--------------------------ProjectionExec: expr=[20 as max_task_instance] +25)----------------------------PlaceholderRowExec +26)--------------------------ProjectionExec: expr=[15 as max_task_instance] +27)----------------------------PlaceholderRowExec + query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 ----