From 2e0fbe1965682b847d93efe69d7235ed8d3013da Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 5 Nov 2018 17:27:53 +0800 Subject: [PATCH] Add new optimization rule to eliminate unnecessary sort by exchanged adjacent Window expressions --- .../sql/catalyst/optimizer/Optimizer.scala | 23 ++++++++++++++++++- .../optimizer/CollapseWindowSuite.scala | 20 +++++++++++++++- .../sql/DataFrameWindowFunctionsSuite.scala | 18 +++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a330a84a3a24f..28b9d3f1edb47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -120,7 +120,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ComputeCurrentTime, GetCurrentDatabase(sessionCatalog), RewriteDistinctAggregates, - ReplaceDeduplicateWithAggregate) :: + ReplaceDeduplicateWithAggregate, + ExchangeWindowWithOrderField) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// @@ -740,6 +741,26 @@ object CollapseWindow extends Rule[LogicalPlan] { } } +/** + * Exchanged the adjacent logical window operator according to the order field of window. + */ +object ExchangeWindowWithOrderField extends Rule[LogicalPlan] { + private def compatibleOrder(os1 : Seq[SortOrder], os2: Seq[SortOrder]): Boolean = { + os1.length < os2.length && os2.take(os1.length).zip(os1).exists({ + case (l, r) => l.semanticEquals(r) + }) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) + if w1.references.intersect(w2.windowOutputSet).isEmpty && + ps1 == ps2 && compatibleOrder(os2, os1) && + WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) => + val newWindow = w1.copy(child = grandChild) + w2.copy(child = newWindow) + } +} + /** * Transpose Adjacent Window Expressions. * - If the partition spec of the parent Window expression is compatible with the partition spec diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala index 52054c2f8bd8d..d6079bd3815c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala @@ -27,7 +27,8 @@ class CollapseWindowSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("CollapseWindow", FixedPoint(10), - CollapseWindow) :: Nil + CollapseWindow, + ExchangeWindowWithOrderField) :: Nil } val testRelation = LocalRelation('a.double, 'b.double, 'c.string) @@ -38,6 +39,8 @@ class CollapseWindowSuite extends PlanTest { val partitionSpec2 = Seq(c + 1) val orderSpec1 = Seq(c.asc) val orderSpec2 = Seq(c.desc) + val orderSpecAB = Seq(a.asc, b.asc) + val orderSpecA = Seq(a.asc) test("collapse two adjacent windows with the same partition/order") { val query = testRelation @@ -89,4 +92,19 @@ class CollapseWindowSuite extends PlanTest { val optimized = Optimize.execute(query.analyze) comparePlans(optimized, expected) } + + test("Exchanged two adjacent windows with order field") { + val query = testRelation + .window(Seq(max(b).as('max_b)), partitionSpec1, orderSpecA) + .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpecAB) + + val optimized = Optimize.execute(query.analyze) + + val expected = testRelation + .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpecAB) + .window(Seq(max(b).as('max_b)), partitionSpec1, orderSpecA) + .analyze + + comparePlans(optimized, expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 78277d7dcf757..d6bebcb81a3ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} +import org.apache.spark.sql.execution.SortExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -681,4 +682,21 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("S2", "P2", 300, 300, 500))) } + + test("Eliminate unnecessary sort by exchanged adjacent Window expressions.") { + val df = Seq( + ("a", "p1", 10.0, 20.0, 30.0), + ("a", "p2", 20.0, 10.0, 40.0)).toDF("key", "value", "value1", "value2", "value3") + val df1 = + df.select( + $"key", + sum("value1").over(Window.partitionBy("key").orderBy("value")), + max("value2").over(Window.partitionBy("key").orderBy("value", "value1")), + avg("value3").over(Window.partitionBy("key").orderBy("value", "value1", "value2")) + ) + val result = Seq(Row("a", 10.0, 20.0, 30.0), Row("a", 30.0, 20.0, 35.0)) + checkAnswer(df1, result) + assert( + df1.queryExecution.executedPlan.collect { case e: SortExec => true }.size == 1) + } }