Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}

private def findGroupColumn(alias: Alias): Option[AttributeReference] = alias match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per SortOrder(attr: AttributeReference..., it's always AttributeReference. Should it address Alias?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. If users specify an Alias for group column. It will be Alias(alias: Alias, _).

case alias @ Alias(attr: AttributeReference, name) if attr.name.startsWith("group_col_") =>
Some(AttributeReference(name, attr.dataType)(alias.exprId))
case Alias(alias: Alias, _) => findGroupColumn(alias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's a bit hacky to assume the Alias contains the actual grouping columns. How about we generate the name mapping (grouping attribute to actual group column name) during agg pushdown, put the name mapping in ScanBuilderHolder, and use the mapping to rewrite order by expression during limit pushdown?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

case _ => None
}

private def pushDownLimit(plan: LogicalPlan, limit: Int): (LogicalPlan, Boolean) = plan match {
case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
val (isPushed, isPartiallyPushed) = PushDownUtils.pushLimit(sHolder.builder, limit)
Expand All @@ -410,12 +417,30 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder))
// Without building the Scan, we do not know the resulting column names after aggregate
// push-down, and thus can't push down Top-N which needs to know the ordering column names.
// TODO: we can support simple cases like GROUP BY columns directly and ORDER BY the same
// columns, which we know the resulting column names: the original table columns.
if sHolder.pushedAggregate.isEmpty && filter.isEmpty &&
// In particular, we push down the simple cases like GROUP BY columns directly and ORDER BY
// the same columns, which we know the resulting column names: the original table columns.
// TODO support push down Aggregate with ORDER BY expressions.
if filter.isEmpty &&
CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) =>
val aliasMap = getAliasMap(project)
val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]]

def findGroupColForSortOrder(sortOrder: SortOrder): Option[SortOrder] = sortOrder match {
case SortOrder(attr: AttributeReference, direction, nullOrdering, sameOrderExpressions) =>
findGroupColumn(aliasMap(attr)).filter { groupCol =>
sHolder.relation.output.exists(out => out.semanticEquals(groupCol))
}.map(SortOrder(_, direction, nullOrdering, sameOrderExpressions))
case _ => None
}

val newOrder = if (sHolder.pushedAggregate.isDefined) {
val orderByGroupCols = order.flatMap(findGroupColForSortOrder)
if (orderByGroupCols.length != order.length) {
return (s, false)
}
orderByGroupCols
} else {
order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]]
}
val normalizedOrders = DataSourceStrategy.normalizeExprs(
newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]]
val orders = DataSourceStrategy.translateSortOrders(normalizedOrders)
Expand Down
212 changes: 182 additions & 30 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -718,56 +718,45 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df5,
Seq(Row(1, "cathy", 9000.00, 1200.0, false), Row(1, "amy", 10000.00, 1000.0, true)))

val name = udf { (x: String) => x.matches("cat|dav|amy") }
val sub = udf { (x: String) => x.substring(0, 3) }
val df6 = spark.read
.table("h2.test.employee")
.groupBy("DEPT").sum("SALARY")
.orderBy("DEPT")
.select($"SALARY", $"BONUS", sub($"NAME").as("shortName"))
.filter(name($"shortName"))
.sort($"SALARY".desc)
.limit(1)
// LIMIT is pushed down only if all the filters are pushed down
checkSortRemoved(df6, false)
checkLimitRemoved(df6, false)
checkPushedInfo(df6, "PushedAggregates: [SUM(SALARY)]," +
" PushedFilters: [], PushedGroupByExpressions: [DEPT], ")
checkAnswer(df6, Seq(Row(1, 19000.00)))
checkPushedInfo(df6, "PushedFilters: [], ")
checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy")))

val name = udf { (x: String) => x.matches("cat|dav|amy") }
val sub = udf { (x: String) => x.substring(0, 3) }
val df7 = spark.read
.table("h2.test.employee")
.select($"SALARY", $"BONUS", sub($"NAME").as("shortName"))
.filter(name($"shortName"))
.sort($"SALARY".desc)
.sort(sub($"NAME"))
.limit(1)
// LIMIT is pushed down only if all the filters are pushed down
checkSortRemoved(df7, false)
checkLimitRemoved(df7, false)
checkPushedInfo(df7, "PushedFilters: [], ")
checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy")))
checkAnswer(df7, Seq(Row(2, "alex", 12000.00, 1200.0, false)))

val df8 = spark.read
.table("h2.test.employee")
.sort(sub($"NAME"))
.limit(1)
checkSortRemoved(df8, false)
checkLimitRemoved(df8, false)
checkPushedInfo(df8, "PushedFilters: [], ")
checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false)))

