-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-50091][SQL] Handle case of aggregates in left-hand operand of IN-subquery #48627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
79b5089
7fe2a08
c96af36
424d803
2b1a376
9c443b0
46d43fd
ca4dba8
e0fc82f
3e52a12
1db5316
f6aa964
cc6384b
93d98e7
cb4066a
b5ee466
0e1c170
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ | |||||
| import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ | ||||||
| import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||||||
| import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery | ||||||
| import org.apache.spark.sql.catalyst.planning.PhysicalAggregation | ||||||
| import org.apache.spark.sql.catalyst.plans._ | ||||||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||||||
| import org.apache.spark.sql.catalyst.rules._ | ||||||
|
|
@@ -115,6 +116,26 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { | |||||
| } | ||||||
| } | ||||||
|
|
||||||
| def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = { | ||||||
| exprs.exists { expr => | ||||||
| exprContainsAggregateInSubquery(expr) | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| def exprContainsAggregateInSubquery(expr: Expression): Boolean = { | ||||||
| expr.exists { | ||||||
| case InSubquery(values, _) => | ||||||
| values.exists { v => | ||||||
| v.exists { | ||||||
| case _: AggregateExpression => true | ||||||
| case _ => false | ||||||
| } | ||||||
| } | ||||||
| case _ => false; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
| def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( | ||||||
| _.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) { | ||||||
| case Filter(condition, child) | ||||||
|
|
@@ -246,46 +267,106 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { | |||||
| } | ||||||
| } | ||||||
|
|
||||||
| // Handle the case where the left-hand side of an IN-subquery contains an aggregate. | ||||||
| // | ||||||
| // If an Aggregate node contains such an IN-subquery, this handler will pull up all | ||||||
| // expressions from the Aggregate node into a new Project node. The new Project node | ||||||
| // will then be handled by the Unary node handler. | ||||||
| // | ||||||
| // The Unary node handler uses the left-hand side of the IN-subquery in a | ||||||
| // join condition. Thus, without this pre-transformation, the join condition | ||||||
| // contains an aggregate, which is illegal. With this pre-transformation, the | ||||||
| // join condition contains an attribute from the left-hand side of the | ||||||
| // IN-subquery contained in the Project node. | ||||||
| // | ||||||
| // For example: | ||||||
| // | ||||||
| // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x | ||||||
| // FROM v2; | ||||||
| // | ||||||
| // The above query has this plan on entry to RewritePredicateSubquery#apply: | ||||||
| // | ||||||
| // Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13] | ||||||
| // : +- LocalRelation [c3#28L] | ||||||
| // +- LocalRelation [col2#18, col3#19] | ||||||
| // | ||||||
| // Note that the Aggregate node contains the IN-subquery and the left-hand | ||||||
| // side of the IN-subquery is an aggregate expression sum(col2#18)). | ||||||
| // | ||||||
| // This handler transforms the above plan into the following: | ||||||
| // scalastyle:off line.size.limit | ||||||
| // | ||||||
| // Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13] | ||||||
| // : +- LocalRelation [c3#28L] | ||||||
| // +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L] | ||||||
| // +- LocalRelation [col2#18, col3#19] | ||||||
| // | ||||||
| // scalastyle:on | ||||||
| // Note that both the IN-subquery and the greater-than expressions have been | ||||||
| // pulled up into the Project node. These expressions use attributes | ||||||
| // (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations | ||||||
| // which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)). | ||||||
| case p @ PhysicalAggregation( | ||||||
| groupingExpressions, aggregateExpressions, resultExpressions, child) | ||||||
| if exprsContainsAggregateInSubquery(p.expressions) => | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This rewrite only pulls out subquery expressions for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re: That won't work with Alternatively, I could do which is kind of ugly, but does the trick. Another alternative: I'm the only one calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah OK, let's keep it as it is |
||||||
| val aggExprs = aggregateExpressions.map( | ||||||
| ae => Alias(ae, "_aggregateexpression")(ae.resultId)) | ||||||
| val aggExprIds = aggExprs.map(_.exprId).toSet | ||||||
| val resExprs = resultExpressions.map(_.transform { | ||||||
| case a: AttributeReference if aggExprIds.contains(a.exprId) => | ||||||
| a.withName("_aggregateexpression") | ||||||
| }.asInstanceOf[NamedExpression]) | ||||||
| // Rewrite the projection and the aggregate separately and then piece them together. | ||||||
| val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child) | ||||||
| val newProj = Project(resExprs, newAgg) | ||||||
| handleUnaryNode(newProj) | ||||||
|
|
||||||
| case u: UnaryNode if u.expressions.exists( | ||||||
| SubqueryExpression.hasInOrCorrelatedExistsSubquery) => | ||||||
| var newChild = u.child | ||||||
| var introducedAttrs = Seq.empty[Attribute] | ||||||
| val updatedNode = u.mapExpressions(expr => { | ||||||
| val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild) | ||||||
| newChild = p | ||||||
| introducedAttrs ++= newAttrs | ||||||
| // The newExpr can not be None | ||||||
| newExpr.get | ||||||
| }).withNewChildren(Seq(newChild)) | ||||||
| updatedNode match { | ||||||
| case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) => | ||||||
| // If we have introduced new `exists`-attributes that are referenced by | ||||||
| // aggregateExpressions within a non-aggregateFunction expression, we wrap them in | ||||||
| // first() aggregate function. first() is Spark's executable version of any_value() | ||||||
| // aggregate function. | ||||||
| // We do this to keep the aggregation valid, i.e avoid references outside of aggregate | ||||||
| // functions that are not in grouping expressions. | ||||||
| // Note that the same `exists` attr will never appear in groupingExpressions due to | ||||||
| // PullOutGroupingExpressions rule. | ||||||
| // Also note: the value of `exists` is functionally determined by grouping expressions, | ||||||
| // so applying any aggregate function is semantically safe. | ||||||
| val aggFunctionReferences = a.aggregateExpressions. | ||||||
| flatMap(extractAggregateExpressions). | ||||||
| flatMap(_.references).toSet | ||||||
| val nonAggFuncReferences = | ||||||
| a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains) | ||||||
| val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains) | ||||||
|
|
||||||
| // Replace all eligible `exists` by `First(exists)` among aggregateExpressions. | ||||||
| val newAggregateExpressions = a.aggregateExpressions.map { aggExpr => | ||||||
| aggExpr.transformUp { | ||||||
| case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) => | ||||||
| new First(attr).toAggregateExpression() | ||||||
| }.asInstanceOf[NamedExpression] | ||||||
| } | ||||||
| a.copy(aggregateExpressions = newAggregateExpressions) | ||||||
| case _ => updatedNode | ||||||
| } | ||||||
| SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u) | ||||||
| } | ||||||
|
|
||||||
| /** | ||||||
| * Handle the unary node case | ||||||
| */ | ||||||
| private def handleUnaryNode(u: UnaryNode): LogicalPlan = { | ||||||
| var newChild = u.child | ||||||
| var introducedAttrs = Seq.empty[Attribute] | ||||||
| val updatedNode = u.mapExpressions(expr => { | ||||||
| val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild) | ||||||
| newChild = p | ||||||
| introducedAttrs ++= newAttrs | ||||||
| // The newExpr can not be None | ||||||
| newExpr.get | ||||||
| }).withNewChildren(Seq(newChild)) | ||||||
| updatedNode match { | ||||||
| case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) => | ||||||
| // If we have introduced new `exists`-attributes that are referenced by | ||||||
| // aggregateExpressions within a non-aggregateFunction expression, we wrap them in | ||||||
| // first() aggregate function. first() is Spark's executable version of any_value() | ||||||
| // aggregate function. | ||||||
| // We do this to keep the aggregation valid, i.e avoid references outside of aggregate | ||||||
| // functions that are not in grouping expressions. | ||||||
| // Note that the same `exists` attr will never appear in groupingExpressions due to | ||||||
| // PullOutGroupingExpressions rule. | ||||||
| // Also note: the value of `exists` is functionally determined by grouping expressions, | ||||||
| // so applying any aggregate function is semantically safe. | ||||||
| val aggFunctionReferences = a.aggregateExpressions. | ||||||
| flatMap(extractAggregateExpressions). | ||||||
| flatMap(_.references).toSet | ||||||
| val nonAggFuncReferences = | ||||||
| a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains) | ||||||
| val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains) | ||||||
|
|
||||||
| // Replace all eligible `exists` by `First(exists)` among aggregateExpressions. | ||||||
| val newAggregateExpressions = a.aggregateExpressions.map { aggExpr => | ||||||
| aggExpr.transformUp { | ||||||
| case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) => | ||||||
| new First(attr).toAggregateExpression() | ||||||
| }.asInstanceOf[NamedExpression] | ||||||
| } | ||||||
| a.copy(aggregateExpressions = newAggregateExpressions) | ||||||
| case _ => updatedNode | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| /** | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer | |
| import org.apache.spark.sql.catalyst.QueryPlanningTracker | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not} | ||
| import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not} | ||
| import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest} | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} | ||
| import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
| import org.apache.spark.sql.types.LongType | ||
|
|
||
|
|
||
| class RewriteSubquerySuite extends PlanTest { | ||
|
|
@@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest { | |
| Optimize.executeAndTrack(query.analyze, tracker) | ||
| assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0) | ||
| } | ||
|
|
||
| test("SPARK-50091: Don't put aggregate expression in join condition") { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also updated this test to check the whole optimized plan rather than simply testing that the join condition does not have an aggregate expression. |
||
| val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int) | ||
| val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int) | ||
| val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3")))) | ||
| val optimized = Optimize.execute(plan.analyze) | ||
| val aggregate = relation2 | ||
| .select($"col2") | ||
| .groupBy()(sum($"col2").as("_aggregateexpression")) | ||
| val correctAnswer = aggregate | ||
| .join(relation1.select(Cast($"c3", LongType).as("c3")), | ||
| ExistenceJoin($"exists".boolean.withNullability(false)), | ||
| Some($"_aggregateexpression" === $"c3")) | ||
| .select($"exists".as("(sum(col2) IN (listquery()))")).analyze | ||
| comparePlans(optimized, correctAnswer) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.