Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ package object dsl {

def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)

def localLimit(limitExpr: Expression): LogicalPlan = LocalLimit(limitExpr, logicalPlan)

def offset(offsetExpr: Expression): LogicalPlan = Offset(offsetExpr, logicalPlan)

def join(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,9 @@ object LimitPushDown extends Rule[LogicalPlan] {
LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join)))
// Push down limit 1 through Aggregate and turn Aggregate into Project if it is group only.
case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly =>
Limit(le, Project(a.aggregateExpressions, LocalLimit(le, a.child)))
val project = Project(a.aggregateExpressions, LocalLimit(le, a.child))
project.setTagValue(Project.dataOrderIrrelevantTag, ())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the EliminateSorts, it's data order relevant if only group expressions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's irrelevant, see

private def isOrderIrrelevantAggs(aggs: Seq[NamedExpression]): Boolean = {
def isOrderIrrelevantAggFunction(func: AggregateFunction): Boolean = func match {
case _: Min | _: Max | _: Count | _: BitAggregate => true
// Arithmetic operations for floating-point values are order-sensitive
// (they are not associative).
case _: Sum | _: Average | _: CentralMomentAgg =>
!Seq(FloatType, DoubleType)
.exists(e => DataTypeUtils.sameType(e, func.children.head.dataType))
case _ => false
}
def checkValidAggregateExpression(expr: Expression): Boolean = expr match {
case _: AttributeReference => true
case ae: AggregateExpression => isOrderIrrelevantAggFunction(ae.aggregateFunction)
case _: UserDefinedExpression => false
case e => e.children.forall(checkValidAggregateExpression)
}
aggs.forall(checkValidAggregateExpression)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missing the checkValidAggregateExpression.

Limit(le, project)
case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly =>
Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le, a.child))))
// Merge offset value and limit value into LocalLimit and pushes down LocalLimit through Offset.
Expand Down Expand Up @@ -1583,6 +1585,8 @@ object EliminateSorts extends Rule[LogicalPlan] {
right = recursiveRemoveSort(originRight, true))
case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) =>
g.copy(child = recursiveRemoveSort(originChild, true))
case p: Project if p.getTagValue(Project.dataOrderIrrelevantTag).isDefined =>
p.copy(child = recursiveRemoveSort(p.child, true))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)

object Project {
val hiddenOutputTag: TreeNodeTag[Seq[Attribute]] = TreeNodeTag[Seq[Attribute]]("hidden_output")
// Project with this tag means it doesn't care about the data order of its input. We only set
// this tag when the Project was converted from grouping-only Aggregate.
val dataOrderIrrelevantTag: TreeNodeTag[Unit] = TreeNodeTag[Unit]("data_order_irrelevant")

def matchSchema(plan: LogicalPlan, schema: StructType, conf: SQLConf): Project = {
assert(plan.resolved)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,4 +478,16 @@ class EliminateSortsSuite extends AnalysisTest {

comparePlans(Optimize.execute(originalPlan.analyze), correctAnswer.analyze)
}

test("SPARK-46378: Still remove Sort after converting Aggregate to Project") {
val originalPlan = testRelation.orderBy($"a".asc)
.groupBy($"a")($"a")
.limit(1)

val correctAnswer = testRelation.localLimit(1)
.select($"a")
.limit(1)

comparePlans(Optimize.execute(originalPlan.analyze), correctAnswer.analyze)
}
}