Skip to content

Commit 78981ef

Browse files
ptkoolgatorsmile
authored andcommitted
[SPARK-20636] Add new optimization rule to transpose adjacent Window expressions.
## What changes were proposed in this pull request? Add new optimization rule to eliminate unnecessary shuffling by flipping adjacent Window expressions. ## How was this patch tested? Tested with unit tests, integration tests, and manual tests. Closes #17899 from ptkool/adjacent_window_optimization. Authored-by: ptkool <[email protected]> Signed-off-by: gatorsmile <[email protected]>
1 parent 26f74b7 commit 78981ef

File tree

3 files changed

+170
-11
lines changed

3 files changed

+170
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,28 @@ object CollapseWindow extends Rule[LogicalPlan] {
734734
}
735735
}
736736

737+
/**
738+
* Transpose Adjacent Window Expressions.
739+
* - If the partition spec of the parent Window expression is compatible with the partition spec
740+
* of the child window expression, transpose them.
741+
*/
742+
object TransposeWindow extends Rule[LogicalPlan] {
743+
private def compatibleParititions(ps1 : Seq[Expression], ps2: Seq[Expression]): Boolean = {
744+
ps1.length < ps2.length && ps2.take(ps1.length).permutations.exists(ps1.zip(_).forall {
745+
case (l, r) => l.semanticEquals(r)
746+
})
747+
}
748+
749+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
750+
case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
751+
if w1.references.intersect(w2.windowOutputSet).isEmpty &&
752+
w1.expressions.forall(_.deterministic) &&
753+
w2.expressions.forall(_.deterministic) &&
754+
compatibleParititions(ps1, ps2) =>
755+
Project(w1.output, Window(we2, ps2, os2, Window(we1, ps1, os1, grandChild)))
756+
}
757+
}
758+
737759
/**
738760
* Generate a list of additional filters from an operator's existing constraint but remove those
739761
* that are either already part of the operator's condition or are part of the operator's child
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.dsl.plans._
22+
import org.apache.spark.sql.catalyst.expressions.Rand
23+
import org.apache.spark.sql.catalyst.plans.PlanTest
24+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
25+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
26+
27+
class TransposeWindowSuite extends PlanTest {
28+
object Optimize extends RuleExecutor[LogicalPlan] {
29+
val batches =
30+
Batch("CollapseProject", FixedPoint(100), CollapseProject, RemoveRedundantProject) ::
31+
Batch("FlipWindow", Once, CollapseWindow, TransposeWindow) :: Nil
32+
}
33+
34+
val testRelation = LocalRelation('a.string, 'b.string, 'c.int, 'd.string)
35+
36+
val a = testRelation.output(0)
37+
val b = testRelation.output(1)
38+
val c = testRelation.output(2)
39+
val d = testRelation.output(3)
40+
41+
val partitionSpec1 = Seq(a)
42+
val partitionSpec2 = Seq(a, b)
43+
val partitionSpec3 = Seq(d)
44+
val partitionSpec4 = Seq(b, a, d)
45+
46+
val orderSpec1 = Seq(d.asc)
47+
val orderSpec2 = Seq(d.desc)
48+
49+
test("transpose two adjacent windows with compatible partitions") {
50+
val query = testRelation
51+
.window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2)
52+
.window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1)
53+
54+
val analyzed = query.analyze
55+
val optimized = Optimize.execute(analyzed)
56+
57+
val correctAnswer = testRelation
58+
.window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1)
59+
.window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2)
60+
.select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1)
61+
62+
comparePlans(optimized, correctAnswer.analyze)
63+
}
64+
65+
test("transpose two adjacent windows with differently ordered compatible partitions") {
66+
val query = testRelation
67+
.window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty)
68+
.window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty)
69+
70+
val analyzed = query.analyze
71+
val optimized = Optimize.execute(analyzed)
72+
73+
val correctAnswer = testRelation
74+
.window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty)
75+
.window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty)
76+
.select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1)
77+
78+
comparePlans(optimized, correctAnswer.analyze)
79+
}
80+
81+
test("don't transpose two adjacent windows with incompatible partitions") {
82+
val query = testRelation
83+
.window(Seq(sum(c).as('sum_a_2)), partitionSpec3, Seq.empty)
84+
.window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty)
85+
86+
val analyzed = query.analyze
87+
val optimized = Optimize.execute(analyzed)
88+
89+
comparePlans(optimized, analyzed)
90+
}
91+
92+
test("don't transpose two adjacent windows with intersection of partition and output set") {
93+
val query = testRelation
94+
.window(Seq(('a + 'b).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty)
95+
.window(Seq(sum(c).as('sum_a_1)), Seq(a, 'e), Seq.empty)
96+
97+
val analyzed = query.analyze
98+
val optimized = Optimize.execute(analyzed)
99+
100+
comparePlans(optimized, analyzed)
101+
}
102+
103+
test("don't transpose two adjacent windows with non-deterministic expressions") {
104+
val query = testRelation
105+
.window(Seq(Rand(0).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty)
106+
.window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty)
107+
108+
val analyzed = query.analyze
109+
val optimized = Optimize.execute(analyzed)
110+
111+
comparePlans(optimized, analyzed)
112+
}
113+
114+
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.types._
3030
* Window function testing for DataFrame API.
3131
*/
3232
class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
33+
3334
import testImplicits._
3435

