-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-30298][SQL] Respect aliases in output partitioning of projects and aggregates #26943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7711ced
31191b8
fa53ccf
d3559b4
2286b9e
c24789d
5d0244f
fcd2186
1762a96
bc2d072
323d4a7
b877de7
fbafedf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.sql.execution | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression} | ||
| import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} | ||
|
|
||
| /** | ||
| * A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning` | ||
| * that satisfies output distribution requirements. | ||
| */ | ||
| trait AliasAwareOutputPartitioning extends UnaryExecNode { | ||
| protected def outputExpressions: Seq[NamedExpression] | ||
|
|
||
| final override def outputPartitioning: Partitioning = { | ||
| if (hasAlias) { | ||
| child.outputPartitioning match { | ||
| case h: HashPartitioning => h.copy(expressions = replaceAliases(h.expressions)) | ||
| case other => other | ||
| } | ||
| } else { | ||
| child.outputPartitioning | ||
| } | ||
| } | ||
|
|
||
| private def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined | ||
|
|
||
| private def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = { | ||
| exprs.map { | ||
| case a: AttributeReference => replaceAlias(a).getOrElse(a) | ||
| case other => other | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about the other partitioning cases, e.g., range?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For |
||
| } | ||
| } | ||
|
|
||
| private def replaceAlias(attr: AttributeReference): Option[Attribute] = { | ||
| outputExpressions.collectFirst { | ||
| case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) => | ||
| a.toAttribute | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans._ | |
| import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range, Repartition, Sort, Union} | ||
| import org.apache.spark.sql.catalyst.plans.physical._ | ||
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | ||
| import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} | ||
| import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} | ||
| import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} | ||
| import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} | ||
|
|
@@ -937,6 +938,93 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { | |
| } | ||
| } | ||
| } | ||
|
|
||
| test("aliases in the project should not introduce extra shuffle") { | ||
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | ||
| withTempView("df1", "df2") { | ||
| spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1") | ||
| spark.range(20).selectExpr("id AS key", "0").repartition($"key").createTempView("df2") | ||
| val planned = sql( | ||
| """ | ||
| |SELECT * FROM | ||
| | (SELECT key AS k from df1) t1 | ||
| |INNER JOIN | ||
| | (SELECT key AS k from df2) t2 | ||
| |ON t1.k = t2.k | ||
| """.stripMargin).queryExecution.executedPlan | ||
| val exchanges = planned.collect { case s: ShuffleExchangeExec => s } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was confused about why only one shuffle, then realized it's exchange reuse. Can we join different data frames? e.g.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing that out. I updated it and it now generates two |
||
| assert(exchanges.size == 2) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("aliases to expressions should not be replaced") { | ||
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | ||
| withTempView("df1", "df2") { | ||
| spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1") | ||
| spark.range(20).selectExpr("id AS key", "0").repartition($"key").createTempView("df2") | ||
| val planned = sql( | ||
| """ | ||
| |SELECT * FROM | ||
| | (SELECT key + 1 AS k1 from df1) t1 | ||
| |INNER JOIN | ||
| | (SELECT key + 1 AS k2 from df2) t2 | ||
| |ON t1.k1 = t2.k2 | ||
| |""".stripMargin).queryExecution.executedPlan | ||
| val exchanges = planned.collect { case s: ShuffleExchangeExec => s } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
|
|
||
| // Make sure aliases to an expression (key + 1) are not replaced. | ||
| Seq("k1", "k2").foreach { alias => | ||
| assert(exchanges.exists(_.outputPartitioning match { | ||
| case HashPartitioning(Seq(a: AttributeReference), _) => a.name == alias | ||
| case _ => false | ||
| })) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("aliases in the aggregate expressions should not introduce extra shuffle") { | ||
imback82 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | ||
| val t1 = spark.range(10).selectExpr("floor(id/4) as k1") | ||
| val t2 = spark.range(20).selectExpr("floor(id/4) as k2") | ||
|
|
||
| val agg1 = t1.groupBy("k1").agg(count(lit("1")).as("cnt1")) | ||
| val agg2 = t2.groupBy("k2").agg(count(lit("1")).as("cnt2")).withColumnRenamed("k2", "k3") | ||
|
|
||
| val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan | ||
|
|
||
| assert(planned.collect { case h: HashAggregateExec => h }.nonEmpty) | ||
|
|
||
| val exchanges = planned.collect { case s: ShuffleExchangeExec => s } | ||
| assert(exchanges.size == 2) | ||
| } | ||
| } | ||
|
|
||
| test("aliases in the object hash/sort aggregate expressions should not introduce extra shuffle") { | ||
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | ||
| Seq(true, false).foreach { useObjectHashAgg => | ||
| withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> useObjectHashAgg.toString) { | ||
| val t1 = spark.range(10).selectExpr("floor(id/4) as k1") | ||
| val t2 = spark.range(20).selectExpr("floor(id/4) as k2") | ||
|
|
||
| val agg1 = t1.groupBy("k1").agg(collect_list("k1")) | ||
| val agg2 = t2.groupBy("k2").agg(collect_list("k2")).withColumnRenamed("k2", "k3") | ||
|
|
||
| val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan | ||
|
|
||
| if (useObjectHashAgg) { | ||
| assert(planned.collect { case o: ObjectHashAggregateExec => o }.nonEmpty) | ||
| } else { | ||
| assert(planned.collect { case s: SortAggregateExec => s }.nonEmpty) | ||
| } | ||
|
|
||
| val exchanges = planned.collect { case s: ShuffleExchangeExec => s } | ||
| assert(exchanges.size == 2) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Used for unit-testing EnsureRequirements | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.