Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -115,6 +116,26 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}

def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = {
exprs.exists { expr =>
exprContainsAggregateInSubquery(expr)
}
}

def exprContainsAggregateInSubquery(expr: Expression): Boolean = {
expr.exists {
case InSubquery(values, _) =>
values.exists { v =>
v.exists {
case _: AggregateExpression => true
case _ => false
}
}
case _ => false;
}
}


def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) {
case Filter(condition, child)
Expand Down Expand Up @@ -246,46 +267,106 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}

// Handle the case where the left-hand side of an IN-subquery contains an aggregate.
//
// If an Aggregate node contains such an IN-subquery, this handler will pull up all
// expressions from the Aggregate node into a new Project node. The new Project node
// will then be handled by the Unary node handler.
//
// The Unary node handler uses the left-hand side of the IN-subquery in a
// join condition. Thus, without this pre-transformation, the join condition
// contains an aggregate, which is illegal. With this pre-transformation, the
// join condition contains an attribute from the left-hand side of the
// IN-subquery contained in the Project node.
//
// For example:
//
// SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
// FROM v2;
//
// The above query has this plan on entry to RewritePredicateSubquery#apply:
//
// Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13]
// : +- LocalRelation [c3#28L]
// +- LocalRelation [col2#18, col3#19]
//
// Note that the Aggregate node contains the IN-subquery and the left-hand
// side of the IN-subquery is an aggregate expression sum(col2#18)).
//
// This handler transforms the above plan into the following:
// scalastyle:off line.size.limit
//
// Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13]
// : +- LocalRelation [c3#28L]
// +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L]
// +- LocalRelation [col2#18, col3#19]
//
// scalastyle:on
// Note that both the IN-subquery and the greater-than expressions have been
// pulled up into the Project node. These expressions use attributes
// (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations
// which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)).
case p @ PhysicalAggregation(
groupingExpressions, aggregateExpressions, resultExpressions, child)
if exprsContainsAggregateInSubquery(p.expressions) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if exprsContainsAggregateInSubquery(p.expressions) =>
if exprsContainsAggregateInSubquery(resultExpressions) =>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rewrite only pulls out subquery expressions for Aggregate#aggregateExpressions, not grouping expressions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: if exprsContainsAggregateInSubquery(resultExpressions) =>.

That won't work withexprsContainsAggregateInSubquery as it currently stands, since that function looks for in-subqueries with aggregate expressions in the left-hand operand. resultExpressions has the aggregate expressions replaced with attributes, so exprsContainsAggregateInSubquery would never trigger.

Alternatively, I could do

if exprsContainsAggregateInSubquery(p.asInstanceOf[Aggregate].aggregateExpressions) =>

which is kind of ugly, but does the trick.

Another alternative: I'm the only one calling exprsContainsAggregateInSubquery, so I could change it to return true if there are any in-subqueries at all with no regard to characteristics of the left-hand operand. We would end up rewriting some cases that wouldn't otherwise cause trouble.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah OK, let's keep it as it is

val aggExprs = aggregateExpressions.map(
ae => Alias(ae, "_aggregateexpression")(ae.resultId))
val aggExprIds = aggExprs.map(_.exprId).toSet
val resExprs = resultExpressions.map(_.transform {
case a: AttributeReference if aggExprIds.contains(a.exprId) =>
a.withName("_aggregateexpression")
}.asInstanceOf[NamedExpression])
// Rewrite the projection and the aggregate separately and then piece them together.
val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child)
val newProj = Project(resExprs, newAgg)
handleUnaryNode(newProj)

case u: UnaryNode if u.expressions.exists(
SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
var newChild = u.child
var introducedAttrs = Seq.empty[Attribute]
val updatedNode = u.mapExpressions(expr => {
val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild)
newChild = p
introducedAttrs ++= newAttrs
// The newExpr can not be None
newExpr.get
}).withNewChildren(Seq(newChild))
updatedNode match {
case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) =>
// If we have introduced new `exists`-attributes that are referenced by
// aggregateExpressions within a non-aggregateFunction expression, we wrap them in
// first() aggregate function. first() is Spark's executable version of any_value()
// aggregate function.
// We do this to keep the aggregation valid, i.e avoid references outside of aggregate
// functions that are not in grouping expressions.
// Note that the same `exists` attr will never appear in groupingExpressions due to
// PullOutGroupingExpressions rule.
// Also note: the value of `exists` is functionally determined by grouping expressions,
// so applying any aggregate function is semantically safe.
val aggFunctionReferences = a.aggregateExpressions.
flatMap(extractAggregateExpressions).
flatMap(_.references).toSet
val nonAggFuncReferences =
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains)

