diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index c18c48251daa..cfb6b206ad87 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -30,7 +30,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, plan_err, Column, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; -use datafusion_expr::logical_plan::{JoinType, Subquery}; +use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; use datafusion_expr::utils::{conjunction, split_conjunction_owned}; use datafusion_expr::{ exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, @@ -66,54 +66,82 @@ impl OptimizerRule for DecorrelatePredicateSubquery { })? .data; - let LogicalPlan::Filter(filter) = plan else { - return Ok(Transformed::no(plan)); - }; - - if !has_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); - } + match plan { + LogicalPlan::Filter(filter) => { + if !has_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } - let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = - split_conjunction_owned(filter.predicate) - .into_iter() - .partition(has_subquery); + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter.predicate) + .into_iter() + .partition(has_subquery); - if with_subqueries.is_empty() { - return internal_err!( - "can not find expected subqueries in DecorrelatePredicateSubquery" - ); - } + if with_subqueries.is_empty() { + return internal_err!( + "can not find expected subqueries in DecorrelatePredicateSubquery" + ); + } - // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = Arc::unwrap_or_clone(filter.input); - for subquery_expr in with_subqueries { - match extract_subquery_info(subquery_expr) { - // The subquery expression is at the top level of the filter - SubqueryPredicate::Top(subquery) => { - match build_join_top(&subquery, &cur_input, config.alias_generator())? - { - Some(plan) => cur_input = plan, - // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter - None => other_exprs.push(subquery.expr()), + // iterate through all exists clauses in predicate, turning each into a join + let mut cur_input = Arc::unwrap_or_clone(filter.input); + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + // The subquery expression is at the top level of the filter + SubqueryPredicate::Top(subquery) => { + match build_join_top( + &subquery, + &cur_input, + config.alias_generator(), + )? { + Some(plan) => cur_input = plan, + // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter + None => other_exprs.push(subquery.expr()), + } + } + // The subquery expression is embedded within another expression + SubqueryPredicate::Embedded(expr) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = plan; + other_exprs.push(expr_without_subqueries); + } } } - // The subquery expression is embedded within another expression - SubqueryPredicate::Embedded(expr) => { - let (plan, expr_without_subqueries) = - rewrite_inner_subqueries(cur_input, expr, config)?; - cur_input = plan; - other_exprs.push(expr_without_subqueries); + + let expr = conjunction(other_exprs); + let mut new_plan = cur_input; + if let Some(expr) = expr { + let new_filter = Filter::try_new(expr, Arc::new(new_plan))?; + new_plan = LogicalPlan::Filter(new_filter); } + Ok(Transformed::yes(new_plan)) } - } + LogicalPlan::Projection(proj) => { + // Only proceed if any projection expression contains a subquery + if !proj.expr.iter().any(has_subquery) { + return Ok(Transformed::no(LogicalPlan::Projection(proj))); + } - let expr = conjunction(other_exprs); - if let Some(expr) = expr { - let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; - cur_input = LogicalPlan::Filter(new_filter); + let mut cur_input = Arc::unwrap_or_clone(proj.input); + let mut new_exprs = Vec::with_capacity(proj.expr.len()); + for e in proj.expr { + let old_name = e.schema_name().to_string(); + let (plan_after, rewritten) = + rewrite_inner_subqueries(cur_input, e, config)?; + cur_input = plan_after; + let new_name = rewritten.schema_name().to_string(); + if new_name != old_name { + new_exprs.push(rewritten.alias(old_name)); + } else { + new_exprs.push(rewritten); + } + } + let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?; + Ok(Transformed::yes(LogicalPlan::Projection(new_proj))) + } + other => Ok(Transformed::no(other)), } - Ok(Transformed::yes(cur_input)) } fn name(&self) -> &str { @@ -529,6 +557,31 @@ mod tests { assert_optimized_plan_equal(plan, expected) } + /// Projection IN (subquery) should be decorrelated via LeftMark join in Projection + #[test] + fn projection_in_subquery_simple() -> Result<()> { + // Build outer values t(a) = (1),(2) + let outer = LogicalPlanBuilder::values(vec![vec![lit(1_i32)], vec![lit(2_i32)]])? + .project(vec![col("column1").alias("a")])? + .build()?; + + // Build subquery u(a) = (2) + let sub = Arc::new( + LogicalPlanBuilder::values(vec![vec![lit(2_i32)]])? + .project(vec![col("column1").alias("ua")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer) + .project(vec![col("a"), in_subquery(col("a"), sub).alias("flag")])? + .build()?; + + // We expect a LeftMark join inserted and the projection keeps columns + let expected = "Projection: a, __correlated_sq_1.mark AS flag [a:Int32;N, flag:Boolean]\n LeftMark Join: Filter: a = __correlated_sq_1.ua [a:Int32;N, mark:Boolean]\n Projection: column1 AS a [a:Int32;N]\n Values: (Int32(1)), (Int32(2)) [column1:Int32;N]\n SubqueryAlias: __correlated_sq_1 [ua:Int32;N]\n Projection: column1 AS ua [ua:Int32;N]\n Values: (Int32(2)) [column1:Int32;N]"; + + assert_optimized_plan_equal(plan, expected) + } + /// Test multiple correlated subqueries /// See subqueries.rs where_in_multiple() #[test] diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index a0ac15b740d7..f32a23e26ff5 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1482,3 +1482,32 @@ logical_plan statement count 0 drop table person; + + +# Projection IN (subquery) decorrelation +query IB rowsort +WITH t(a) AS (VALUES (1),(2),(3),(4),(5)), + u(a) AS (VALUES (2),(4),(6)) +SELECT t.a, (t.a IN (SELECT u.a FROM u)) AS flag FROM t; +---- +1 false +2 true +3 false +4 true +5 false + +query TT +EXPLAIN WITH t(a) AS (VALUES (1),(2),(3),(4),(5)), + u(a) AS (VALUES (2),(4),(6)) +SELECT t.a, (t.a IN (SELECT u.a FROM u)) AS flag FROM t; +---- +logical_plan +01)Projection: t.a, __correlated_sq_1.mark AS flag +02)--LeftMark Join: t.a = __correlated_sq_1.a +03)----SubqueryAlias: t +04)------Projection: column1 AS a +05)--------Values: (Int64(1)), (Int64(2)), (Int64(3)), (Int64(4)), (Int64(5)) +06)----SubqueryAlias: __correlated_sq_1 +07)------SubqueryAlias: u +08)--------Projection: column1 AS a +09)----------Values: (Int64(2)), (Int64(4)), (Int64(6))