3536
test("reuse window partitionBy") {
@@ -72,9 +73,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
7273
cume_dist().over(Window.partitionBy("value").orderBy("key")),
7374
percent_rank().over(Window.partitionBy("value").orderBy("key"))),
7475
Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) ::
75-
Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) ::
76-
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) ::
77-
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil)
76+
Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) ::
77+
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) ::
78+
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil)
7879
}
7980

8081
test("window function should fail if order by clause is not specified") {
@@ -162,12 +163,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
162163
Seq(
163164
Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
164165
Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
165-
Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ),
166-
Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ),
167-
Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
168-
Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
169-
Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
170-
Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
166+
Row("c", 0.0, 0.0, 0.0, 0.0, 0.0),
167+
Row("d", 0.0, 0.0, 0.0, 0.0, 0.0),
168+
Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
169+
Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
170+
Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
171+
Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
171172
Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)))
172173
}
173174

@@ -326,7 +327,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
326327
var_samp($"value").over(window),
327328
approx_count_distinct($"value").over(window)),
328329
Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2))
329-
++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3)))
330+
++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3)))
330331
}
331332

332333
test("window function with aggregates") {
@@ -624,7 +625,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
624625

625626
test("SPARK-24575: Window functions inside WHERE and HAVING clauses") {
626627
def checkAnalysisError(df: => DataFrame): Unit = {
627-
val thrownException = the [AnalysisException] thrownBy {
628+
val thrownException = the[AnalysisException] thrownBy {
628629
df.queryExecution.analyzed
629630
}
630631
assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses"))
@@ -658,4 +659,26 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
658659
|GROUP BY a
659660
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
660661
}
662+
663+
test("window functions in multiple selects") {
664+
val df = Seq(
665+
("S1", "P1", 100),
666+
("S1", "P1", 700),
667+
("S2", "P1", 200),
668+
("S2", "P2", 300)
669+
).toDF("sno", "pno", "qty")
670+
671+
val w1 = Window.partitionBy("sno")
672+
val w2 = Window.partitionBy("sno", "pno")
673+
674+
checkAnswer(
675+
df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2"))
676+
.select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")),
677+
Seq(
678+
Row("S1", "P1", 100, 800, 800),
679+
Row("S1", "P1", 700, 800, 800),
680+
Row("S2", "P1", 200, 200, 500),
681+
Row("S2", "P2", 300, 300, 500)))
682+
683+
}
661684
}

0 commit comments

Comments
 (0)