// Replace all eligible `exists` by `First(exists)` among aggregateExpressions.
val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
aggExpr.transformUp {
case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
new First(attr).toAggregateExpression()
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = newAggregateExpressions)
case _ => updatedNode
}
SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u)
}

/**
* Handle the unary node case
*/
private def handleUnaryNode(u: UnaryNode): LogicalPlan = {
var newChild = u.child
var introducedAttrs = Seq.empty[Attribute]
val updatedNode = u.mapExpressions(expr => {
val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild)
newChild = p
introducedAttrs ++= newAttrs
// The newExpr can not be None
newExpr.get
}).withNewChildren(Seq(newChild))
updatedNode match {
case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) =>
// If we have introduced new `exists`-attributes that are referenced by
// aggregateExpressions within a non-aggregateFunction expression, we wrap them in
// first() aggregate function. first() is Spark's executable version of any_value()
// aggregate function.
// We do this to keep the aggregation valid, i.e avoid references outside of aggregate
// functions that are not in grouping expressions.
// Note that the same `exists` attr will never appear in groupingExpressions due to
// PullOutGroupingExpressions rule.
// Also note: the value of `exists` is functionally determined by grouping expressions,
// so applying any aggregate function is semantically safe.
val aggFunctionReferences = a.aggregateExpressions.
flatMap(extractAggregateExpressions).
flatMap(_.references).toSet
val nonAggFuncReferences =
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains)

// Replace all eligible `exists` by `First(exists)` among aggregateExpressions.
val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
aggExpr.transformUp {
case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
new First(attr).toAggregateExpression()
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = newAggregateExpressions)
case _ => updatedNode
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not}
import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.LongType


class RewriteSubquerySuite extends PlanTest {
Expand Down Expand Up @@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest {
Optimize.executeAndTrack(query.analyze, tracker)
assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0)
}

test("SPARK-50091: Don't put aggregate expression in join condition") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also updated this test to check the whole optimized plan rather than simply testing that the join condition does not have an aggregate expression.

val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int)
val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int)
val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3"))))
val optimized = Optimize.execute(plan.analyze)
val aggregate = relation2
.select($"col2")
.groupBy()(sum($"col2").as("_aggregateexpression"))
val correctAnswer = aggregate
.join(relation1.select(Cast($"c3", LongType).as("c3")),
ExistenceJoin($"exists".boolean.withNullability(false)),
Some($"_aggregateexpression" === $"c3"))
.select($"exists".as("(sum(col2) IN (listquery()))")).analyze
comparePlans(optimized, correctAnswer)
}
}
30 changes: 30 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2800,4 +2800,34 @@ class SubquerySuite extends QueryTest
checkAnswer(df3, Row(7))
}
}

test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") {
withView("v1", "v2") {
Seq((1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8))
.toDF("c1", "c2", "c3")
.createOrReplaceTempView("v1")
Seq((1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1))
.toDF("col1", "col2", "col3")
.createOrReplaceTempView("v2")

val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2 GROUP BY col1")
checkAnswer(df1,
Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil)

val df2 = sql("""SELECT
| col1,
| SUM(col2) IN (SELECT c3 FROM v1) and SUM(col3) IN (SELECT c2 FROM v1) AS x
|FROM v2 GROUP BY col1
|ORDER BY col1""".stripMargin)
checkAnswer(df2,
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)

val df3 = sql("""SELECT col1, (SUM(col2), SUM(col3)) IN (SELECT c3, c2 FROM v1) AS x
|FROM v2
|GROUP BY col1
|ORDER BY col1""".stripMargin)
checkAnswer(df3,
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
}
}
}
Loading