Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,28 @@ object CollapseWindow extends Rule[LogicalPlan] {
}
}

/**
* Transpose Adjacent Window Expressions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this rule useful?

* - If the partition spec of the parent Window expression is compatible with the partition spec
* of the child window expression, transpose them.
*/
object TransposeWindow extends Rule[LogicalPlan] {
private def compatibleParititions(ps1 : Seq[Expression], ps2: Seq[Expression]): Boolean = {
ps1.length < ps2.length && ps2.take(ps1.length).permutations.exists(ps1.zip(_).forall {
case (l, r) => l.semanticEquals(r)
})
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
Copy link
Member

@gatorsmile gatorsmile Jun 30, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expressions in both w1.expressions and w2.expressions must be deterministic. If not, we should not flip

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? This seems overly restrictive to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to ensure the results are still the same with and without the rule.

if w1.references.intersect(w2.windowOutputSet).isEmpty &&
w1.expressions.forall(_.deterministic) &&
w2.expressions.forall(_.deterministic) &&
compatibleParititions(ps1, ps2) =>
Project(w1.output, Window(we2, ps2, os2, Window(we1, ps1, os1, grandChild)))
}
}

/**
* Generate a list of additional filters from an operator's existing constraint but remove those
* that are either already part of the operator's condition or are part of the operator's child
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class TransposeWindowSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("CollapseProject", FixedPoint(100), CollapseProject, RemoveRedundantProject) ::
Batch("FlipWindow", Once, CollapseWindow, TransposeWindow) :: Nil
}

val testRelation = LocalRelation('a.string, 'b.string, 'c.int, 'd.string)

val a = testRelation.output(0)
val b = testRelation.output(1)
val c = testRelation.output(2)
val d = testRelation.output(3)

val partitionSpec1 = Seq(a)
val partitionSpec2 = Seq(a, b)
val partitionSpec3 = Seq(d)
val partitionSpec4 = Seq(b, a, d)

val orderSpec1 = Seq(d.asc)
val orderSpec2 = Seq(d.desc)

test("transpose two adjacent windows with compatible partitions") {
val query = testRelation
.window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2)
.window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1)

val analyzed = query.analyze
val optimized = Optimize.execute(analyzed)

val correctAnswer = testRelation
.window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1)
.window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2)
.select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1)

comparePlans(optimized, correctAnswer.analyze)
}

test("transpose two adjacent windows with differently ordered compatible partitions") {
val query = testRelation
.window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty)
.window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty)

val analyzed = query.analyze
val optimized = Optimize.execute(analyzed)

val correctAnswer = testRelation
.window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty)
.window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty)
.select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1)

comparePlans(optimized, correctAnswer.analyze)
}

test("don't transpose two adjacent windows with incompatible partitions") {
val query = testRelation
.window(Seq(sum(c).as('sum_a_2)), partitionSpec3, Seq.empty)
.window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty)

val analyzed = query.analyze
val optimized = Optimize.execute(analyzed)

comparePlans(optimized, analyzed)
}

test("don't transpose two adjacent windows with intersection of partition and output set") {
val query = testRelation
.window(Seq(('a + 'b).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty)
.window(Seq(sum(c).as('sum_a_1)), Seq(a, 'e), Seq.empty)

val analyzed = query.analyze
val optimized = Optimize.execute(analyzed)

comparePlans(optimized, analyzed)
}

test("don't transpose two adjacent windows with non-deterministic expressions") {
val query = testRelation
.window(Seq(Rand(0).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty)
.window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty)

val analyzed = query.analyze
val optimized = Optimize.execute(analyzed)

comparePlans(optimized, analyzed)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._
* Window function testing for DataFrame API.
*/
class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {

import testImplicits._

test("reuse window partitionBy") {
Expand Down Expand Up @@ -72,9 +73,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
cume_dist().over(Window.partitionBy("value").orderBy("key")),
percent_rank().over(Window.partitionBy("value").orderBy("key"))),
Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) ::
Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) ::
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) ::
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil)
Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) ::
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) ::
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil)
}

test("window function should fail if order by clause is not specified") {
Expand Down Expand Up @@ -162,12 +163,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(
Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ),
Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ),
Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("c", 0.0, 0.0, 0.0, 0.0, 0.0),
Row("d", 0.0, 0.0, 0.0, 0.0, 0.0),
Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)))
}

Expand Down Expand Up @@ -326,7 +327,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
var_samp($"value").over(window),
approx_count_distinct($"value").over(window)),
Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2))
++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3)))
++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3)))
}

test("window function with aggregates") {
Expand Down Expand Up @@ -624,7 +625,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {

test("SPARK-24575: Window functions inside WHERE and HAVING clauses") {
def checkAnalysisError(df: => DataFrame): Unit = {
val thrownException = the [AnalysisException] thrownBy {
val thrownException = the[AnalysisException] thrownBy {
df.queryExecution.analyzed
}
assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses"))
Expand Down Expand Up @@ -658,4 +659,26 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
|GROUP BY a
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
}

test("window functions in multiple selects") {
val df = Seq(
("S1", "P1", 100),
("S1", "P1", 700),
("S2", "P1", 200),
("S2", "P2", 300)
).toDF("sno", "pno", "qty")

val w1 = Window.partitionBy("sno")
val w2 = Window.partitionBy("sno", "pno")

checkAnswer(
df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2"))
.select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")),
Seq(
Row("S1", "P1", 100, 800, 800),
Row("S1", "P1", 700, 800, 800),
Row("S2", "P1", 200, 200, 500),
Row("S2", "P2", 300, 300, 500)))

}
}