val df9 = spark.read
.table("h2.test.employee")
.select($"DEPT", $"name", $"SALARY",
when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key"))
.sort("key", "dept", "SALARY")
.limit(3)
checkSortRemoved(df9)
checkLimitRemoved(df9)
checkPushedInfo(df9, "PushedFilters: [], " +
checkSortRemoved(df8)
checkLimitRemoved(df8)
checkPushedInfo(df8, "PushedFilters: [], " +
"PushedTopN: " +
"ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " +
"ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,")
checkAnswer(df9,
checkAnswer(df8,
Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0)))

val df10 = spark.read
val df9 = spark.read
.option("partitionColumn", "dept")
.option("lowerBound", "0")
.option("upperBound", "2")
Expand All @@ -777,13 +766,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key"))
.orderBy($"key", $"dept", $"SALARY")
.limit(3)
checkSortRemoved(df10, false)
checkLimitRemoved(df10, false)
checkPushedInfo(df10, "PushedFilters: [], " +
checkSortRemoved(df9, false)
checkLimitRemoved(df9, false)
checkPushedInfo(df9, "PushedFilters: [], " +
"PushedTopN: " +
"ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " +
"ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,")
checkAnswer(df10,
checkAnswer(df9,
Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0)))
}

Expand Down Expand Up @@ -811,6 +800,169 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df2, Seq(Row(2, "david", 10000.00)))
}

test("scan with aggregate push-down and top N push-down") {
val df1 = spark.read
.table("h2.test.employee")
.groupBy("DEPT").sum("SALARY")
.orderBy("DEPT")
.limit(1)
checkSortRemoved(df1)
checkLimitRemoved(df1)
checkPushedInfo(df1,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 1")
checkAnswer(df1, Seq(Row(1, 19000.00)))

val df2 = sql(
"""
|SELECT dept AS my_dept, SUM(SALARY) FROM h2.test.employee
|GROUP BY dept
|ORDER BY my_dept
|LIMIT 1
|""".stripMargin)
checkSortRemoved(df2)
checkLimitRemoved(df2)
checkPushedInfo(df2,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 1")
checkAnswer(df2, Seq(Row(1, 19000.00)))

val df3 = spark.read
.table("h2.test.employee")
.select($"SALARY",
when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key"))
.groupBy("key").sum("SALARY")
.orderBy("key")
.limit(1)
checkSortRemoved(df3, false)
checkLimitRemoved(df3, false)
checkPushedInfo(df3,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: " +
"[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END]",
"PushedFilters: []")
checkAnswer(df3, Seq(Row(0, 44000.00)))

val df4 = sql(
"""
|SELECT dept, SUM(SALARY) FROM h2.test.employee
|GROUP BY dept
|ORDER BY SUM(SALARY)
|LIMIT 1
|""".stripMargin)
checkSortRemoved(df4, false)
checkLimitRemoved(df4, false)
checkPushedInfo(df4,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []")
checkAnswer(df4, Seq(Row(6, 12000.00)))

val df5 = sql(
"""
|SELECT dept, SUM(SALARY) AS total FROM h2.test.employee
|GROUP BY dept
|ORDER BY total
|LIMIT 1
|""".stripMargin)
checkSortRemoved(df5, false)
checkLimitRemoved(df5, false)
checkPushedInfo(df5,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []")
checkAnswer(df5, Seq(Row(6, 12000.00)))
}

test("scan with aggregate push-down and paging push-down") {
val df1 = spark.read
.table("h2.test.employee")
.groupBy("DEPT").sum("SALARY")
.orderBy("DEPT")
.offset(1)
.limit(1)
checkSortRemoved(df1)
checkLimitRemoved(df1)
checkPushedInfo(df1,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedOffset: OFFSET 1",
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 2")
checkAnswer(df1, Seq(Row(2, 22000.00)))

val df2 = sql(
"""
|SELECT dept AS my_dept, SUM(SALARY) FROM h2.test.employee
|GROUP BY dept
|ORDER BY my_dept
|LIMIT 1
|OFFSET 1
|""".stripMargin)
checkSortRemoved(df2)
checkLimitRemoved(df2)
checkPushedInfo(df2,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedOffset: OFFSET 1",
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 2")
checkAnswer(df2, Seq(Row(2, 22000.00)))

val df3 = spark.read
.table("h2.test.employee")
.select($"SALARY",
when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key"))
.groupBy("key").sum("SALARY")
.orderBy("key")
.offset(1)
.limit(1)
checkSortRemoved(df3, false)
checkLimitRemoved(df3, false)
checkPushedInfo(df3,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: " +
"[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END]",
"PushedFilters: []")
checkAnswer(df3, Seq(Row(9000, 9000.00)))

val df4 = sql(
"""
|SELECT dept, SUM(SALARY) FROM h2.test.employee
|GROUP BY dept
|ORDER BY SUM(SALARY)
|LIMIT 1
|OFFSET 1
|""".stripMargin)
checkSortRemoved(df4, false)
checkLimitRemoved(df4, false)
checkPushedInfo(df4,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []")
checkAnswer(df4, Seq(Row(1, 19000.00)))

val df5 = sql(
"""
|SELECT dept, SUM(SALARY) AS total FROM h2.test.employee
|GROUP BY dept
|ORDER BY total
|LIMIT 1
|OFFSET 1
|""".stripMargin)
checkSortRemoved(df5, false)
checkLimitRemoved(df5, false)
checkPushedInfo(df5,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []")
checkAnswer(df5, Seq(Row(1, 19000.00)))
}

test("scan with filter push-down") {
val df = spark.table("h2.test.people").filter($"id" > 1)
checkFiltersRemoved(df)
Expand Down