Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 93 additions & 40 deletions datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{internal_err, plan_err, Column, Result};
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
use datafusion_expr::logical_plan::{JoinType, Subquery};
use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
use datafusion_expr::{
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
Expand Down Expand Up @@ -66,54 +66,82 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
})?
.data;

let LogicalPlan::Filter(filter) = plan else {
return Ok(Transformed::no(plan));
};

if !has_subquery(&filter.predicate) {
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}
match plan {
LogicalPlan::Filter(filter) => {
if !has_subquery(&filter.predicate) {
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}

let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
split_conjunction_owned(filter.predicate)
.into_iter()
.partition(has_subquery);
let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
split_conjunction_owned(filter.predicate)
.into_iter()
.partition(has_subquery);

if with_subqueries.is_empty() {
return internal_err!(
"can not find expected subqueries in DecorrelatePredicateSubquery"
);
}
if with_subqueries.is_empty() {
return internal_err!(
"can not find expected subqueries in DecorrelatePredicateSubquery"
);
}

// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = Arc::unwrap_or_clone(filter.input);
for subquery_expr in with_subqueries {
match extract_subquery_info(subquery_expr) {
// The subquery expression is at the top level of the filter
SubqueryPredicate::Top(subquery) => {
match build_join_top(&subquery, &cur_input, config.alias_generator())?
{
Some(plan) => cur_input = plan,
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
None => other_exprs.push(subquery.expr()),
// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = Arc::unwrap_or_clone(filter.input);
for subquery_expr in with_subqueries {
match extract_subquery_info(subquery_expr) {
// The subquery expression is at the top level of the filter
SubqueryPredicate::Top(subquery) => {
match build_join_top(
&subquery,
&cur_input,
config.alias_generator(),
)? {
Some(plan) => cur_input = plan,
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
None => other_exprs.push(subquery.expr()),
}
}
// The subquery expression is embedded within another expression
SubqueryPredicate::Embedded(expr) => {
let (plan, expr_without_subqueries) =
rewrite_inner_subqueries(cur_input, expr, config)?;
cur_input = plan;
other_exprs.push(expr_without_subqueries);
}
}
}
// The subquery expression is embedded within another expression
SubqueryPredicate::Embedded(expr) => {
let (plan, expr_without_subqueries) =
rewrite_inner_subqueries(cur_input, expr, config)?;
cur_input = plan;
other_exprs.push(expr_without_subqueries);

let expr = conjunction(other_exprs);
let mut new_plan = cur_input;
if let Some(expr) = expr {
let new_filter = Filter::try_new(expr, Arc::new(new_plan))?;
new_plan = LogicalPlan::Filter(new_filter);
}
Ok(Transformed::yes(new_plan))
}
}
LogicalPlan::Projection(proj) => {
// Only proceed if any projection expression contains a subquery
if !proj.expr.iter().any(has_subquery) {
return Ok(Transformed::no(LogicalPlan::Projection(proj)));
}

let expr = conjunction(other_exprs);
if let Some(expr) = expr {
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
cur_input = LogicalPlan::Filter(new_filter);
let mut cur_input = Arc::unwrap_or_clone(proj.input);
let mut new_exprs = Vec::with_capacity(proj.expr.len());
for e in proj.expr {
let old_name = e.schema_name().to_string();
let (plan_after, rewritten) =
rewrite_inner_subqueries(cur_input, e, config)?;
cur_input = plan_after;
let new_name = rewritten.schema_name().to_string();
if new_name != old_name {
new_exprs.push(rewritten.alias(old_name));
} else {
new_exprs.push(rewritten);
}
}
let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?;
Ok(Transformed::yes(LogicalPlan::Projection(new_proj)))
}
other => Ok(Transformed::no(other)),
}
Ok(Transformed::yes(cur_input))
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -529,6 +557,31 @@ mod tests {
assert_optimized_plan_equal(plan, expected)
}

/// Projection IN (subquery) should be decorrelated via LeftMark join in Projection
#[test]
fn projection_in_subquery_simple() -> Result<()> {
// Build outer values t(a) = (1),(2)
let outer = LogicalPlanBuilder::values(vec![vec![lit(1_i32)], vec![lit(2_i32)]])?
.project(vec![col("column1").alias("a")])?
.build()?;

// Build subquery u(a) = (2)
let sub = Arc::new(
LogicalPlanBuilder::values(vec![vec![lit(2_i32)]])?
.project(vec![col("column1").alias("ua")])?
.build()?,
);

let plan = LogicalPlanBuilder::from(outer)
.project(vec![col("a"), in_subquery(col("a"), sub).alias("flag")])?
.build()?;

// We expect a LeftMark join inserted and the projection keeps columns
let expected = "Projection: a, __correlated_sq_1.mark AS flag [a:Int32;N, flag:Boolean]\n LeftMark Join: Filter: a = __correlated_sq_1.ua [a:Int32;N, mark:Boolean]\n Projection: column1 AS a [a:Int32;N]\n Values: (Int32(1)), (Int32(2)) [column1:Int32;N]\n SubqueryAlias: __correlated_sq_1 [ua:Int32;N]\n Projection: column1 AS ua [ua:Int32;N]\n Values: (Int32(2)) [column1:Int32;N]";

assert_optimized_plan_equal(plan, expected)
}

/// Test multiple correlated subqueries
/// See subqueries.rs where_in_multiple()
#[test]
Expand Down
29 changes: 29 additions & 0 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1482,3 +1482,32 @@ logical_plan

statement count 0
drop table person;


# Projection IN (subquery) decorrelation
query IB rowsort
WITH t(a) AS (VALUES (1),(2),(3),(4),(5)),
u(a) AS (VALUES (2),(4),(6))
SELECT t.a, (t.a IN (SELECT u.a FROM u)) AS flag FROM t;
----
1 false
2 true
3 false
4 true
5 false

query TT
EXPLAIN WITH t(a) AS (VALUES (1),(2),(3),(4),(5)),
u(a) AS (VALUES (2),(4),(6))
SELECT t.a, (t.a IN (SELECT u.a FROM u)) AS flag FROM t;
----
logical_plan
01)Projection: t.a, __correlated_sq_1.mark AS flag
02)--LeftMark Join: t.a = __correlated_sq_1.a
03)----SubqueryAlias: t
04)------Projection: column1 AS a
05)--------Values: (Int64(1)), (Int64(2)), (Int64(3)), (Int64(4)), (Int64(5))
06)----SubqueryAlias: __correlated_sq_1
07)------SubqueryAlias: u
08)--------Projection: column1 AS a
09)----------Values: (Int64(2)), (Int64(4)), (Int64(6))
Loading