diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index da3cf782f668..e91493188873 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -197,6 +197,15 @@ import org.apache.spark.util.collection.Utils * techniques. */ object RewriteDistinctAggregates extends Rule[LogicalPlan] { + private def mustRewrite( + aggregateExpressions: Seq[AggregateExpression], + groupingExpressions: Seq[Expression]): Boolean = { + // If there are any AggregateExpressions with filter, we need to rewrite the query. + // Also, if there are no grouping expressions and all aggregate expressions are foldable, + // we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). + aggregateExpressions.exists(_.filter.isDefined) || (groupingExpressions.isEmpty && + aggregateExpressions.exists(_.aggregateFunction.children.forall(_.foldable))) + } private def mayNeedtoRewrite(a: Aggregate): Boolean = { val aggExpressions = collectAggregateExprs(a) @@ -205,7 +214,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // clause for this rule because aggregation strategy can handle a single distinct aggregate // group without filter clause. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) + distinctAggs.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( @@ -236,7 +245,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group without filter clause. - if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { + if (distinctAggGroups.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 620ee430cab2..90ac4f351ff4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} +import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -2339,6 +2340,119 @@ class DataFrameAggregateSuite extends QueryTest test("SPARK-32761: aggregating multiple distinct CONSTANT columns") { checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1, 1)) } + + test("aggregating with various distinct expressions") { + abstract class AggregateTestCaseBase( + val query: String, + val resultSeq: Seq[Seq[Row]], + val hasExpandNodeInPlan: Boolean) + case class AggregateTestCase( + override val query: String, + override val resultSeq: Seq[Seq[Row]], + override val hasExpandNodeInPlan: Boolean) + extends AggregateTestCaseBase(query, resultSeq, hasExpandNodeInPlan) + case class AggregateTestCaseDefault( + override val query: String) + extends AggregateTestCaseBase( + query, + Seq(Seq(Row(0)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = true) + + val t = "t" + val testCases: Seq[AggregateTestCaseBase] = Seq( + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT "col") FROM $t""" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1) FROM $t" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1 + 2) FROM $t" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1, 2, 1 + 2) FROM $t" + ), + AggregateTestCase( + s"SELECT COUNT(1), COUNT(DISTINCT 1) FROM $t", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(2, 1))), + hasExpandNodeInPlan = true + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT 1, "col") FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT collation("abc")) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT current_date()) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT array(1, 2)[1]) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT map(1, 2)[1]) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT struct(1, 2).col1) FROM $t""" + ), + AggregateTestCase( + s"SELECT COUNT(DISTINCT 1) FROM $t GROUP BY col", + Seq(Seq(), Seq(Row(1)), Seq(Row(1), Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 1" + ), + AggregateTestCase( + s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 0", + Seq(Seq(Row(0)), Seq(Row(0)), Seq(Row(0))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCaseDefault( + s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $t)"), + AggregateTestCase( + s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT "col") FROM $t""", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 1))), + hasExpandNodeInPlan = true + ), + AggregateTestCase( + s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT col) FROM $t""", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 2))), + hasExpandNodeInPlan = true + ) + ) + withTable(t) { + sql(s"create table $t(col int) using parquet") + Seq(0, 1, 2).foreach(columnValue => { + if (columnValue != 0) { + sql(s"insert into $t(col) values($columnValue)") + } + testCases.foreach(testCase => { + val query = sql(testCase.query) + checkAnswer(query, testCase.resultSeq(columnValue)) + val hasExpandNodeInPlan = query.queryExecution.optimizedPlan.collectFirst { + case _: Expand => true + }.nonEmpty + assert(hasExpandNodeInPlan == testCase.hasExpandNodeInPlan) + }) + }) + } + } } case class B(c: Option[Double])