Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}

/**
* A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning`
* that satisfies output distribution requirements.
*/
trait AliasAwareOutputPartitioning extends UnaryExecNode {
protected def outputExpressions: Seq[NamedExpression]

final override def outputPartitioning: Partitioning = {
if (hasAlias) {
child.outputPartitioning match {
case h: HashPartitioning => h.copy(expressions = replaceAliases(h.expressions))
case other => other
}
} else {
child.outputPartitioning
}
}

private def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined

private def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map {
case a: AttributeReference => replaceAlias(a).getOrElse(a)
case other => other
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the other partitioning cases, e.g., range?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PartitioningCollection is constructed as PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)), so aliases should have been already removed if the partitioning was HashPartitioning. But we could add one similar to your solution (https://github.com/apache/spark/pull/17400/files#diff-342789ab9c8c0154b412dd1c719c9397R82-R86) for future proof.

For RangePartitioning, your change (https://github.com/apache/spark/pull/17400/files#diff-342789ab9c8c0154b412dd1c719c9397R72-R78) makes sense, but I couldn't come up with an actual example to test against. Do you have one in mind?

}
}

private def replaceAlias(attr: AttributeReference): Option[Attribute] = {
outputExpressions.collectFirst {
case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) =>
a.toAttribute
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode with BlockingOperatorWithCodegen {
extends UnaryExecNode with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand All @@ -75,7 +75,7 @@ case class HashAggregateExec(

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def outputPartitioning: Partitioning = child.outputPartitioning
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ case class ObjectHashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode {
extends UnaryExecNode with AliasAwareOutputPartitioning {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand Down Expand Up @@ -97,7 +97,7 @@ case class ObjectHashAggregateExec(
}
}

override def outputPartitioning: Partitioning = child.outputPartitioning
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numOutputRows = longMetric("numOutputRows")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetrics

/**
Expand All @@ -38,7 +38,7 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode {
extends UnaryExecNode with AliasAwareOutputPartitioning {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand Down Expand Up @@ -66,7 +66,7 @@ case class SortAggregateExec(
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}

override def outputPartitioning: Partitioning = child.outputPartitioning
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

override def outputOrdering: Seq[SortOrder] = {
groupingExpressions.map(SortOrder(_, Ascending))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}

/** Physical plan for Project. */
case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryExecNode with CodegenSupport {
extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning {

override def output: Seq[Attribute] = projectList.map(_.toAttribute)

Expand Down Expand Up @@ -80,7 +80,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def outputPartitioning: Partitioning = child.outputPartitioning
override protected def outputExpressions: Seq[NamedExpression] = projectList

override def verboseStringWithOperatorId(): String = {
s"""
Expand All @@ -91,7 +91,6 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
}
}


/** Physical plan for Filter. */
case class FilterExec(condition: Expression, child: SparkPlan)
extends UnaryExecNode with CodegenSupport with PredicateHelper {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range, Repartition, Sort, Union}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
Expand Down Expand Up @@ -937,6 +938,93 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
}
}

test("aliases in the project should not introduce extra shuffle") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("df1", "df2") {
spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1")
spark.range(20).selectExpr("id AS key", "0").repartition($"key").createTempView("df2")
val planned = sql(
"""
|SELECT * FROM
| (SELECT key AS k from df1) t1
|INNER JOIN
| (SELECT key AS k from df2) t2
|ON t1.k = t2.k
""".stripMargin).queryExecution.executedPlan
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused about why only one shuffle, then realized it's exchange reuse.

Can we join different data frames? e.g. spark.range(10) and spark.range(20).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out. I updated it and it now generates two ShuffleExchangeExec instead of four.

assert(exchanges.size == 2)
}
}
}

test("aliases to expressions should not be replaced") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("df1", "df2") {
spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1")
spark.range(20).selectExpr("id AS key", "0").repartition($"key").createTempView("df2")
val planned = sql(
"""
|SELECT * FROM
| (SELECT key + 1 AS k1 from df1) t1
|INNER JOIN
| (SELECT key + 1 AS k2 from df2) t2
|ON t1.k1 = t2.k2
|""".stripMargin).queryExecution.executedPlan
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


// Make sure aliases to an expression (key + 1) are not replaced.
Seq("k1", "k2").foreach { alias =>
assert(exchanges.exists(_.outputPartitioning match {
case HashPartitioning(Seq(a: AttributeReference), _) => a.name == alias
case _ => false
}))
}
}
}
}

test("aliases in the aggregate expressions should not introduce extra shuffle") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
val t2 = spark.range(20).selectExpr("floor(id/4) as k2")

val agg1 = t1.groupBy("k1").agg(count(lit("1")).as("cnt1"))
val agg2 = t2.groupBy("k2").agg(count(lit("1")).as("cnt2")).withColumnRenamed("k2", "k3")

val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan

assert(planned.collect { case h: HashAggregateExec => h }.nonEmpty)

val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
assert(exchanges.size == 2)
}
}

test("aliases in the object hash/sort aggregate expressions should not introduce extra shuffle") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(true, false).foreach { useObjectHashAgg =>
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> useObjectHashAgg.toString) {
val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
val t2 = spark.range(20).selectExpr("floor(id/4) as k2")

val agg1 = t1.groupBy("k1").agg(collect_list("k1"))
val agg2 = t2.groupBy("k2").agg(collect_list("k2")).withColumnRenamed("k2", "k3")

val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan

if (useObjectHashAgg) {
assert(planned.collect { case o: ObjectHashAggregateExec => o }.nonEmpty)
} else {
assert(planned.collect { case s: SortAggregateExec => s }.nonEmpty)
}

val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
assert(exchanges.size == 2)
}
}
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,20 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
}
}

test("bucket join should work with SubqueryAlias plan") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
withTable("t") {
withView("v") {
spark.range(20).selectExpr("id as i").write.bucketBy(8, "i").saveAsTable("t")
sql("CREATE VIEW v AS SELECT * FROM t").collect()

val plan = sql("SELECT * FROM t a JOIN v b ON a.i = b.i").queryExecution.executedPlan
assert(plan.collect { case exchange: ShuffleExchangeExec => exchange }.isEmpty)
}
}
}
}

test("avoid shuffle when grouping keys are a super-set of bucket keys") {
withTable("bucketed_table") {
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
Expand Down