From b1ce4b5fe12d32a867b281d48556f952932afd55 Mon Sep 17 00:00:00 2001 From: Linhong Liu Date: Tue, 8 Sep 2020 15:01:45 +0800 Subject: [PATCH 1/5] [SPARK-32816][SQL] Fix analyzer bug when aggregating multiple distinct DECIMAL columns --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 296fe86e834e..801ceaa629ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -142,7 +142,6 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteNonCorrelatedExists, ComputeCurrentTime, GetCurrentDatabaseAndCatalog(catalogManager), - RewriteDistinctAggregates, ReplaceDeduplicateWithAggregate) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here @@ -196,6 +195,8 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ + Batch("Distinct Aggregate Rewrite", Once, + RewriteDistinctAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, EliminateMapObjects, CombineTypedFilters, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d95f09a4cc83..edd397cd1860 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2555,6 +2555,19 @@ class DataFrameSuite extends QueryTest val df = Seq(0.0 -> -0.0).toDF("pos", "neg") checkAnswer(df.select($"pos" > $"neg"), Row(false)) } + + test("SPARK-32816: aggregating multiple distinct DECIMAL columns") { + withTempPath { path => + spark.range(0, 100, 1, 1) + .selectExpr("id", "cast(id as decimal(9, 0)) as decimal_col") + .write.mode("overwrite") + .parquet(path.getAbsolutePath) + spark.read.parquet(path.getAbsolutePath).createOrReplaceTempView("test_table") + checkAnswer( + sql("select avg(distinct decimal_col), sum(distinct decimal_col) from test_table"), + Row(49.5, 4950)) + } + } } case class GroupByKey(a: Int, b: Int) From 4df4f7c8ebc80bcf854fb26764338789ea13b319 Mon Sep 17 00:00:00 2001 From: Linhong Liu Date: Wed, 9 Sep 2020 10:37:08 +0800 Subject: [PATCH 2/5] fix comments --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 2 ++ .../org/apache/spark/sql/DataFrameSuite.scala | 16 ++++++---------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 801ceaa629ac..6870d9c66fab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -195,6 +195,8 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ + // This batch must run after "Decimal Optimizations", as that one may change the + // aggregate distinct column Batch("Distinct Aggregate Rewrite", Once, RewriteDistinctAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index edd397cd1860..a26e027857ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2557,16 +2557,12 @@ class DataFrameSuite extends QueryTest } test("SPARK-32816: aggregating multiple distinct DECIMAL columns") { - withTempPath { path => - spark.range(0, 100, 1, 1) - .selectExpr("id", "cast(id as decimal(9, 0)) as decimal_col") - .write.mode("overwrite") - .parquet(path.getAbsolutePath) - spark.read.parquet(path.getAbsolutePath).createOrReplaceTempView("test_table") - checkAnswer( - sql("select avg(distinct decimal_col), sum(distinct decimal_col) from test_table"), - Row(49.5, 4950)) - } + spark.range(0, 100, 1, 1) + .selectExpr("id", "cast(id as decimal(9, 0)) as decimal_col") + .createOrReplaceTempView("test_table") + checkAnswer( + sql("select avg(distinct decimal_col), sum(distinct decimal_col) from test_table"), + Row(49.5, 4950)) } } From 737f9965ca6e87e7c43a87e56da7f47a4e08e43e Mon Sep 17 00:00:00 2001 From: Linhong Liu Date: Thu, 10 Sep 2020 10:55:44 +0800 Subject: [PATCH 3/5] use temp view --- .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a26e027857ae..0c06bf3b551f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2557,12 +2557,14 @@ class DataFrameSuite extends QueryTest } test("SPARK-32816: aggregating multiple distinct DECIMAL columns") { - spark.range(0, 100, 1, 1) - .selectExpr("id", "cast(id as decimal(9, 0)) as decimal_col") - .createOrReplaceTempView("test_table") - checkAnswer( - sql("select avg(distinct decimal_col), sum(distinct decimal_col) from test_table"), - Row(49.5, 4950)) + withTempView("test_table") { + spark.range(0, 100, 1, 1) + .selectExpr("id", "cast(id as decimal(9, 0)) as decimal_col") + .createOrReplaceTempView("test_table") + checkAnswer( + sql("select avg(distinct decimal_col), sum(distinct decimal_col) from test_table"), + Row(49.5, 4950)) + } } } From f2111dfc4e22b11de82e6202e7f53992d492eef2 Mon Sep 17 00:00:00 2001 From: Linhong Liu Date: Thu, 10 Sep 2020 17:24:32 +0800 Subject: [PATCH 4/5] change to SQLQueryTestSuite --- .../test/resources/sql-tests/inputs/group-by.sql | 3 +++ .../resources/sql-tests/results/group-by.sql.out | 10 +++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 15 +++++++-------- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index fedf03d774e4..81e2204358bc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -166,3 +166,6 @@ SELECT * FROM (SELECT COUNT(*) AS cnt FROM test_agg) WHERE cnt > 1L; SELECT count(*) FROM test_agg WHERE count(*) > 1L; SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L; SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1; + +-- Aggregate with multiple distinct decimal columns +SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 50eb2a9f22f6..5d9553f80405 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 56 +-- Number of queries: 57 -- !query @@ -573,3 +573,11 @@ org.apache.spark.sql.AnalysisException Aggregate/Window/Generate expressions are not valid in where clause of the query. Expression in where clause: [(((test_agg.`k` = 1) OR (test_agg.`k` = 2)) OR (((count(1) + 1L) > 1L) OR (max(test_agg.`k`) > 1)))] Invalid expressions: [count(1), max(test_agg.`k`)]; + + +-- !query +SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col) +-- !query schema +struct +-- !query output +1.0000 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0c06bf3b551f..71bcb6bdaa30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2557,14 +2557,13 @@ class DataFrameSuite extends QueryTest } test("SPARK-32816: aggregating multiple distinct DECIMAL columns") { - withTempView("test_table") { - spark.range(0, 100, 1, 1) - .selectExpr("id", "cast(id as decimal(9, 0)) as decimal_col") - .createOrReplaceTempView("test_table") - checkAnswer( - sql("select avg(distinct decimal_col), sum(distinct decimal_col) from test_table"), - Row(49.5, 4950)) - } + checkAnswer( + sql( + s""" + |SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) + | FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col) + """.stripMargin), + Row(1, 1)) } } From 8510ff92177554d9e87bd2e30e6fa9e43f89b8e3 Mon Sep 17 00:00:00 2001 From: Linhong Liu Date: Fri, 11 Sep 2020 08:37:17 +0800 Subject: [PATCH 5/5] code clean --- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 71bcb6bdaa30..d95f09a4cc83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2555,16 +2555,6 @@ class DataFrameSuite extends QueryTest val df = Seq(0.0 -> -0.0).toDF("pos", "neg") checkAnswer(df.select($"pos" > $"neg"), Row(false)) } - - test("SPARK-32816: aggregating multiple distinct DECIMAL columns") { - checkAnswer( - sql( - s""" - |SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) - | FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col) - """.stripMargin), - Row(1, 1)) - } } case class GroupByKey(a: Int, b: Int)