Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,37 @@
*/
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}

/**
* A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning`
* that satisfies output distribution requirements.
* A trait that provides functionality to handle aliases in the `outputExpressions`.
*/
trait AliasAwareOutputPartitioning extends UnaryExecNode {
trait AliasAwareOutputExpression extends UnaryExecNode {
protected def outputExpressions: Seq[NamedExpression]

protected def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined

protected def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map {
case a: AttributeReference => replaceAlias(a).getOrElse(a)
case other => other
}
}

protected def replaceAlias(attr: AttributeReference): Option[Attribute] = {
outputExpressions.collectFirst {
case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) =>
a.toAttribute
}
}
}

/**
* A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning` that
* satisfies distribution requirements.
*/
trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression {
final override def outputPartitioning: Partitioning = {
if (hasAlias) {
child.outputPartitioning match {
Expand All @@ -36,20 +57,25 @@ trait AliasAwareOutputPartitioning extends UnaryExecNode {
child.outputPartitioning
}
}
}

private def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined

private def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map {
case a: AttributeReference => replaceAlias(a).getOrElse(a)
case other => other
}
}
/**
* A trait that handles aliases in the `orderingExpressions` to produce `outputOrdering` that
* satisfies ordering requirements.
*/
trait AliasAwareOutputOrdering extends AliasAwareOutputExpression {
protected def orderingExpressions: Seq[SortOrder]

private def replaceAlias(attr: AttributeReference): Option[Attribute] = {
outputExpressions.collectFirst {
case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) =>
a.toAttribute
final override def outputOrdering: Seq[SortOrder] = {
if (hasAlias) {
orderingExpressions.map { s =>
s.child match {
case a: AttributeReference => s.copy(child = replaceAlias(a).getOrElse(a))
case _ => s
}
}
} else {
orderingExpressions
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends BaseAggregateExec with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning {
extends BaseAggregateExec
with BlockingOperatorWithCodegen
with AliasAwareOutputPartitioning {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan}
import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, AliasAwareOutputPartitioning, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

/**
Expand All @@ -38,7 +38,9 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends BaseAggregateExec with AliasAwareOutputPartitioning {
extends BaseAggregateExec
with AliasAwareOutputPartitioning
with AliasAwareOutputOrdering {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand Down Expand Up @@ -68,7 +70,7 @@ case class SortAggregateExec(

override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

override def outputOrdering: Seq[SortOrder] = {
override protected def orderingExpressions: Seq[SortOrder] = {
groupingExpressions.map(SortOrder(_, Ascending))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}

/** Physical plan for Project. */
case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning {
extends UnaryExecNode
with CodegenSupport
with AliasAwareOutputPartitioning
with AliasAwareOutputOrdering {

override def output: Seq[Attribute] = projectList.map(_.toAttribute)

Expand Down Expand Up @@ -80,10 +83,10 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
}
}

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override protected def outputExpressions: Seq[NamedExpression] = projectList

override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering

override def verboseStringWithOperatorId(): String = {
s"""
|(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,25 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
}
}

test("aliases in the sort aggregate expressions should not introduce extra sort") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
val t2 = spark.range(20).selectExpr("floor(id/4) as k2")

val agg1 = t1.groupBy("k1").agg(collect_list("k1")).withColumnRenamed("k1", "k3")
val agg2 = t2.groupBy("k2").agg(collect_list("k2"))

val planned = agg1.join(agg2, $"k3" === $"k2").queryExecution.executedPlan
assert(planned.collect { case s: SortAggregateExec => s }.nonEmpty)

// We expect two SortExec nodes on each side of join.
val sorts = planned.collect { case s: SortExec => s }
assert(sorts.size == 4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, this is 5 without the fix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have sorts between SortAggregateExec and SortMergeJoinExec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, nothing between them. Two sorts under SortAggregateExec:

   +- SortAggregate(key=[k2#224L], functions=[collect_list(k2#224L, 0, 0)], output=[k2#224L, collect_list(k2)#236])
      +- *(5) Sort [k2#224L ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(k2#224L, 5), true, [id=#151]
            +- SortAggregate(key=[k2#224L], functions=[partial_collect_list(k2#224L, 0, 0)], output=[k2#224L, buf#254])
               +- *(4) Sort [k2#224L ASC NULLS FIRST], false, 0
                  +- *(4) Project [FLOOR((cast(id#222L as double) / 4.0)) AS k2#224L]
                     +- *(4) Range (0, 20, step=1, splits=2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see

}
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,18 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
}
}

test("sort should not be introduced when aliases are used") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
withTable("t") {
df1.repartition(1).write.format("parquet").bucketBy(8, "i").sortBy("i").saveAsTable("t")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add more tests, e.g., orderBy cases?

Copy link
Contributor Author

@imback82 imback82 Mar 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean something like the following?

val t1 = spark.range(10).selectExpr("id as k1").orderBy("k1")
  .selectExpr("k1 as k11").orderBy("k11")
val t2 = spark.range(10).selectExpr("id as k2").orderBy("k2")
t1.join(t2, t1("k11") === t2("k2")).explain(true)

Extra Sort is optimized away, so it doesn't affect the physical plan (no extra sort).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. How about the aggregate case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test for SortAggregateExec. HashAggregateExec and ObjectAggregateExec are not affected since ordering is not involved with them.

val t1 = spark.table("t")
val t2 = t1.selectExpr("i as ii")
val plan = t1.join(t2, t1("i") === t2("ii")).queryExecution.executedPlan
assert(plan.collect { case sort: SortExec => sort }.isEmpty)
}
}
}

test("bucket join should work with SubqueryAlias plan") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
withTable("t") {
Expand Down