Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ case class SortOrder(child: Expression, direction: SortDirection)
override def sql: String = child.sql + " " + direction.sql

def isAscending: Boolean = direction == Ascending

def semanticEquals(other: SortOrder): Boolean =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expression has a default implementation of semanticEquals, doesn't it work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan : If you look at the old version of EnsureRequirements below at L253, it compared raw SortOrder objects which will use equals() generated for it. In scala, equals() for case classes is merely doing equals() over all its fields so that lead to Expression's equals() being used instead of its semanticEquals().

My fix here was to introduce a semanticEquals in SortOrder which compares the underlying Expression semantically.

Copy link
Contributor

@cloud-fan cloud-fan Sep 1, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I understand in EnsureRequirements we should use semanticEquals instead of == to compare SortOrder, but why we need to implement samanticEquals again in SortOrder? What's wrong with the default implementation?

I mean, there is no need to "introduce" a semanticEquals in SortOrder, it already has, because SortOrder is also Expression

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan : I see what you were trying to say before. I tried that and it worked. I have created a PR to clean it up : #14910 Thanks for pointing this out !!

(direction == other.direction) && child.semanticEquals(other.child)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
if (requiredOrdering.nonEmpty) {
// If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) {
val orderingMatched = if (requiredOrdering.length > child.outputOrdering.length) {
false
} else {
requiredOrdering.zip(child.outputOrdering).forall {
case (requiredOrder, childOutputOrder) =>
requiredOrder.semanticEquals(childOutputOrder)
}
}

if (!orderingMatched) {
SortExec(requiredOrdering, global = false, child = child)
} else {
child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -444,6 +444,44 @@ class PlannerSuite extends SharedSQLContext {
}
}

test("EnsureRequirements skips sort when required ordering is semantically equal to " +
"existing ordering") {
val exprId: ExprId = NamedExpression.newExprId
val attribute1 =
AttributeReference(
name = "col1",
dataType = LongType,
nullable = false
) (exprId = exprId,
qualifier = Some("col1_qualifier")
)

val attribute2 =
AttributeReference(
name = "col1",
dataType = LongType,
nullable = false
) (exprId = exprId)

val orderingA1 = SortOrder(attribute1, Ascending)
val orderingA2 = SortOrder(attribute2, Ascending)

assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2")
assert(orderingA1.semanticEquals(orderingA2),
s"$orderingA1 should be semantically equal to $orderingA2")

val inputPlan = DummySparkPlan(
children = DummySparkPlan(outputOrdering = Seq(orderingA1)) :: Nil,
requiredChildOrdering = Seq(Seq(orderingA2)),
requiredChildDistribution = Seq(UnspecifiedDistribution)
)
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case s: SortExec => true }.nonEmpty) {
fail(s"No sorts should have been added:\n$outputPlan")
}
}

// This is a regression test for SPARK-11135
test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") {
val orderingA = SortOrder(Literal(1), Ascending)
Expand Down