diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala new file mode 100644 index 0000000000000..2c7faea019322 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} + +/** + * A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning` + * that satisfies output distribution requirements. + */ +trait AliasAwareOutputPartitioning extends UnaryExecNode { + protected def outputExpressions: Seq[NamedExpression] + + final override def outputPartitioning: Partitioning = { + if (hasAlias) { + child.outputPartitioning match { + case h: HashPartitioning => h.copy(expressions = replaceAliases(h.expressions)) + case other => other + } + } else { + child.outputPartitioning + } + } + + private def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined + + private def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = { + exprs.map { + case a: AttributeReference => replaceAlias(a).getOrElse(a) + case other => other + } + } + + private def replaceAlias(attr: AttributeReference): Option[Attribute] = { + outputExpressions.collectFirst { + case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) => + a.toAttribute + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 7f19d2754673d..f73e214a6b41f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -53,7 +53,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with BlockingOperatorWithCodegen { + extends UnaryExecNode with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -75,7 +75,7 @@ case class HashAggregateExec( override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def outputExpressions: Seq[NamedExpression] = resultExpressions override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 953622afebf89..4376f6b6edd57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -67,7 +67,7 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -97,7 +97,7 @@ case class ObjectHashAggregateExec( } } - override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def outputExpressions: Seq[NamedExpression] = resultExpressions protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 0ddf95771d5b2..b6e684e62ea5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -38,7 +38,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -66,7 +66,7 @@ case class SortAggregateExec( groupingExpressions.map(SortOrder(_, Ascending)) :: Nil } - override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def outputExpressions: Seq[NamedExpression] = resultExpressions override def outputOrdering: Seq[SortOrder] = { groupingExpressions.map(SortOrder(_, Ascending)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index e128d59dca6ba..02c5571e60ea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -37,7 +37,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} /** Physical plan for Project. */ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -80,7 +80,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def outputExpressions: Seq[NamedExpression] = projectList override def verboseStringWithOperatorId(): String = { s""" @@ -91,7 +91,6 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } } - /** Physical plan for Filter. */ case class FilterExec(condition: Expression, child: SparkPlan) extends UnaryExecNode with CodegenSupport with PredicateHelper { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a10db54855c8a..94ce3559bb44b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} @@ -937,6 +938,93 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } } } + + test("aliases in the project should not introduce extra shuffle") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempView("df1", "df2") { + spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1") + spark.range(20).selectExpr("id AS key", "0").repartition($"key").createTempView("df2") + val planned = sql( + """ + |SELECT * FROM + | (SELECT key AS k from df1) t1 + |INNER JOIN + | (SELECT key AS k from df2) t2 + |ON t1.k = t2.k + """.stripMargin).queryExecution.executedPlan + val exchanges = planned.collect { case s: ShuffleExchangeExec => s } + assert(exchanges.size == 2) + } + } + } + + test("aliases to expressions should not be replaced") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempView("df1", "df2") { + spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1") + spark.range(20).selectExpr("id AS key", "0").repartition($"key").createTempView("df2") + val planned = sql( + """ + |SELECT * FROM + | (SELECT key + 1 AS k1 from df1) t1 + |INNER JOIN + | (SELECT key + 1 AS k2 from df2) t2 + |ON t1.k1 = t2.k2 + |""".stripMargin).queryExecution.executedPlan + val exchanges = planned.collect { case s: ShuffleExchangeExec => s } + + // Make sure aliases to an expression (key + 1) are not replaced. + Seq("k1", "k2").foreach { alias => + assert(exchanges.exists(_.outputPartitioning match { + case HashPartitioning(Seq(a: AttributeReference), _) => a.name == alias + case _ => false + })) + } + } + } + } + + test("aliases in the aggregate expressions should not introduce extra shuffle") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val t1 = spark.range(10).selectExpr("floor(id/4) as k1") + val t2 = spark.range(20).selectExpr("floor(id/4) as k2") + + val agg1 = t1.groupBy("k1").agg(count(lit("1")).as("cnt1")) + val agg2 = t2.groupBy("k2").agg(count(lit("1")).as("cnt2")).withColumnRenamed("k2", "k3") + + val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan + + assert(planned.collect { case h: HashAggregateExec => h }.nonEmpty) + + val exchanges = planned.collect { case s: ShuffleExchangeExec => s } + assert(exchanges.size == 2) + } + } + + test("aliases in the object hash/sort aggregate expressions should not introduce extra shuffle") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq(true, false).foreach { useObjectHashAgg => + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> useObjectHashAgg.toString) { + val t1 = spark.range(10).selectExpr("floor(id/4) as k1") + val t2 = spark.range(20).selectExpr("floor(id/4) as k2") + + val agg1 = t1.groupBy("k1").agg(collect_list("k1")) + val agg2 = t2.groupBy("k2").agg(collect_list("k2")).withColumnRenamed("k2", "k3") + + val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan + + if (useObjectHashAgg) { + assert(planned.collect { case o: ObjectHashAggregateExec => o }.nonEmpty) + } else { + assert(planned.collect { case s: SortAggregateExec => s }.nonEmpty) + } + + val exchanges = planned.collect { case s: ShuffleExchangeExec => s } + assert(exchanges.size == 2) + } + } + } + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index a585f215ad681..c7266c886128c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -604,6 +604,20 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { } } + test("bucket join should work with SubqueryAlias plan") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + withTable("t") { + withView("v") { + spark.range(20).selectExpr("id as i").write.bucketBy(8, "i").saveAsTable("t") + sql("CREATE VIEW v AS SELECT * FROM t").collect() + + val plan = sql("SELECT * FROM t a JOIN v b ON a.i = b.i").queryExecution.executedPlan + assert(plan.collect { case exchange: ShuffleExchangeExec => exchange }.isEmpty) + } + } + } + } + test("avoid shuffle when grouping keys are a super-set of bucket keys") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")