diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 296fe86e834e..6870d9c66fab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -142,7 +142,6 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteNonCorrelatedExists, ComputeCurrentTime, GetCurrentDatabaseAndCatalog(catalogManager), - RewriteDistinctAggregates, ReplaceDeduplicateWithAggregate) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here @@ -196,6 +195,10 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ + // This batch must run after "Decimal Optimizations", as that one may change the + // aggregate distinct column + Batch("Distinct Aggregate Rewrite", Once, + RewriteDistinctAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, EliminateMapObjects, CombineTypedFilters, diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index fedf03d774e4..81e2204358bc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -166,3 +166,6 @@ SELECT * FROM (SELECT COUNT(*) AS cnt FROM test_agg) WHERE cnt > 1L; SELECT count(*) FROM test_agg WHERE count(*) > 1L; SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L; SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1; + +-- Aggregate with multiple distinct decimal columns +SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 50eb2a9f22f6..5d9553f80405 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 56 +-- Number of queries: 57 -- !query @@ -573,3 +573,11 @@ org.apache.spark.sql.AnalysisException Aggregate/Window/Generate expressions are not valid in where clause of the query. Expression in where clause: [(((test_agg.`k` = 1) OR (test_agg.`k` = 2)) OR (((count(1) + 1L) > 1L) OR (max(test_agg.`k`) > 1)))] Invalid expressions: [count(1), max(test_agg.`k`)]; + + +-- !query +SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col) +-- !query schema +struct +-- !query output +1.0000 1