diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 01b0ae451b2a9..59954f88c36bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -400,6 +400,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def findGroupColumn(alias: Alias): Option[AttributeReference] = alias match { + case alias @ Alias(attr: AttributeReference, name) if attr.name.startsWith("group_col_") => + Some(AttributeReference(name, attr.dataType)(alias.exprId)) + case Alias(alias: Alias, _) => findGroupColumn(alias) + case _ => None + } + private def pushDownLimit(plan: LogicalPlan, limit: Int): (LogicalPlan, Boolean) = plan match { case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => val (isPushed, isPartiallyPushed) = PushDownUtils.pushLimit(sHolder.builder, limit) @@ -410,12 +417,30 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder)) // Without building the Scan, we do not know the resulting column names after aggregate // push-down, and thus can't push down Top-N which needs to know the ordering column names. - // TODO: we can support simple cases like GROUP BY columns directly and ORDER BY the same - // columns, which we know the resulting column names: the original table columns. - if sHolder.pushedAggregate.isEmpty && filter.isEmpty && + // In particular, we push down the simple cases like GROUP BY columns directly and ORDER BY + // the same columns, which we know the resulting column names: the original table columns. + // TODO support push down Aggregate with ORDER BY expressions. + if filter.isEmpty && CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) => val aliasMap = getAliasMap(project) - val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] + + def findGroupColForSortOrder(sortOrder: SortOrder): Option[SortOrder] = sortOrder match { + case SortOrder(attr: AttributeReference, direction, nullOrdering, sameOrderExpressions) => + findGroupColumn(aliasMap(attr)).filter { groupCol => + sHolder.relation.output.exists(out => out.semanticEquals(groupCol)) + }.map(SortOrder(_, direction, nullOrdering, sameOrderExpressions)) + case _ => None + } + + val newOrder = if (sHolder.pushedAggregate.isDefined) { + val orderByGroupCols = order.flatMap(findGroupColForSortOrder) + if (orderByGroupCols.length != order.length) { + return (s, false) + } + orderByGroupCols + } else { + order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] + } val normalizedOrders = DataSourceStrategy.normalizeExprs( newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]] val orders = DataSourceStrategy.translateSortOrders(normalizedOrders) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index d64b181500753..80714c7cde517 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -718,56 +718,45 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0, false), Row(1, "amy", 10000.00, 1000.0, true))) + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } val df6 = spark.read .table("h2.test.employee") - .groupBy("DEPT").sum("SALARY") - .orderBy("DEPT") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .sort($"SALARY".desc) .limit(1) + // LIMIT is pushed down only if all the filters are pushed down checkSortRemoved(df6, false) checkLimitRemoved(df6, false) - checkPushedInfo(df6, "PushedAggregates: [SUM(SALARY)]," + - " PushedFilters: [], PushedGroupByExpressions: [DEPT], ") - checkAnswer(df6, Seq(Row(1, 19000.00))) + checkPushedInfo(df6, "PushedFilters: [], ") + checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy"))) - val name = udf { (x: String) => x.matches("cat|dav|amy") } - val sub = udf { (x: String) => x.substring(0, 3) } val df7 = spark.read .table("h2.test.employee") - .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) - .filter(name($"shortName")) - .sort($"SALARY".desc) + .sort(sub($"NAME")) .limit(1) - // LIMIT is pushed down only if all the filters are pushed down checkSortRemoved(df7, false) checkLimitRemoved(df7, false) checkPushedInfo(df7, "PushedFilters: [], ") - checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) + checkAnswer(df7, Seq(Row(2, "alex", 12000.00, 1200.0, false))) val df8 = spark.read - .table("h2.test.employee") - .sort(sub($"NAME")) - .limit(1) - checkSortRemoved(df8, false) - checkLimitRemoved(df8, false) - checkPushedInfo(df8, "PushedFilters: [], ") - checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false))) - - val df9 = spark.read .table("h2.test.employee") .select($"DEPT", $"name", $"SALARY", when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) .sort("key", "dept", "SALARY") .limit(3) - checkSortRemoved(df9) - checkLimitRemoved(df9) - checkPushedInfo(df9, "PushedFilters: [], " + + checkSortRemoved(df8) + checkLimitRemoved(df8) + checkPushedInfo(df8, "PushedFilters: [], " + "PushedTopN: " + "ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " + "ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,") - checkAnswer(df9, + checkAnswer(df8, Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0))) - val df10 = spark.read + val df9 = spark.read .option("partitionColumn", "dept") .option("lowerBound", "0") .option("upperBound", "2") @@ -777,13 +766,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) .orderBy($"key", $"dept", $"SALARY") .limit(3) - checkSortRemoved(df10, false) - checkLimitRemoved(df10, false) - checkPushedInfo(df10, "PushedFilters: [], " + + checkSortRemoved(df9, false) + checkLimitRemoved(df9, false) + checkPushedInfo(df9, "PushedFilters: [], " + "PushedTopN: " + "ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " + "ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,") - checkAnswer(df10, + checkAnswer(df9, Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0))) } @@ -811,6 +800,169 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df2, Seq(Row(2, "david", 10000.00))) } + test("scan with aggregate push-down and top N push-down") { + val df1 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .orderBy("DEPT") + .limit(1) + checkSortRemoved(df1) + checkLimitRemoved(df1) + checkPushedInfo(df1, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 1") + checkAnswer(df1, Seq(Row(1, 19000.00))) + + val df2 = sql( + """ + |SELECT dept AS my_dept, SUM(SALARY) FROM h2.test.employee + |GROUP BY dept + |ORDER BY my_dept + |LIMIT 1 + |""".stripMargin) + checkSortRemoved(df2) + checkLimitRemoved(df2) + checkPushedInfo(df2, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 1") + checkAnswer(df2, Seq(Row(1, 19000.00))) + + val df3 = spark.read + .table("h2.test.employee") + .select($"SALARY", + when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) + .groupBy("key").sum("SALARY") + .orderBy("key") + .limit(1) + checkSortRemoved(df3, false) + checkLimitRemoved(df3, false) + checkPushedInfo(df3, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: " + + "[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END]", + "PushedFilters: []") + checkAnswer(df3, Seq(Row(0, 44000.00))) + + val df4 = sql( + """ + |SELECT dept, SUM(SALARY) FROM h2.test.employee + |GROUP BY dept + |ORDER BY SUM(SALARY) + |LIMIT 1 + |""".stripMargin) + checkSortRemoved(df4, false) + checkLimitRemoved(df4, false) + checkPushedInfo(df4, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []") + checkAnswer(df4, Seq(Row(6, 12000.00))) + + val df5 = sql( + """ + |SELECT dept, SUM(SALARY) AS total FROM h2.test.employee + |GROUP BY dept + |ORDER BY total + |LIMIT 1 + |""".stripMargin) + checkSortRemoved(df5, false) + checkLimitRemoved(df5, false) + checkPushedInfo(df5, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []") + checkAnswer(df5, Seq(Row(6, 12000.00))) + } + + test("scan with aggregate push-down and paging push-down") { + val df1 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .orderBy("DEPT") + .offset(1) + .limit(1) + checkSortRemoved(df1) + checkLimitRemoved(df1) + checkPushedInfo(df1, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 2") + checkAnswer(df1, Seq(Row(2, 22000.00))) + + val df2 = sql( + """ + |SELECT dept AS my_dept, SUM(SALARY) FROM h2.test.employee + |GROUP BY dept + |ORDER BY my_dept + |LIMIT 1 + |OFFSET 1 + |""".stripMargin) + checkSortRemoved(df2) + checkLimitRemoved(df2) + checkPushedInfo(df2, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 2") + checkAnswer(df2, Seq(Row(2, 22000.00))) + + val df3 = spark.read + .table("h2.test.employee") + .select($"SALARY", + when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) + .groupBy("key").sum("SALARY") + .orderBy("key") + .offset(1) + .limit(1) + checkSortRemoved(df3, false) + checkLimitRemoved(df3, false) + checkPushedInfo(df3, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: " + + "[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END]", + "PushedFilters: []") + checkAnswer(df3, Seq(Row(9000, 9000.00))) + + val df4 = sql( + """ + |SELECT dept, SUM(SALARY) FROM h2.test.employee + |GROUP BY dept + |ORDER BY SUM(SALARY) + |LIMIT 1 + |OFFSET 1 + |""".stripMargin) + checkSortRemoved(df4, false) + checkLimitRemoved(df4, false) + checkPushedInfo(df4, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []") + checkAnswer(df4, Seq(Row(1, 19000.00))) + + val df5 = sql( + """ + |SELECT dept, SUM(SALARY) AS total FROM h2.test.employee + |GROUP BY dept + |ORDER BY total + |LIMIT 1 + |OFFSET 1 + |""".stripMargin) + checkSortRemoved(df5, false) + checkLimitRemoved(df5, false) + checkPushedInfo(df5, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []") + checkAnswer(df5, Seq(Row(1, 19000.00))) + } + test("scan with filter push-down") { val df = spark.table("h2.test.people").filter($"id" > 1) checkFiltersRemoved(df)