From 5adc74354821cb029cc1833aedd9f6df882e20ba Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 15 Apr 2016 10:56:39 -0700 Subject: [PATCH 1/4] [SPARK-14664][SQL] Fix DecimalAggregates optimizer not to break Window queries --- .../sql/catalyst/optimizer/Optimizer.scala | 37 +++++++--- .../optimizer/DecimalAggregatesSuite.scala | 70 +++++++++++++++++++ .../spark/sql/execution/WindowExec.scala | 2 + .../spark/sql/DataFrameAggregateSuite.scala | 11 ++- 4 files changed, 108 insertions(+), 12 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b26ceba228963..f9b03d0d8b700 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1343,17 +1343,32 @@ object DecimalAggregates extends Rule[LogicalPlan] { /** Maximum number of decimal digits representable precisely in a Double */ private val MAX_DOUBLE_DIGITS = 15 - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _) - if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - - case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _) - if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) - Cast( - Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), - DecimalType(prec + 4, scale + 4)) + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + case we @ WindowExpression(ae @ AggregateExpression(Sum( + e @ DecimalType.Expression(prec, scale)), _, _, _), _) if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), + prec + 10, scale) + + case we @ WindowExpression(ae @ AggregateExpression(Average( + e @ DecimalType.Expression(prec, scale)), _, _, _), _) if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = + we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) + Cast( + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + + case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _) + if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) + + case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _) + if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) + Cast( + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala new file mode 100644 index 0000000000000..970a50d2473d2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{MakeDecimal, UnscaledValue} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.DecimalType + +class DecimalAggregatesSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Decimal Optimizations", FixedPoint(100), + DecimalAggregates) :: Nil + } + + val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1)) + + test("Decimal Sum Aggregation Optimize") { + val originalQuery = testRelation.select(sum('a)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(MakeDecimal(sum(UnscaledValue('a)), 12, 1).as("sum(a)")).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Sum Aggregation: Not Optimized") { + val originalQuery = testRelation.select(sum('b)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Average Aggregation") { + val originalQuery = testRelation.select(avg('a)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select((avg(UnscaledValue('a)) / 10.0).cast(DecimalType(6, 5)).as("avg(a)")).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Average Aggregation: Not Optimized") { + val originalQuery = testRelation.select(avg('b)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala index 97bbab65af1de..0480d5229025a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala @@ -177,6 +177,8 @@ case class WindowExec( case e @ WindowExpression(function, spec) => val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] function match { + case MakeDecimal(AggregateExpression(f, _, _, _), prec, scale) => + collect("AGGREGATE", frame, e, f) case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 2f685c5f9cb51..9a3ccab848231 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.DecimalData -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{Decimal, DecimalType} case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -430,4 +430,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { expr("kurtosis(a)")), Row(null, null, null, null, null)) } + + test("SPARK-14664: Decimal sum/avg over window should work.") { + checkAnswer( + sqlContext.sql("select sum(a) over () from (select explode(array(1.0,2.0,3.0)) a) t"), + Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil) + checkAnswer( + sqlContext.sql("select avg(a) over () from (select explode(array(1.0,2.0,3.0)) a) t"), + Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) + } } From 03734406e5ba17886a9aaa2c08d32a73334ceb8b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 25 Apr 2016 15:42:13 -0700 Subject: [PATCH 2/4] Add more testcases --- .../optimizer/DecimalAggregatesSuite.scala | 58 ++++++++++++++++++- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala index 970a50d2473d2..711294ed61928 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{MakeDecimal, UnscaledValue} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -34,7 +34,7 @@ class DecimalAggregatesSuite extends PlanTest { val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1)) - test("Decimal Sum Aggregation Optimize") { + test("Decimal Sum Aggregation: Optimized") { val originalQuery = testRelation.select(sum('a)) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -51,7 +51,7 @@ class DecimalAggregatesSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Decimal Average Aggregation") { + test("Decimal Average Aggregation: Optimized") { val originalQuery = testRelation.select(avg('a)) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -67,4 +67,56 @@ class DecimalAggregatesSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("Decimal Sum Aggregation over Window: Optimized") { + val spec = windowSpec(Seq('a), Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(sum('a), spec).as('sum_a)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select('a) + .window( + Seq(MakeDecimal(windowExpr(sum(UnscaledValue('a)), spec), 12, 1).as('sum_a)), + Seq('a), + Nil) + .select('a, 'sum_a, 'sum_a) + .select('sum_a) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Sum Aggregation over Window: Not Optimized") { + val spec = windowSpec('b :: Nil, Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(sum('b), spec)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Average Aggregation over Window: Optimized") { + val spec = windowSpec(Seq('a), Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(avg('a), spec).as('avg_a)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select('a) + .window( + Seq((windowExpr(avg(UnscaledValue('a)), spec) / 10.0).cast(DecimalType(6, 5)).as('avg_a)), + Seq('a), + Nil) + .select('a, 'avg_a, 'avg_a) + .select('avg_a) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Average Aggregation over Window: Not Optimized") { + val spec = windowSpec('b :: Nil, Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(avg('b), spec)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } } From 152c6c2c854e2f0628c7f945ef5900f129080e0d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 27 Apr 2016 10:49:55 -0700 Subject: [PATCH 3/4] Address comments. --- .../sql/catalyst/optimizer/Optimizer.scala | 49 ++++++++++--------- .../spark/sql/execution/WindowExec.scala | 2 - 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f9b03d0d8b700..54bf4a52935de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1345,29 +1345,32 @@ object DecimalAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case we @ WindowExpression(ae @ AggregateExpression(Sum( - e @ DecimalType.Expression(prec, scale)), _, _, _), _) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), - prec + 10, scale) - - case we @ WindowExpression(ae @ AggregateExpression(Average( - e @ DecimalType.Expression(prec, scale)), _, _, _), _) if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = - we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) - Cast( - Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), - DecimalType(prec + 4, scale + 4)) - - case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _) - if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - - case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _) - if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) - Cast( - Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), - DecimalType(prec + 4, scale + 4)) + case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _), _) => af match { + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), + prec + 10, scale) + + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = + we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) + Cast( + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + + case _ => we + } + case ae @ AggregateExpression(af, _, _, _) => af match { + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) + + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) + Cast( + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + + case _ => ae + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala index 0480d5229025a..97bbab65af1de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala @@ -177,8 +177,6 @@ case class WindowExec( case e @ WindowExpression(function, spec) => val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] function match { - case MakeDecimal(AggregateExpression(f, _, _, _), prec, scale) => - collect("AGGREGATE", frame, e, f) case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) From b2ac335fedac91b7db710f38e2801a9ea756748e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 27 Apr 2016 10:54:51 -0700 Subject: [PATCH 4/4] Simplify test sql statements. --- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 9a3ccab848231..63f4b759a00ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -433,10 +433,10 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( - sqlContext.sql("select sum(a) over () from (select explode(array(1.0,2.0,3.0)) a) t"), + sqlContext.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil) checkAnswer( - sqlContext.sql("select avg(a) over () from (select explode(array(1.0,2.0,3.0)) a) t"), + sqlContext.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } }