-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-33122][SQL] Remove redundant aggregates in the Optimzier #30018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
84ba723
14f3033
29701dc
4ce0644
ef64abf
ca974c7
4bf08bb
832ff02
6cdc43a
2d63bb4
a82699e
38d7007
fab0427
67861f9
12d1bf4
e396ce3
6d68718
0d86060
33d6072
3e3168a
57af005
f50048f
9219bca
c21dd52
b415194
797d48f
2b32f4a
e202987
37dc4b1
07e758d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.analysis | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
|
|
||
| /** | ||
| * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, | ||
| * put them into an inner Project and finally project them away at the outer Project. | ||
| */ | ||
| object PullOutNondeterministic extends Rule[LogicalPlan] { | ||
| override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp applyLocally | ||
|
|
||
| val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { | ||
| case p if !p.resolved => p // Skip unresolved nodes. | ||
| case p: Project => p | ||
| case f: Filter => f | ||
|
|
||
| case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) => | ||
| val nondeterToAttr = getNondeterToAttr(a.groupingExpressions) | ||
| val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child) | ||
| a.transformExpressions { case e => | ||
| nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) | ||
| }.copy(child = newChild) | ||
|
|
||
| // Don't touch collect metrics. Top-level metrics are not supported (check analysis will fail) | ||
| // and we want to retain them inside the aggregate functions. | ||
| case m: CollectMetrics => m | ||
|
|
||
| // todo: It's hard to write a general rule to pull out nondeterministic expressions | ||
| // from LogicalPlan, currently we only do it for UnaryNode which has same output | ||
| // schema with its child. | ||
| case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => | ||
| val nondeterToAttr = getNondeterToAttr(p.expressions) | ||
| val newPlan = p.transformExpressions { case e => | ||
| nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) | ||
| } | ||
| val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child) | ||
| Project(p.output, newPlan.withNewChildren(newChild :: Nil)) | ||
| } | ||
|
|
||
| private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = { | ||
| exprs.filterNot(_.deterministic).flatMap { expr => | ||
| val leafNondeterministic = expr.collect { | ||
| case n: Nondeterministic => n | ||
| case udf: UserDefinedExpression if !udf.deterministic => udf | ||
| } | ||
| leafNondeterministic.distinct.map { e => | ||
| val ne = e match { | ||
| case n: NamedExpression => n | ||
| case _ => Alias(e, "_nondeterministic")() | ||
| } | ||
| e -> ne | ||
| } | ||
| }.toMap | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -111,6 +111,7 @@ abstract class Optimizer(catalogManager: CatalogManager) | |
| RewriteCorrelatedScalarSubquery, | ||
| EliminateSerialization, | ||
| RemoveRedundantAliases, | ||
| RemoveRedundantAggregates, | ||
| UnwrapCastInBinaryComparison, | ||
| RemoveNoopOperators, | ||
| OptimizeUpdateFields, | ||
|
|
@@ -488,6 +489,50 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { | |
| def apply(plan: LogicalPlan): LogicalPlan = removeRedundantAliases(plan, AttributeSet.empty) | ||
| } | ||
|
|
||
| /** | ||
| * Remove redundant aggregates from a query plan. A redundant aggregate is an aggregate whose | ||
| * only goal is to keep distinct values, while its parent aggregate would ignore duplicate values. | ||
| */ | ||
| object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you move this optimizer into a new file please, @tanelk ? |
||
| def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { | ||
tanelk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) => | ||
| val aliasMap = getAliasMap(lower) | ||
|
|
||
| val newAggregate = upper.copy( | ||
| child = lower.child, | ||
| groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)), | ||
| aggregateExpressions = upper.aggregateExpressions.map( | ||
| replaceAliasButKeepName(_, aliasMap)) | ||
| ) | ||
|
|
||
| // We might have introduces non-deterministic grouping expression | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| if (newAggregate.groupingExpressions.exists(!_.deterministic)) { | ||
| PullOutNondeterministic.applyLocally.applyOrElse(newAggregate, identity[LogicalPlan]) | ||
| } else { | ||
| newAggregate | ||
| } | ||
| } | ||
|
|
||
| private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = { | ||
tanelk marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit. Usually, |
||
| val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate) | ||
|
|
||
| lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet( | ||
| lower | ||
| .aggregateExpressions | ||
| .filter(_.deterministic) | ||
| .filter(!isAggregate(_)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. - .filter(!isAggregate(_))
+ .filterNot(isAggregate) |
||
| .map(_.toAttribute) | ||
| )) | ||
|
|
||
| upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg | ||
| } | ||
|
|
||
| private def isAggregate(expr: Expression): Boolean = { | ||
| expr.find(e => e.isInstanceOf[AggregateExpression] || | ||
| PythonUDF.isGroupedAggPandasUDF(e)).isDefined | ||
tanelk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Remove no-op operators from the query plan that do not make any modifications. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -621,7 +621,7 @@ case class Range( | |
| * | ||
| * @param groupingExpressions expressions for grouping keys | ||
| * @param aggregateExpressions expressions for a project list, which could contain | ||
| * [[AggregateFunction]]s. | ||
| * [[AggregateExpression]]s. | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This caused some confusion while making this PR |
||
| * | ||
| * Note: Currently, aggregateExpressions is the project list of this Group by operator. Before | ||
| * separating projection from grouping and aggregate, we should avoid expression-level optimization | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import org.apache.spark.api.python.PythonEvalType | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF} | ||
| import org.apache.spark.sql.catalyst.plans.PlanTest | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} | ||
| import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
| import org.apache.spark.sql.types.IntegerType | ||
|
|
||
| class RemoveRedundantAggregatesSuite extends PlanTest { | ||
|
|
||
| object Optimize extends RuleExecutor[LogicalPlan] { | ||
| val batches = Batch("RemoveRedundantAggregates", FixedPoint(10), | ||
| RemoveRedundantAggregates) :: Nil | ||
| } | ||
|
|
||
| private def aggregates(e: Expression): Seq[Expression] = { | ||
| Seq( | ||
| count(e), | ||
| PythonUDF("pyUDF", null, IntegerType, Seq(e), | ||
| PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, udfDeterministic = true) | ||
| ) | ||
| } | ||
|
|
||
| test("Remove redundant aggregate") { | ||
| val relation = LocalRelation('a.int, 'b.int) | ||
| for (agg <- aggregates('b)) { | ||
| val query = relation | ||
| .groupBy('a)('a, agg) | ||
| .groupBy('a)('a) | ||
| .analyze | ||
| val expected = relation | ||
| .groupBy('a)('a) | ||
| .analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, expected) | ||
| } | ||
| } | ||
|
|
||
| test("Remove 2 redundant aggregates") { | ||
| val relation = LocalRelation('a.int, 'b.int) | ||
| for (agg <- aggregates('b)) { | ||
| val query = relation | ||
| .groupBy('a)('a, agg) | ||
| .groupBy('a)('a) | ||
| .groupBy('a)('a) | ||
| .analyze | ||
| val expected = relation | ||
| .groupBy('a)('a) | ||
| .analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, expected) | ||
| } | ||
| } | ||
|
|
||
| test("Remove redundant aggregate with different grouping") { | ||
| val relation = LocalRelation('a.int, 'b.int) | ||
| val query = relation | ||
| .groupBy('a, 'b)('a) | ||
| .groupBy('a)('a) | ||
| .analyze | ||
| val expected = relation | ||
| .groupBy('a)('a) | ||
| .analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, expected) | ||
| } | ||
|
|
||
| test("Remove redundant aggregate with aliases") { | ||
| val relation = LocalRelation('a.int, 'b.int) | ||
| for (agg <- aggregates('b)) { | ||
| val query = relation | ||
| .groupBy('a + 'b)(('a + 'b) as 'c, agg) | ||
| .groupBy('c)('c) | ||
| .analyze | ||
| val expected = relation | ||
| .groupBy('a + 'b)(('a + 'b) as 'c) | ||
| .analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, expected) | ||
| } | ||
| } | ||
|
|
||
| test("Remove redundant aggregate with non-deterministic upper") { | ||
| val relation = LocalRelation('a.int, 'b.int) | ||
| val query = relation | ||
| .groupBy('a)('a) | ||
| .groupBy('a)('a, rand(0) as 'c) | ||
| .analyze | ||
| val expected = relation | ||
| .groupBy('a)('a, rand(0) as 'c) | ||
| .analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, expected) | ||
| } | ||
|
|
||
| test("Remove redundant aggregate with non-deterministic lower") { | ||
| val relation = LocalRelation('a.int, 'b.int) | ||
| val query = relation | ||
| .groupBy('a, 'c)('a, rand(0) as 'c) | ||
| .groupBy('a, 'c)('a, 'c) | ||
| .analyze | ||
| val expected = relation | ||
| .groupBy('a, 'c)('a, rand(0) as 'c) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, shouldn't this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you, good catch. I fixed this test, added another to check this case and added an extra condition to the optimizer.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM now, conflicts need to be resolved though. |
||
| .analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, expected) | ||
| } | ||
|
|
||
| test("Keep non-redundant aggregate - upper has agg expression") { | ||
| val relation = LocalRelation('a.int, 'b.int) | ||
| for (agg <- aggregates('b)) { | ||
| val query = relation | ||
| .groupBy('a, 'b)('a, 'b) | ||
| // The count would change if we remove the first aggregate | ||
| .groupBy('a)('a, agg) | ||
| .analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, query) | ||
| } | ||
| } | ||
|
|
||
| test("Keep non-redundant aggregate - upper references agg expression") { | ||
| val relation = LocalRelation('a.int, 'b.int) | ||
| for (agg <- aggregates('b)) { | ||
| val query = relation | ||
| .groupBy('a)('a, agg as 'c) | ||
| .groupBy('c)('c) | ||
| .analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, query) | ||
| } | ||
| } | ||
|
|
||
| test("Keep non-redundant aggregate - upper references non-deterministic non-grouping") { | ||
| val relation = LocalRelation('a.int, 'b.int) | ||
| val query = relation | ||
| .groupBy('a)('a, ('a + rand(0)) as 'c) | ||
| .groupBy('a, 'c)('a, 'c) | ||
| .analyze | ||
| val optimized = Optimize.execute(query) | ||
| comparePlans(optimized, query) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just moved it outside of Analyzer, so it would be accessible from outside