From 79b50891a02cc153ee483f9d70b7fea8d0b2fe85 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sat, 19 Oct 2024 13:13:27 -0700 Subject: [PATCH 01/17] Some testing --- .../sql/catalyst/optimizer/subquery.scala | 81 +++++++++++++++++++ .../org/apache/spark/sql/SubquerySuite.scala | 10 +++ 2 files changed, 91 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 5a4e9f37c3951..bc76fac2e5fba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -115,6 +115,26 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } + def aggregateExprsNestedInSubquery(exprs: Seq[Expression]): Boolean = { + exprs.exists { expr => + aggregateExprContainsNestedInSubquery(expr) + } + } + + def aggregateExprContainsNestedInSubquery(expr: Expression): Boolean = { + expr.exists { + case InSubquery(values, _) => + values.exists { v => + v.exists { + case a: AggregateExpression => true + case _ => false + } + } + case _ => false; + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) { case Filter(condition, child) @@ -245,6 +265,67 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { condition = Some(newCondition))) } } + case a: Aggregate if aggregateExprsNestedInSubquery(a.aggregateExpressions) => + print("Matched on aggregate!!\n") + // extract out the aggregate expressions + + print("Named expressions are ") + print(s"${a.aggregateExpressions.map(_.getClass.getName).mkString(",")}\n") + + // split expressions depending on whether they contain an InSubquery + // with an aggregate expression + val (withInsubquery, withoutInsubquery) = + a.aggregateExpressions.partition(aggregateExprContainsNestedInSubquery(_)) + + // extract the aggregate expressions from withInsubquery + val inSubqueryMapping = withInsubquery.map { e => + val aggregateExpressions = e.collect { + case a: AggregateExpression => a + } + (e, aggregateExpressions) + } + val inSubqueryMap = inSubqueryMapping.toMap + val aggregateExprs = inSubqueryMapping.flatMap(_._2) + val aggregateExprAliases = aggregateExprs.zipWithIndex + .map(a => Alias(a._1, s"__aggregate_alias_${a._2}")()) + val aggregateExprAliasMap = aggregateExprs.zip(aggregateExprAliases).toMap + val aggregateExprAttrs = aggregateExprAliases.map(_.toAttribute) + val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap + + val newAggregateExpressions = a.aggregateExpressions.flatMap { ae => + // if this is an expression contain insubquery and aggregates, patch + // replace with just the aggregate + if (inSubqueryMap.contains(ae)) { + // replace the expression with the aliased aggregate exprs + inSubqueryMap(ae).map(aggregateExprAliasMap(_)) + } else { + Seq(ae) + } + } + + val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions) + val projList = a.aggregateExpressions.map { ae => + if (inSubqueryMap.contains(ae)) { + ae.transform { + // patch the aggregate expression with an attribute + case a: AggregateExpression if aggregateExprAttrMap.contains(a) => + aggregateExprAttrMap(a) + }.asInstanceOf[NamedExpression] + } else { + ae.toAttribute + } + } + // want to create Aggregate with withoutInsubquery expressions, plus + // aliased aggregate expressions + // Put on top of that a projection with attributes for withoutInsubquery + // expressions plus InSubquery expressions patched with attributes for + // aggregate expressions. Might also need to patch group-by expressions as well, + // not sure. + + print(s"withoutInsubquery ${withoutInsubquery}\n") + print(s"withInsubquery ${withInsubquery}\n") + print(s"inSubqueryMapping is ${inSubqueryMapping}\n") + apply(Project(projList, newAggregate)) case u: UnaryNode if u.expressions.exists( SubqueryExpression.hasInOrCorrelatedExistsSubquery) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 9e97c224736d8..4efb1dcdbe91c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2755,6 +2755,16 @@ class SubquerySuite extends QueryTest } } + test("stuffing") { + withTable("v1", "v2") { + sql("create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2)") + sql("create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 2)") + val df = sql("select col1, sum(col2) in (select c2 from v1) from v2 group by col1") + checkAnswer(df, + Row(1, false) :: Row(2, true) :: Nil) + } + } + test("SPARK-45580: Handle case where a nested subquery becomes an existence join") { withTempView("t1", "t2", "t3") { Seq((1), (2), (3), (7)).toDF("a").persist().createOrReplaceTempView("t1") From 7fe2a087642c59e0996370ab2589b038d5851f6d Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 20 Oct 2024 12:38:40 -0700 Subject: [PATCH 02/17] update --- .../sql/catalyst/optimizer/subquery.scala | 38 ++++++++----------- .../org/apache/spark/sql/SubquerySuite.scala | 28 +++++++++++--- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index bc76fac2e5fba..f6df9d0669bd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, IN_SUBQUERY, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY} +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION, OPTIMIZE_UNCORRELATED_IN_SUBQUERIES_IN_JOIN_CONDITION, @@ -267,15 +268,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } case a: Aggregate if aggregateExprsNestedInSubquery(a.aggregateExpressions) => print("Matched on aggregate!!\n") - // extract out the aggregate expressions - print("Named expressions are ") - print(s"${a.aggregateExpressions.map(_.getClass.getName).mkString(",")}\n") - - // split expressions depending on whether they contain an InSubquery - // with an aggregate expression - val (withInsubquery, withoutInsubquery) = - a.aggregateExpressions.partition(aggregateExprContainsNestedInSubquery(_)) + // find expressions with an in-subquery whose values contain aggregates + val withInsubquery = a.aggregateExpressions.filter(aggregateExprContainsNestedInSubquery(_)) + print(s"withInsubquery is ${withInsubquery}\n") // extract the aggregate expressions from withInsubquery val inSubqueryMapping = withInsubquery.map { e => @@ -286,15 +282,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } val inSubqueryMap = inSubqueryMapping.toMap val aggregateExprs = inSubqueryMapping.flatMap(_._2) - val aggregateExprAliases = aggregateExprs.zipWithIndex - .map(a => Alias(a._1, s"__aggregate_alias_${a._2}")()) + val aggregateExprAliases = aggregateExprs.map(a => Alias(a, toPrettySQL(a))()) val aggregateExprAliasMap = aggregateExprs.zip(aggregateExprAliases).toMap val aggregateExprAttrs = aggregateExprAliases.map(_.toAttribute) val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap + // create Aggregate operator without the in-subqueries that contain aggregates, + // just the aggregates themselves and the other aggregate expressions. val newAggregateExpressions = a.aggregateExpressions.flatMap { ae => - // if this is an expression contain insubquery and aggregates, patch - // replace with just the aggregate + // if this is expression contains an in-subquery with aggregates in values, + // replace with just the aggregates if (inSubqueryMap.contains(ae)) { // replace the expression with the aliased aggregate exprs inSubqueryMap(ae).map(aggregateExprAliasMap(_)) @@ -302,8 +299,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { Seq(ae) } } - val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions) + + // Create a projection with the in-subquery expressions that contain aggregates, replacing + // the aggregates with an attribute references to the output of the Aggregate operator. + // Also include the other output of the Aggregate operator. val projList = a.aggregateExpressions.map { ae => if (inSubqueryMap.contains(ae)) { ae.transform { @@ -315,16 +315,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { ae.toAttribute } } - // want to create Aggregate with withoutInsubquery expressions, plus - // aliased aggregate expressions - // Put on top of that a projection with attributes for withoutInsubquery - // expressions plus InSubquery expressions patched with attributes for - // aggregate expressions. Might also need to patch group-by expressions as well, - // not sure. - - print(s"withoutInsubquery ${withoutInsubquery}\n") - print(s"withInsubquery ${withInsubquery}\n") - print(s"inSubqueryMapping is ${inSubqueryMapping}\n") + + // reapply this rule, now with a Project as parent to the Aggregate apply(Project(projList, newAggregate)) case u: UnaryNode if u.expressions.exists( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 4efb1dcdbe91c..1f46d61de9adb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2757,11 +2757,29 @@ class SubquerySuite extends QueryTest test("stuffing") { withTable("v1", "v2") { - sql("create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2)") - sql("create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 2)") - val df = sql("select col1, sum(col2) in (select c2 from v1) from v2 group by col1") - checkAnswer(df, - Row(1, false) :: Row(2, true) :: Nil) + sql("""create or replace temp view v1 (c1, c2, c3) as values + |(1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8)""".stripMargin) + sql("""create or replace temp view v2 (col1, col2, col3) as values + |(1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1)""".stripMargin) + + val df1 = sql("select col1, sum(col2) in (select c3 from v1) from v2 group by col1") + checkAnswer(df1, + Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil) + + val df2 = sql("""select + | col1, + | sum(col2) in (select c3 from v1) and sum(col3) in (select c2 from v1) as x + |from v2 group by col1 + |order by col1""".stripMargin) + checkAnswer(df2, + Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil) + + val df3 = sql("""select col1, (sum(col2), sum(col3)) in (select c3, c2 from v1) as x + |from v2 + |group by col1 + |order by col1;""".stripMargin) + checkAnswer(df3, + Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil) } } From c96af366377a28256ccf6c425527d49464c15700 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 22 Oct 2024 09:39:19 -0700 Subject: [PATCH 03/17] Small cleanup --- .../apache/spark/sql/catalyst/optimizer/subquery.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index f6df9d0669bd1..f0c84d3f1713e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -267,19 +267,14 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } case a: Aggregate if aggregateExprsNestedInSubquery(a.aggregateExpressions) => - print("Matched on aggregate!!\n") - // find expressions with an in-subquery whose values contain aggregates val withInsubquery = a.aggregateExpressions.filter(aggregateExprContainsNestedInSubquery(_)) - print(s"withInsubquery is ${withInsubquery}\n") // extract the aggregate expressions from withInsubquery val inSubqueryMapping = withInsubquery.map { e => - val aggregateExpressions = e.collect { - case a: AggregateExpression => a - } - (e, aggregateExpressions) + (e, extractAggregateExpressions(e)) } + val inSubqueryMap = inSubqueryMapping.toMap val aggregateExprs = inSubqueryMapping.flatMap(_._2) val aggregateExprAliases = aggregateExprs.map(a => Alias(a, toPrettySQL(a))()) From 424d80392855d7c15522db21aa2a0813f1031a8a Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 22 Oct 2024 16:26:25 -0700 Subject: [PATCH 04/17] Update --- .../apache/spark/sql/catalyst/optimizer/subquery.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index f0c84d3f1713e..8bd57324f3491 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -297,13 +297,15 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions) // Create a projection with the in-subquery expressions that contain aggregates, replacing - // the aggregates with an attribute references to the output of the Aggregate operator. - // Also include the other output of the Aggregate operator. + // the aggregate expressions with attribute references to the output of the Aggregate + // operator. Also include the other output of the Aggregate operator. val projList = a.aggregateExpressions.map { ae => + // if this expression contains an in-subquery that uses an aggregate, we + // need to do something special if (inSubqueryMap.contains(ae)) { ae.transform { - // patch the aggregate expression with an attribute - case a: AggregateExpression if aggregateExprAttrMap.contains(a) => + // patch any aggregate expression with its corresponding attribute + case a: AggregateExpression => aggregateExprAttrMap(a) }.asInstanceOf[NamedExpression] } else { From 2b1a3762d4f27d5dc0b6ec60acb7200459eefa90 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 22 Oct 2024 17:38:33 -0700 Subject: [PATCH 05/17] Add catalyst test --- .../catalyst/optimizer/RewriteSubquerySuite.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index 17547bbcb9402..5e10ed7f4ea73 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -21,8 +21,9 @@ import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -79,4 +80,14 @@ class RewriteSubquerySuite extends PlanTest { Optimize.executeAndTrack(query.analyze, tracker) assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0) } + + test("stuffing") { + val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int) + val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int) + val query = relation2.select(sum($"col2").in(ListQuery(relation1.select($"c3")))) + + val optimized = Optimize.execute(query.analyze) + val join = optimized.find(_.isInstanceOf[Join]).get.asInstanceOf[Join] + assert(!join.condition.get.exists(_.isInstanceOf[AggregateExpression])) + } } From 9c443b0eaf40a415396f2acbc16dd3c790a90dce Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 23 Oct 2024 07:55:04 -0700 Subject: [PATCH 06/17] Fix names --- .../spark/sql/catalyst/optimizer/subquery.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 8bd57324f3491..57e7fb01bbc89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -116,18 +116,18 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } - def aggregateExprsNestedInSubquery(exprs: Seq[Expression]): Boolean = { + def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = { exprs.exists { expr => - aggregateExprContainsNestedInSubquery(expr) + exprContainsAggregateInSubquery(expr) } } - def aggregateExprContainsNestedInSubquery(expr: Expression): Boolean = { + def exprContainsAggregateInSubquery(expr: Expression): Boolean = { expr.exists { case InSubquery(values, _) => values.exists { v => v.exists { - case a: AggregateExpression => true + case _: AggregateExpression => true case _ => false } } @@ -266,9 +266,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { condition = Some(newCondition))) } } - case a: Aggregate if aggregateExprsNestedInSubquery(a.aggregateExpressions) => + case a: Aggregate if exprsContainsAggregateInSubquery(a.aggregateExpressions) => // find expressions with an in-subquery whose values contain aggregates - val withInsubquery = a.aggregateExpressions.filter(aggregateExprContainsNestedInSubquery(_)) + val withInsubquery = a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_)) // extract the aggregate expressions from withInsubquery val inSubqueryMapping = withInsubquery.map { e => From 46d43fdd74649fe114d8d9d29487388249a7677c Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 23 Oct 2024 08:35:56 -0700 Subject: [PATCH 07/17] Clean up some comments --- .../sql/catalyst/optimizer/subquery.scala | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 57e7fb01bbc89..c79b99f477813 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -267,28 +267,33 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } case a: Aggregate if exprsContainsAggregateInSubquery(a.aggregateExpressions) => - // find expressions with an in-subquery whose values contain aggregates - val withInsubquery = a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_)) + // find expressions with an IN-subquery whose left-hand operand contains aggregates + val withInSubquery = a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_)) - // extract the aggregate expressions from withInsubquery - val inSubqueryMapping = withInsubquery.map { e => + // extract the aggregate expressions from withInSubquery + val inSubqueryMapping = withInSubquery.map { e => (e, extractAggregateExpressions(e)) } val inSubqueryMap = inSubqueryMapping.toMap + // get all aggregate expressions found in left-hand operands of IN-subqueries val aggregateExprs = inSubqueryMapping.flatMap(_._2) + // create aliases for each above aggregate expression val aggregateExprAliases = aggregateExprs.map(a => Alias(a, toPrettySQL(a))()) + // create a mapping from each aggregate expression to its alias val aggregateExprAliasMap = aggregateExprs.zip(aggregateExprAliases).toMap + // create attributes from those aliases of aggregate expressions val aggregateExprAttrs = aggregateExprAliases.map(_.toAttribute) + // create a mapping from aggregate expressions to attributes val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap - // create Aggregate operator without the in-subqueries that contain aggregates, - // just the aggregates themselves and the other aggregate expressions. + // create Aggregate operator without the offending IN-subqueries, just + // the aggregates themselves and all the other aggregate expressions. val newAggregateExpressions = a.aggregateExpressions.flatMap { ae => - // if this is expression contains an in-subquery with aggregates in values, - // replace with just the aggregates + // if this expression contains IN-subqueries with aggregates in the left-hand + // operand, replace with just the aggregates if (inSubqueryMap.contains(ae)) { - // replace the expression with the aliased aggregate exprs + // replace the expression with an aliased aggregate expression inSubqueryMap(ae).map(aggregateExprAliasMap(_)) } else { Seq(ae) @@ -296,11 +301,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions) - // Create a projection with the in-subquery expressions that contain aggregates, replacing + // Create a projection with the IN-subquery expressions that contain aggregates, replacing // the aggregate expressions with attribute references to the output of the Aggregate // operator. Also include the other output of the Aggregate operator. val projList = a.aggregateExpressions.map { ae => - // if this expression contains an in-subquery that uses an aggregate, we + // if this expression contains an IN-subquery that uses an aggregate, we // need to do something special if (inSubqueryMap.contains(ae)) { ae.transform { From ca4dba8723eb26bfaf1a5f36c15da9322c93f640 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 23 Oct 2024 09:42:45 -0700 Subject: [PATCH 08/17] Cleanup --- .../sql/catalyst/optimizer/subquery.scala | 2 +- .../org/apache/spark/sql/SubquerySuite.scala | 56 +++++++++---------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index c79b99f477813..7956c322027dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -287,7 +287,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // create a mapping from aggregate expressions to attributes val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap - // create Aggregate operator without the offending IN-subqueries, just + // create an Aggregate node without the offending IN-subqueries, just // the aggregates themselves and all the other aggregate expressions. val newAggregateExpressions = a.aggregateExpressions.flatMap { ae => // if this expression contains IN-subqueries with aggregates in the left-hand diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 1f46d61de9adb..1f0c06003d1e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2755,34 +2755,6 @@ class SubquerySuite extends QueryTest } } - test("stuffing") { - withTable("v1", "v2") { - sql("""create or replace temp view v1 (c1, c2, c3) as values - |(1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8)""".stripMargin) - sql("""create or replace temp view v2 (col1, col2, col3) as values - |(1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1)""".stripMargin) - - val df1 = sql("select col1, sum(col2) in (select c3 from v1) from v2 group by col1") - checkAnswer(df1, - Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil) - - val df2 = sql("""select - | col1, - | sum(col2) in (select c3 from v1) and sum(col3) in (select c2 from v1) as x - |from v2 group by col1 - |order by col1""".stripMargin) - checkAnswer(df2, - Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil) - - val df3 = sql("""select col1, (sum(col2), sum(col3)) in (select c3, c2 from v1) as x - |from v2 - |group by col1 - |order by col1;""".stripMargin) - checkAnswer(df3, - Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil) - } - } - test("SPARK-45580: Handle case where a nested subquery becomes an existence join") { withTempView("t1", "t2", "t3") { Seq((1), (2), (3), (7)).toDF("a").persist().createOrReplaceTempView("t1") @@ -2828,4 +2800,32 @@ class SubquerySuite extends QueryTest checkAnswer(df3, Row(7)) } } + + test("stuffing") { + withTable("v1", "v2") { + sql("""CREATE OR REPLACE TEMP VIEW v1 (c1, c2, c3) AS VALUES + |(1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8)""".stripMargin) + sql("""CREATE OR REPLACE TEMP VIEW v2 (col1, col2, col3) AS VALUES + |(1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1)""".stripMargin) + + val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2 GROUP BY col1") + checkAnswer(df1, + Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil) + + val df2 = sql("""SELECT + | col1, + | SUM(col2) IN (SELECT c3 FROM v1) and SUM(col3) IN (SELECT c2 FROM v1) AS x + |FROM v2 GROUP BY col1 + |ORDER BY col1""".stripMargin) + checkAnswer(df2, + Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil) + + val df3 = sql("""SELECT col1, (SUM(col2), SUM(col3)) IN (SELECT c3, c2 FROM v1) AS x + |FROM v2 + |GROUP BY col1 + |ORDER BY col1""".stripMargin) + checkAnswer(df3, + Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil) + } + } } From e0fc82f8baa7a9560417504b6a90e4918266e02e Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 23 Oct 2024 10:01:43 -0700 Subject: [PATCH 09/17] Rename tests --- .../spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala | 2 +- .../src/test/scala/org/apache/spark/sql/SubquerySuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index 5e10ed7f4ea73..67e2312f7670e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -81,7 +81,7 @@ class RewriteSubquerySuite extends PlanTest { assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0) } - test("stuffing") { + test("SPARK-50091: Don't put aggregate expression in join condition") { val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int) val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int) val query = relation2.select(sum($"col2").in(ListQuery(relation1.select($"c3")))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 1f0c06003d1e0..6a7276438987a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2801,7 +2801,7 @@ class SubquerySuite extends QueryTest } } - test("stuffing") { + test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") { withTable("v1", "v2") { sql("""CREATE OR REPLACE TEMP VIEW v1 (c1, c2, c3) AS VALUES |(1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8)""".stripMargin) From 3e52a12f97ee49186492c37db1dee1b54e1b776e Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 14 Nov 2024 18:31:10 -0800 Subject: [PATCH 10/17] Update --- .../sql/catalyst/optimizer/subquery.scala | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 7956c322027dd..77a70a5ea2c7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -289,33 +289,28 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // create an Aggregate node without the offending IN-subqueries, just // the aggregates themselves and all the other aggregate expressions. - val newAggregateExpressions = a.aggregateExpressions.flatMap { ae => + val newAggregateExpressions = a.aggregateExpressions.flatMap { // if this expression contains IN-subqueries with aggregates in the left-hand // operand, replace with just the aggregates - if (inSubqueryMap.contains(ae)) { + case ae: Expression if inSubqueryMap.contains(ae) => // replace the expression with an aliased aggregate expression inSubqueryMap(ae).map(aggregateExprAliasMap(_)) - } else { - Seq(ae) - } + case ae @ _ => Seq(ae) } val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions) // Create a projection with the IN-subquery expressions that contain aggregates, replacing // the aggregate expressions with attribute references to the output of the Aggregate // operator. Also include the other output of the Aggregate operator. - val projList = a.aggregateExpressions.map { ae => + val projList = a.aggregateExpressions.map { // if this expression contains an IN-subquery that uses an aggregate, we // need to do something special - if (inSubqueryMap.contains(ae)) { + case ae: Expression if inSubqueryMap.contains(ae) => ae.transform { // patch any aggregate expression with its corresponding attribute - case a: AggregateExpression => - aggregateExprAttrMap(a) + case a: AggregateExpression => aggregateExprAttrMap(a) }.asInstanceOf[NamedExpression] - } else { - ae.toAttribute - } + case ae @ _ => ae.toAttribute } // reapply this rule, now with a Project as parent to the Aggregate From 1db531606ba23732efa69747881487ca3ab39d25 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 24 Nov 2024 13:03:42 -0800 Subject: [PATCH 11/17] Review updates --- .../sql/catalyst/optimizer/subquery.scala | 174 ++++++++++++++++-- 1 file changed, 155 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 77a70a5ea2c7d..6c2da4067f63b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -266,55 +266,191 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { condition = Some(newCondition))) } } + + // Handle the case where the left-hand side of an IN-subquery contains an aggregate. + // + // This handler pulls up any expression containing such an IN-subquery into a new Project + // node and then re-enters RewritePredicateSubquery#apply, where the new Project node + // will be handled by the Unary node handler. The Unary node handler will transform the + // plan into a join. Without this pre-transformation, the Unary node handler would + // create a join with an aggregate expression in the join condition, which is illegal + // (see SPARK-50091). + // + // For example: + // + // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) as x + // FROM v2 GROUP BY col1; + // + // The above query has this plan on entry to RewritePredicateSubquery#apply: + // + // Aggregate [col1#28], [col1#28, sum(col2#29) IN (list#24 []) AS x#25] + // : +- LocalRelation [c2#35L] + // +- LocalRelation [col1#28, col2#29] + // + // Note that the Aggregate node contains the IN-subquery and the left-hand + // side of the IN-subquery is an aggregate expression (sum(col2#28)). + // + // This handler transforms the above plan into the following: + // + // Project [col1#28, sum(col2)#36L IN (list#24 []) AS x#25] + // : +- LocalRelation [c2#35L] + // +- Aggregate [col1#28], [col1#28, sum(col2#29) AS sum(col2)#36L] + // +- LocalRelation [col1#28, col2#29] + // + // The transformation pulled the IN-subquery up into a Project. The left-hand side of the + // In-subquery is now an attribute (sum(col2)#36L) that refers to the actual aggregation + // which is still performed in the Aggregate node (sum(col2#28) AS sum(col2)#36L). + // + // Note that if the IN-subquery is nested in a larger expression, that entire larger + // expression is pulled up into the Project. For example: + // + // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x + // FROM v2; + // + // The input to RewritePredicateSubquery#apply is the following plan: + // + // Aggregate [(sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29] + // : +- LocalRelation [c3#44L] + // +- LocalRelation [col2#34, col3#35] + // + // This handler transforms the plan into: + // + // Project [(sum(col2)#45L IN (list#28 []) AND (sum(col3)#46L > -1)) AS x#29] + // : +- LocalRelation [c3#44L] + // +- Aggregate [sum(col2#34) AS sum(col2)#45L, sum(col3#35) AS sum(col3)#46L] + // +- LocalRelation [col2#34, col3#35] + // + // Note that the entire AND expression was pulled up into the Project, but the Aggregate + // node continues to perform the aggregations (but without the IN-subquery expression). case a: Aggregate if exprsContainsAggregateInSubquery(a.aggregateExpressions) => - // find expressions with an IN-subquery whose left-hand operand contains aggregates + // Find any interesting expressions from Aggregate.aggregateExpressions. + // + // An interesting expression is one that contains an IN-subquery whose left-hand + // operand contains aggregates. For example: + // + // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) + // FROM v2 GROUP BY col1; + // + // withInSubquery will be a List containing a single Alias expression: + // + // List(sum(col2#12) IN (list#8 []) AS (...)#19) val withInSubquery = a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_)) - // extract the aggregate expressions from withInSubquery + // Extract the aggregate expressions from each interesting expression. This will include + // any aggregate expressions that were not part of the IN-subquery but were part + // of the larger containing expression. val inSubqueryMapping = withInSubquery.map { e => (e, extractAggregateExpressions(e)) } + // Map each interesting expression to its contained aggregate expressions. + // + // Example #1: + // + // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) + // FROM v2 GROUP BY col1; + // + // inSubqueryMap will have a single entry mapping an Alias expression to a Vector + // with a single aggregate expression: + // + // Map( + // sum(col2#100) IN (list []) AS (...)#107 -> Vector(sum(col2#100)) + // ) + // + // Example #2: + // + // SELECT (SUM(col1), SUM(col2)) IN (SELECT c1, c2 FROM v1) + // FROM v2; + // + // inSubqueryMap will have a single entry mapping an Alias expression to a Vector + // with two aggregate expressions: + // + // Map( + // named_struct(_0, sum(col1#169), _1, sum(col2#170)) IN (list#166 []) AS (...)#179 + // -> Vector(sum(col1#169), sum(col2#170)) + // ) + // + // Example #3: + // + // select SUM(col1) IN (SELECT c1 FROM v1), SUM(col2) IN (SELECT c2 FROM v1) + // FROM v2; + // + // inSubqueryMap will have two entries, each mapping an Alias expression to a Vector + // with a single aggregate expression: + // + // Map( + // sum(col1#193) IN (list#189 []) AS (...)#207 -> Vector(sum(col1#193)), + // sum(col2#194) IN (list#190 []) AS (...)#208 -> Vector(sum(col2#194)) + // ) + // + // Example #5: + // + // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x + // FROM v2; + // + // inSubqueryMap will contain a single AND expression that maps to two aggregate + // expressions, even though only one of those aggregate expressions is used as + // the left-hand operand of the IN-subquery expression. + // + // Map( + // (sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29 + // -> Vector(sum(col2#34), sum(col3#35)) + // ) + // + // The keys of inSubqueryMap will be used to determine which expressions in + // the old Aggregate node are interesting. The values of inSubqueryMap, after + // being wrapped in Alias expressions, will replace their associated interesting + // expressions in a new Aggregate node. val inSubqueryMap = inSubqueryMapping.toMap - // get all aggregate expressions found in left-hand operands of IN-subqueries + + // Get all aggregate expressions associated with interesting expressions. val aggregateExprs = inSubqueryMapping.flatMap(_._2) - // create aliases for each above aggregate expression + // Create aliases for each above aggregate expression. We can't use the aggregate + // expressions directly in the new Aggregate node because Aggregate.aggregateExpressions + // has the type Seq[NamedExpression]. val aggregateExprAliases = aggregateExprs.map(a => Alias(a, toPrettySQL(a))()) - // create a mapping from each aggregate expression to its alias + // Create a mapping from each aggregate expression to its alias. val aggregateExprAliasMap = aggregateExprs.zip(aggregateExprAliases).toMap - // create attributes from those aliases of aggregate expressions + // Create attributes from those aliases of aggregate expressions. These attributes + // will be used in the new Project node to refer to the aliased aggregate expressions + // in the new Aggregate node. val aggregateExprAttrs = aggregateExprAliases.map(_.toAttribute) - // create a mapping from aggregate expressions to attributes + // Create a mapping from aggregate expressions to attributes. This will be + // used when patching the interesting expressions after they are pulled up + // into the new Project node: aggregate expressions will be replaced by attributes. val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap - // create an Aggregate node without the offending IN-subqueries, just - // the aggregates themselves and all the other aggregate expressions. + // Create an Aggregate node without the interesting expressions, just + // the associated aggregate expressions plus any other group-by or aggregate expressions + // that were not involved in the interesting expressions. val newAggregateExpressions = a.aggregateExpressions.flatMap { - // if this expression contains IN-subqueries with aggregates in the left-hand - // operand, replace with just the aggregates + // If this expression contains IN-subqueries with aggregates in the left-hand + // operand, replace with just the aggregates. case ae: Expression if inSubqueryMap.contains(ae) => - // replace the expression with an aliased aggregate expression + // Replace the expression with an aliased aggregate expression. inSubqueryMap(ae).map(aggregateExprAliasMap(_)) - case ae @ _ => Seq(ae) + case ae => Seq(ae) } val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions) // Create a projection with the IN-subquery expressions that contain aggregates, replacing - // the aggregate expressions with attribute references to the output of the Aggregate + // the aggregate expressions with attribute references to the output of the new Aggregate // operator. Also include the other output of the Aggregate operator. val projList = a.aggregateExpressions.map { - // if this expression contains an IN-subquery that uses an aggregate, we + // If this expression contains an IN-subquery that uses an aggregate, we // need to do something special case ae: Expression if inSubqueryMap.contains(ae) => ae.transform { - // patch any aggregate expression with its corresponding attribute + // Patch any aggregate expression with its corresponding attribute. case a: AggregateExpression => aggregateExprAttrMap(a) }.asInstanceOf[NamedExpression] - case ae @ _ => ae.toAttribute + case ae => ae.toAttribute } + val newProj = Project(projList, newAggregate) - // reapply this rule, now with a Project as parent to the Aggregate - apply(Project(projList, newAggregate)) + // Reapply this rule, but now with all interesting expressions + // from Aggregate.aggregateExpressions pulled up into a Project node. + apply(newProj) case u: UnaryNode if u.expressions.exists( SubqueryExpression.hasInOrCorrelatedExistsSubquery) => From f6aa964fad426fabcb0904607b8b6a4db071302e Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 24 Nov 2024 15:38:35 -0800 Subject: [PATCH 12/17] Comment update --- .../sql/catalyst/optimizer/subquery.scala | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 6c2da4067f63b..9429419fdcfce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -270,11 +270,15 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Handle the case where the left-hand side of an IN-subquery contains an aggregate. // // This handler pulls up any expression containing such an IN-subquery into a new Project - // node and then re-enters RewritePredicateSubquery#apply, where the new Project node - // will be handled by the Unary node handler. The Unary node handler will transform the - // plan into a join. Without this pre-transformation, the Unary node handler would - // create a join with an aggregate expression in the join condition, which is illegal - // (see SPARK-50091). + // node, replacing aggregate expressions with attributes, and then re-enters + // RewritePredicateSubquery#apply, where the new Project node will be handled + // by the Unary node handler. + // + // The Unary node handler uses the left-hand side of the IN-subquery in a + // join condition. Thus, without this pre-transformation, the join condition + // contains an aggregate, which is illegal. With this pre-transformation, the + // join condition contains an attribute from the left-hand side of the + // IN-subquery contained in the Project node. // // For example: // @@ -298,10 +302,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // +- LocalRelation [col1#28, col2#29] // // The transformation pulled the IN-subquery up into a Project. The left-hand side of the - // In-subquery is now an attribute (sum(col2)#36L) that refers to the actual aggregation - // which is still performed in the Aggregate node (sum(col2#28) AS sum(col2)#36L). + // IN-subquery is now an attribute (sum(col2)#36L) that refers to the actual aggregation + // which is still performed in the Aggregate node (sum(col2#28) AS sum(col2)#36L). The Unary + // node handler will use that attribute in the join condition (rather than the aggregate + // expression). // - // Note that if the IN-subquery is nested in a larger expression, that entire larger + // If the IN-subquery is nested in a larger expression, that entire larger // expression is pulled up into the Project. For example: // // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x From cc6384bb6cb7c663f307a984cf5ebc0ce51becf4 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 26 Nov 2024 17:44:30 -0800 Subject: [PATCH 13/17] Address review comments --- .../org/apache/spark/sql/catalyst/optimizer/subquery.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 9429419fdcfce..9c5317084b2b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -292,7 +292,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // +- LocalRelation [col1#28, col2#29] // // Note that the Aggregate node contains the IN-subquery and the left-hand - // side of the IN-subquery is an aggregate expression (sum(col2#28)). + // side of the IN-subquery is an aggregate expression (sum(col2#29)). // // This handler transforms the above plan into the following: // @@ -303,7 +303,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // // The transformation pulled the IN-subquery up into a Project. The left-hand side of the // IN-subquery is now an attribute (sum(col2)#36L) that refers to the actual aggregation - // which is still performed in the Aggregate node (sum(col2#28) AS sum(col2)#36L). The Unary + // which is still performed in the Aggregate node (sum(col2#29) AS sum(col2)#36L). The Unary // node handler will use that attribute in the join condition (rather than the aggregate // expression). // From 93d98e777cee56d57638c51359187147e08e48bc Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 3 Jan 2025 17:53:14 -0800 Subject: [PATCH 14/17] Move unary node handler to its own utility method --- .../sql/catalyst/optimizer/subquery.scala | 93 ++++++++++--------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 9c5317084b2b6..6ee11009ba943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -270,9 +270,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Handle the case where the left-hand side of an IN-subquery contains an aggregate. // // This handler pulls up any expression containing such an IN-subquery into a new Project - // node, replacing aggregate expressions with attributes, and then re-enters - // RewritePredicateSubquery#apply, where the new Project node will be handled - // by the Unary node handler. + // node, replacing aggregate expressions with attributes. The new Project node will be + // handled by the Unary node handler. // // The Unary node handler uses the left-hand side of the IN-subquery in a // join condition. Thus, without this pre-transformation, the join condition @@ -454,50 +453,56 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } val newProj = Project(projList, newAggregate) - // Reapply this rule, but now with all interesting expressions + // Call the unary node handler, but now with all interesting expressions // from Aggregate.aggregateExpressions pulled up into a Project node. - apply(newProj) + handleUnaryNode(newProj) case u: UnaryNode if u.expressions.exists( - SubqueryExpression.hasInOrCorrelatedExistsSubquery) => - var newChild = u.child - var introducedAttrs = Seq.empty[Attribute] - val updatedNode = u.mapExpressions(expr => { - val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild) - newChild = p - introducedAttrs ++= newAttrs - // The newExpr can not be None - newExpr.get - }).withNewChildren(Seq(newChild)) - updatedNode match { - case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) => - // If we have introduced new `exists`-attributes that are referenced by - // aggregateExpressions within a non-aggregateFunction expression, we wrap them in - // first() aggregate function. first() is Spark's executable version of any_value() - // aggregate function. - // We do this to keep the aggregation valid, i.e avoid references outside of aggregate - // functions that are not in grouping expressions. - // Note that the same `exists` attr will never appear in groupingExpressions due to - // PullOutGroupingExpressions rule. - // Also note: the value of `exists` is functionally determined by grouping expressions, - // so applying any aggregate function is semantically safe. - val aggFunctionReferences = a.aggregateExpressions. - flatMap(extractAggregateExpressions). - flatMap(_.references).toSet - val nonAggFuncReferences = - a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains) - val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains) - - // Replace all eligible `exists` by `First(exists)` among aggregateExpressions. - val newAggregateExpressions = a.aggregateExpressions.map { aggExpr => - aggExpr.transformUp { - case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) => - new First(attr).toAggregateExpression() - }.asInstanceOf[NamedExpression] - } - a.copy(aggregateExpressions = newAggregateExpressions) - case _ => updatedNode - } + SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u) + } + + /** + * Handle the unary node case + */ + private def handleUnaryNode(u: UnaryNode): LogicalPlan = { + var newChild = u.child + var introducedAttrs = Seq.empty[Attribute] + val updatedNode = u.mapExpressions(expr => { + val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild) + newChild = p + introducedAttrs ++= newAttrs + // The newExpr can not be None + newExpr.get + }).withNewChildren(Seq(newChild)) + updatedNode match { + case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) => + // If we have introduced new `exists`-attributes that are referenced by + // aggregateExpressions within a non-aggregateFunction expression, we wrap them in + // first() aggregate function. first() is Spark's executable version of any_value() + // aggregate function. + // We do this to keep the aggregation valid, i.e avoid references outside of aggregate + // functions that are not in grouping expressions. + // Note that the same `exists` attr will never appear in groupingExpressions due to + // PullOutGroupingExpressions rule. + // Also note: the value of `exists` is functionally determined by grouping expressions, + // so applying any aggregate function is semantically safe. + val aggFunctionReferences = a.aggregateExpressions. + flatMap(extractAggregateExpressions). + flatMap(_.references).toSet + val nonAggFuncReferences = + a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains) + val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains) + + // Replace all eligible `exists` by `First(exists)` among aggregateExpressions. + val newAggregateExpressions = a.aggregateExpressions.map { aggExpr => + aggExpr.transformUp { + case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) => + new First(attr).toAggregateExpression() + }.asInstanceOf[NamedExpression] + } + a.copy(aggregateExpressions = newAggregateExpressions) + case _ => updatedNode + } } /** From cb4066a057655b3a8546ffb5c5f0b98de4685c8a Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 16 Jan 2025 16:03:09 -0800 Subject: [PATCH 15/17] Respond to review comments --- .../sql/catalyst/optimizer/subquery.scala | 202 +++--------------- 1 file changed, 33 insertions(+), 169 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 6ee11009ba943..378081221c8c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -27,11 +27,11 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery +import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, IN_SUBQUERY, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY} -import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION, OPTIMIZE_UNCORRELATED_IN_SUBQUERIES_IN_JOIN_CONDITION, @@ -269,9 +269,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Handle the case where the left-hand side of an IN-subquery contains an aggregate. // - // This handler pulls up any expression containing such an IN-subquery into a new Project - // node, replacing aggregate expressions with attributes. The new Project node will be - // handled by the Unary node handler. + // If an Aggregate node contains such an IN-subquery, this handler will pull up all + // expressions from the Aggregate node into a new Project node. The new Project node + // will then be handled by the Unary node handler. // // The Unary node handler uses the left-hand side of the IN-subquery in a // join condition. Thus, without this pre-transformation, the join condition @@ -281,180 +281,44 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // // For example: // - // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) as x - // FROM v2 GROUP BY col1; + // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x + // FROM v2; // // The above query has this plan on entry to RewritePredicateSubquery#apply: // - // Aggregate [col1#28], [col1#28, sum(col2#29) IN (list#24 []) AS x#25] - // : +- LocalRelation [c2#35L] - // +- LocalRelation [col1#28, col2#29] + // Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13] + // : +- LocalRelation [c3#28L] + // +- LocalRelation [col2#18, col3#19] // // Note that the Aggregate node contains the IN-subquery and the left-hand - // side of the IN-subquery is an aggregate expression (sum(col2#29)). + // side of the IN-subquery is an aggregate expression sum(col2#18)). // // This handler transforms the above plan into the following: + // scalastyle:off line.size.limit // - // Project [col1#28, sum(col2)#36L IN (list#24 []) AS x#25] - // : +- LocalRelation [c2#35L] - // +- Aggregate [col1#28], [col1#28, sum(col2#29) AS sum(col2)#36L] - // +- LocalRelation [col1#28, col2#29] - // - // The transformation pulled the IN-subquery up into a Project. The left-hand side of the - // IN-subquery is now an attribute (sum(col2)#36L) that refers to the actual aggregation - // which is still performed in the Aggregate node (sum(col2#29) AS sum(col2)#36L). The Unary - // node handler will use that attribute in the join condition (rather than the aggregate - // expression). - // - // If the IN-subquery is nested in a larger expression, that entire larger - // expression is pulled up into the Project. For example: - // - // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x - // FROM v2; - // - // The input to RewritePredicateSubquery#apply is the following plan: - // - // Aggregate [(sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29] - // : +- LocalRelation [c3#44L] - // +- LocalRelation [col2#34, col3#35] + // Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13] + // : +- LocalRelation [c3#28L] + // +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L] + // +- LocalRelation [col2#18, col3#19] // - // This handler transforms the plan into: - // - // Project [(sum(col2)#45L IN (list#28 []) AND (sum(col3)#46L > -1)) AS x#29] - // : +- LocalRelation [c3#44L] - // +- Aggregate [sum(col2#34) AS sum(col2)#45L, sum(col3#35) AS sum(col3)#46L] - // +- LocalRelation [col2#34, col3#35] - // - // Note that the entire AND expression was pulled up into the Project, but the Aggregate - // node continues to perform the aggregations (but without the IN-subquery expression). - case a: Aggregate if exprsContainsAggregateInSubquery(a.aggregateExpressions) => - // Find any interesting expressions from Aggregate.aggregateExpressions. - // - // An interesting expression is one that contains an IN-subquery whose left-hand - // operand contains aggregates. For example: - // - // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) - // FROM v2 GROUP BY col1; - // - // withInSubquery will be a List containing a single Alias expression: - // - // List(sum(col2#12) IN (list#8 []) AS (...)#19) - val withInSubquery = a.aggregateExpressions.filter(exprContainsAggregateInSubquery(_)) - - // Extract the aggregate expressions from each interesting expression. This will include - // any aggregate expressions that were not part of the IN-subquery but were part - // of the larger containing expression. - val inSubqueryMapping = withInSubquery.map { e => - (e, extractAggregateExpressions(e)) - } - - // Map each interesting expression to its contained aggregate expressions. - // - // Example #1: - // - // SELECT col1, SUM(col2) IN (SELECT c2 FROM v1) - // FROM v2 GROUP BY col1; - // - // inSubqueryMap will have a single entry mapping an Alias expression to a Vector - // with a single aggregate expression: - // - // Map( - // sum(col2#100) IN (list []) AS (...)#107 -> Vector(sum(col2#100)) - // ) - // - // Example #2: - // - // SELECT (SUM(col1), SUM(col2)) IN (SELECT c1, c2 FROM v1) - // FROM v2; - // - // inSubqueryMap will have a single entry mapping an Alias expression to a Vector - // with two aggregate expressions: - // - // Map( - // named_struct(_0, sum(col1#169), _1, sum(col2#170)) IN (list#166 []) AS (...)#179 - // -> Vector(sum(col1#169), sum(col2#170)) - // ) - // - // Example #3: - // - // select SUM(col1) IN (SELECT c1 FROM v1), SUM(col2) IN (SELECT c2 FROM v1) - // FROM v2; - // - // inSubqueryMap will have two entries, each mapping an Alias expression to a Vector - // with a single aggregate expression: - // - // Map( - // sum(col1#193) IN (list#189 []) AS (...)#207 -> Vector(sum(col1#193)), - // sum(col2#194) IN (list#190 []) AS (...)#208 -> Vector(sum(col2#194)) - // ) - // - // Example #5: - // - // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x - // FROM v2; - // - // inSubqueryMap will contain a single AND expression that maps to two aggregate - // expressions, even though only one of those aggregate expressions is used as - // the left-hand operand of the IN-subquery expression. - // - // Map( - // (sum(col2#34) IN (list#28 []) AND (sum(col3#35) > -1)) AS x#29 - // -> Vector(sum(col2#34), sum(col3#35)) - // ) - // - // The keys of inSubqueryMap will be used to determine which expressions in - // the old Aggregate node are interesting. The values of inSubqueryMap, after - // being wrapped in Alias expressions, will replace their associated interesting - // expressions in a new Aggregate node. - val inSubqueryMap = inSubqueryMapping.toMap - - // Get all aggregate expressions associated with interesting expressions. - val aggregateExprs = inSubqueryMapping.flatMap(_._2) - // Create aliases for each above aggregate expression. We can't use the aggregate - // expressions directly in the new Aggregate node because Aggregate.aggregateExpressions - // has the type Seq[NamedExpression]. - val aggregateExprAliases = aggregateExprs.map(a => Alias(a, toPrettySQL(a))()) - // Create a mapping from each aggregate expression to its alias. - val aggregateExprAliasMap = aggregateExprs.zip(aggregateExprAliases).toMap - // Create attributes from those aliases of aggregate expressions. These attributes - // will be used in the new Project node to refer to the aliased aggregate expressions - // in the new Aggregate node. - val aggregateExprAttrs = aggregateExprAliases.map(_.toAttribute) - // Create a mapping from aggregate expressions to attributes. This will be - // used when patching the interesting expressions after they are pulled up - // into the new Project node: aggregate expressions will be replaced by attributes. - val aggregateExprAttrMap = aggregateExprs.zip(aggregateExprAttrs).toMap - - // Create an Aggregate node without the interesting expressions, just - // the associated aggregate expressions plus any other group-by or aggregate expressions - // that were not involved in the interesting expressions. - val newAggregateExpressions = a.aggregateExpressions.flatMap { - // If this expression contains IN-subqueries with aggregates in the left-hand - // operand, replace with just the aggregates. - case ae: Expression if inSubqueryMap.contains(ae) => - // Replace the expression with an aliased aggregate expression. - inSubqueryMap(ae).map(aggregateExprAliasMap(_)) - case ae => Seq(ae) - } - val newAggregate = a.copy(aggregateExpressions = newAggregateExpressions) - - // Create a projection with the IN-subquery expressions that contain aggregates, replacing - // the aggregate expressions with attribute references to the output of the new Aggregate - // operator. Also include the other output of the Aggregate operator. - val projList = a.aggregateExpressions.map { - // If this expression contains an IN-subquery that uses an aggregate, we - // need to do something special - case ae: Expression if inSubqueryMap.contains(ae) => - ae.transform { - // Patch any aggregate expression with its corresponding attribute. - case a: AggregateExpression => aggregateExprAttrMap(a) - }.asInstanceOf[NamedExpression] - case ae => ae.toAttribute - } - val newProj = Project(projList, newAggregate) - - // Call the unary node handler, but now with all interesting expressions - // from Aggregate.aggregateExpressions pulled up into a Project node. + // scalastyle:on + // Note that both the IN-subquery and the greater-than expressions have been + // pulled up into the Project node. These expressions use attributes + // (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations + // which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)). + case p @ PhysicalAggregation( + groupingExpressions, aggregateExpressions, resultExpressions, child) + if exprsContainsAggregateInSubquery(p.expressions) => + val aggExprs = aggregateExpressions.map( + ae => Alias(ae, "_aggregateexpression")(ae.resultId)) + val aggExprIds = aggExprs.map(_.exprId).toSet + val resExprs = resultExpressions.map(_.transform { + case a: AttributeReference if aggExprIds.contains(a.exprId) => + a.withName("_aggregateexpression") + }.asInstanceOf[NamedExpression]) + // Rewrite the projection and the aggregate separately and then piece them together. + val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child) + val newProj = Project(resExprs, newAgg) handleUnaryNode(newProj) case u: UnaryNode if u.expressions.exists( From b5ee4669754c54faaba1d9639799d59dc92f90ea Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 20 Jan 2025 15:57:09 -0800 Subject: [PATCH 16/17] Make test more explicit --- .../optimizer/RewriteSubquerySuite.scala | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index 67e2312f7670e..c45a761353c85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.LongType class RewriteSubquerySuite extends PlanTest { @@ -84,10 +84,16 @@ class RewriteSubquerySuite extends PlanTest { test("SPARK-50091: Don't put aggregate expression in join condition") { val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int) val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int) - val query = relation2.select(sum($"col2").in(ListQuery(relation1.select($"c3")))) - - val optimized = Optimize.execute(query.analyze) - val join = optimized.find(_.isInstanceOf[Join]).get.asInstanceOf[Join] - assert(!join.condition.get.exists(_.isInstanceOf[AggregateExpression])) + val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3")))) + val optimized = Optimize.execute(plan.analyze) + val aggregate = relation2 + .select($"col2") + .groupBy()(sum($"col2").as("_aggregateexpression")) + val correctAnswer = aggregate + .join(relation1.select(Cast($"c3", LongType).as("c3")), + ExistenceJoin($"exists".boolean.withNullability(false)), + Some($"_aggregateexpression" === $"c3")) + .select($"exists".as("(sum(col2) IN (listquery()))")).analyze + comparePlans(optimized, correctAnswer) } } From 0e1c1706cb694368ac213616ab31cc5df68f6948 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 21 Jan 2025 16:55:40 -0800 Subject: [PATCH 17/17] Test updates --- .../scala/org/apache/spark/sql/SubquerySuite.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 6a7276438987a..e7e41f6570d3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2802,11 +2802,13 @@ class SubquerySuite extends QueryTest } test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") { - withTable("v1", "v2") { - sql("""CREATE OR REPLACE TEMP VIEW v1 (c1, c2, c3) AS VALUES - |(1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8)""".stripMargin) - sql("""CREATE OR REPLACE TEMP VIEW v2 (col1, col2, col3) AS VALUES - |(1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1)""".stripMargin) + withView("v1", "v2") { + Seq((1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8)) + .toDF("c1", "c2", "c3") + .createOrReplaceTempView("v1") + Seq((1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1)) + .toDF("col1", "col2", "col3") + .createOrReplaceTempView("v2") val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2 GROUP BY col1") checkAnswer(df1,