From d945e31cc69334aaedaa8abffecf0d07aa5b0ce8 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 13 Oct 2022 10:08:44 -0700 Subject: [PATCH 1/3] try to fix --- .../catalyst/optimizer/RewriteDistinctAggregates.scala | 7 ++++++- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 3a35c08d594a..d4fa058fdc6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -248,7 +248,12 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def patchAggregateFunctionChildren( af: AggregateFunction)( attrs: Expression => Option[Expression]): AggregateFunction = { - val newChildren = af.children.map(c => attrs(c).getOrElse(c)) + val newChildren = af.children.zipWithIndex.map { case (x, i) => + x match { + case l: Literal if i > 0 => l // some literal function arguments must stay literal + case c@_ => attrs (c).getOrElse (c) + } + } af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 58c8e3abaa44..ee1d7aa0903e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1527,6 +1527,16 @@ class DataFrameAggregateSuite extends QueryTest |""".stripMargin) checkAnswer(res3, Row(1, 7, 4.5, 1) :: Row(2, 7, 4.5, 2) :: Nil) } + + test("Reuse of literal in distinct aggregations should work") { + val res = sql( + """select a, count(distinct 100), count(distinct b, 100) + |from values (1, 2), (4, 5), (4, 6) as data(a, b) + |group by a; + |""".stripMargin + ) + checkAnswer(res, Row(1, 1, 1) :: Row(4, 1, 2) :: Nil) + } } case class B(c: Option[Double]) From 581983e7610724a7b4da3316f1047fa8948b52c9 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 7 Nov 2022 18:28:33 -0800 Subject: [PATCH 2/3] Update --- .../catalyst/optimizer/RewriteDistinctAggregates.scala | 9 +++++---- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index d4fa058fdc6c..0e24cb096f35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -248,10 +248,11 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def patchAggregateFunctionChildren( af: AggregateFunction)( attrs: Expression => Option[Expression]): AggregateFunction = { - val newChildren = af.children.zipWithIndex.map { case (x, i) => - x match { - case l: Literal if i > 0 => l // some literal function arguments must stay literal - case c@_ => attrs (c).getOrElse (c) + val newChildren = af.children.zipWithIndex.map { case (c, i) => + if (c.foldable && i > 0) { + c + } else { + attrs(c).getOrElse(c) } } af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index ee1d7aa0903e..235c5b011e3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1528,7 +1528,7 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(res3, Row(1, 7, 4.5, 1) :: Row(2, 7, 4.5, 2) :: Nil) } - test("Reuse of literal in distinct aggregations should work") { + test("SPARK-41035: Reuse of literal in distinct aggregations should work") { val res = sql( """select a, count(distinct 100), count(distinct b, 100) |from values (1, 2), (4, 5), (4, 6) as data(a, b) From 9d4008d94f2f918206a13be4359d12c9e95bccca Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 8 Nov 2022 08:50:15 -0800 Subject: [PATCH 3/3] Update --- .../catalyst/optimizer/RewriteDistinctAggregates.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 0e24cb096f35..da3cf782f668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -248,13 +248,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def patchAggregateFunctionChildren( af: AggregateFunction)( attrs: Expression => Option[Expression]): AggregateFunction = { - val newChildren = af.children.zipWithIndex.map { case (c, i) => - if (c.foldable && i > 0) { - c - } else { - attrs(c).getOrElse(c) - } - } + val newChildren = af.children.map(c => attrs(c).getOrElse(c)) af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] } @@ -273,7 +267,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { }.unzip3 // Setup expand & aggregate operators for distinct aggregate expressions. - val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap + val distinctAggChildAttrLookup = distinctAggChildAttrMap.filter(!_._1.foldable).toMap val distinctAggFilterAttrLookup = Utils.toMap(distinctAggFilters, maxConds.map(_.toAttribute)) val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { case ((group, expressions), i) =>