From ee6d85dfdff072257cf92ccda5b6f0f8dbfe7d94 Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 29 Jul 2024 23:47:52 +0200 Subject: [PATCH 1/3] Fix `select count(distinct 1) from t` where t is empty table by expanding RewriteDistinctAggregates --- .../optimizer/RewriteDistinctAggregates.scala | 6 +- .../spark/sql/DataFrameAggregateSuite.scala | 286 ++++++++++++++++++ 2 files changed, 290 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index da3cf782f668..15379474a0a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -205,7 +205,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // clause for this rule because aggregation strategy can handle a single distinct aggregate // group without filter clause. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) + distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) || + distinctAggs.exists(_.aggregateFunction.children.forall(_.foldable)) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( @@ -236,7 +237,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group without filter clause. - if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { + if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined) || + distinctAggs.exists(_.aggregateFunction.children.forall(_.foldable))) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 620ee430cab2..7edb75a9284c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2339,6 +2339,292 @@ class DataFrameAggregateSuite extends QueryTest test("SPARK-32761: aggregating multiple distinct CONSTANT columns") { checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1, 1)) } + + test("aggregating single distinct column with empty and non-empty table") { + val tableName = "t" + withTable(tableName) { + // Original table now has 0 rows. + sql(s"create table $tableName(col int) using parquet") + + // Count function. + checkAnswer(sql(s"select count(1) from $tableName"), Row(0)) + checkAnswer(sql(s"select count(col) from $tableName"), Row(0)) + checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(0)) + checkAnswer(sql(s"select count(distinct col) from $tableName"), Row(0)) + // Sum function. + checkAnswer(sql(s"select sum(1) from $tableName"), Row(null)) + checkAnswer(sql(s"select sum(col) from $tableName"), Row(null)) + checkAnswer(sql(s"select sum(distinct 1) from $tableName"), Row(null)) + checkAnswer(sql(s"select sum(distinct col) from $tableName"), Row(null)) + + // Original table now has 1 row. + sql(s"insert into $tableName(col) values(1)") + + // Count function. + checkAnswer(sql(s"select count(1) from $tableName"), Row(1)) + checkAnswer(sql(s"select count(col) from $tableName"), Row(1)) + checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(1)) + checkAnswer(sql(s"select count(distinct col) from $tableName"), Row(1)) + // Sum function. + checkAnswer(sql(s"select sum(1) from $tableName"), Row(1)) + checkAnswer(sql(s"select sum(col) from $tableName"), Row(1)) + checkAnswer(sql(s"select sum(distinct 1) from $tableName"), Row(1)) + checkAnswer(sql(s"select sum(distinct col) from $tableName"), Row(1)) + + // Original table now has 2 rows. + sql(s"insert into $tableName(col) values(2)") + + // Count function. + checkAnswer(sql(s"select count(1) from $tableName"), Row(2)) + checkAnswer(sql(s"select count(col) from $tableName"), Row(2)) + checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(1)) + checkAnswer(sql(s"select count(distinct col) from $tableName"), Row(2)) + // Sum function. + checkAnswer(sql(s"select sum(1) from $tableName"), Row(2)) + checkAnswer(sql(s"select sum(col) from $tableName"), Row(3)) + checkAnswer(sql(s"select sum(distinct 1) from $tableName"), Row(1)) + checkAnswer(sql(s"select sum(distinct col) from $tableName"), Row(3)) + + // Original table now has 3 rows. + sql(s"insert into $tableName(col) values(3)") + + // Count function. + checkAnswer(sql(s"select count(1) from $tableName"), Row(3)) + checkAnswer(sql(s"select count(col) from $tableName"), Row(3)) + checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(1)) + checkAnswer(sql(s"select count(distinct col) from $tableName"), Row(3)) + // Sum function. + checkAnswer(sql(s"select sum(1) from $tableName"), Row(3)) + checkAnswer(sql(s"select sum(col) from $tableName"), Row(6)) + checkAnswer(sql(s"select sum(distinct 1) from $tableName"), Row(1)) + checkAnswer(sql(s"select sum(distinct col) from $tableName"), Row(6)) + } + } + + test("aggregating single distinct column with table") { + val tableName = "t" + withTable(tableName) { + sql(s"create table $tableName(col int) using parquet") + + // Column `col` does exist in the table. + checkAnswer(sql(s"select max(col) from $tableName"), Row(null)) + checkAnswer(sql(s"select count(col) from $tableName"), Row(0)) + checkAnswer(sql(s"select max(distinct col) from $tableName"), Row(null)) + checkAnswer(sql(s"select count(distinct col) from $tableName"), Row(0)) + checkAnswer(sql(s"select max(`col`) from $tableName"), Row(null)) + checkAnswer(sql(s"select count(`col`) from $tableName"), Row(0)) + checkAnswer(sql(s"select max(distinct `col`) from $tableName"), Row(null)) + checkAnswer(sql(s"select count(distinct `col`) from $tableName"), Row(0)) + // But here 'col' is a string literal, not a column name. + checkAnswer(sql(s"select max('col') from $tableName"), Row(null)) + checkAnswer(sql(s"select count('col') from $tableName"), Row(0)) + checkAnswer(sql(s"select max(distinct 'col') from $tableName"), Row(null)) + checkAnswer(sql(s"select count(distinct 'col') from $tableName"), Row(0)) + checkAnswer(sql(s"""select max("col") from $tableName"""), Row(null)) + checkAnswer(sql(s"""select count("col") from $tableName"""), Row(0)) + checkAnswer(sql(s"""select max(distinct "col") from $tableName"""), Row(null)) + checkAnswer(sql(s"""select count(distinct "col") from $tableName"""), Row(0)) + // Works the same for any string literal. + checkAnswer(sql(s"select max('hello') from $tableName"), Row(null)) + checkAnswer(sql(s"select count('hello') from $tableName"), Row(0)) + checkAnswer(sql(s"select max(distinct 'hello') from $tableName"), Row(null)) + checkAnswer(sql(s"select count(distinct 'hello') from $tableName"), Row(0)) + checkAnswer(sql(s"""select max("hello") from $tableName"""), Row(null)) + checkAnswer(sql(s"""select count("hello") from $tableName"""), Row(0)) + checkAnswer(sql(s"""select max(distinct "hello") from $tableName"""), Row(null)) + checkAnswer(sql(s"""select count(distinct "hello") from $tableName"""), Row(0)) + // Or any other kind of literal for that matter. + checkAnswer(sql(s"select max(1) from $tableName"), Row(null)) + checkAnswer(sql(s"select count(1) from $tableName"), Row(0)) + checkAnswer(sql(s"""select max(distinct 1) from $tableName"""), Row(null)) + checkAnswer(sql(s"""select count(distinct 1) from $tableName"""), Row(0)) + } + } + + test("selecting single distinct column with table") { + val tableName = "t" + withTable(tableName) { + sql(s"create table $tableName(col int) using parquet") + + // Column `col` does exist in the table. + checkAnswer(sql(s"select col from $tableName"), Seq()) + checkAnswer(sql(s"select distinct col from $tableName"), Seq()) + checkAnswer(sql(s"select `col` from $tableName"), Seq()) + checkAnswer(sql(s"select distinct `col` from $tableName"), Seq()) + // But here 'col' is a string literal, not a column name. + checkAnswer(sql(s"select 'col' from $tableName"), Seq()) + checkAnswer(sql(s"select distinct 'col' from $tableName"), Seq()) + checkAnswer(sql(s"""select "col" from $tableName"""), Seq()) + checkAnswer(sql(s"""select distinct "col" from $tableName"""), Seq()) + // Works the same for any string literal. + checkAnswer(sql(s"select 'hello' from $tableName"), Seq()) + checkAnswer(sql(s"select distinct 'hello' from $tableName"), Seq()) + checkAnswer(sql(s"""select "hello" from $tableName"""), Seq()) + checkAnswer(sql(s"""select distinct "hello" from $tableName"""), Seq()) + // Or any other literal for that matter. + checkAnswer(sql(s"select 1 from $tableName"), Seq()) + checkAnswer(sql(s"select distinct 1 from $tableName"), Seq()) + checkAnswer(sql(s"""select 1 from $tableName"""), Seq()) + checkAnswer(sql(s"""select distinct 1 from $tableName"""), Seq()) + } + } + + test("selecting single distinct column with subquery") { + val tableName = "t" + withTable(tableName) { + sql(s"create table $tableName(col int) using parquet") + + val querySumDistinctCountDistinct = + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $tableName)" + val querySumDistinctCount = + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $tableName)" + val querySumCountDistinct = + s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $tableName)" + val querySumRefCountDistinct = + s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $tableName)" + + checkAnswer(sql(querySumDistinctCountDistinct), Row(1)) + checkAnswer(sql(querySumDistinctCount), Row(1)) + checkAnswer(sql(querySumCountDistinct), Row(1)) + checkAnswer(sql(querySumRefCountDistinct), Row(0)) + + sql(s"insert into $tableName(col) values(1)") + + checkAnswer(sql(querySumDistinctCountDistinct), Row(1)) + checkAnswer(sql(querySumDistinctCount), Row(1)) + checkAnswer(sql(querySumCountDistinct), Row(1)) + checkAnswer(sql(querySumRefCountDistinct), Row(1)) + + sql(s"insert into $tableName(col) values(2)") + + checkAnswer(sql(querySumDistinctCountDistinct), Row(1)) + checkAnswer(sql(querySumDistinctCount), Row(1)) + checkAnswer(sql(querySumCountDistinct), Row(1)) + checkAnswer(sql(querySumRefCountDistinct), Row(1)) + } + } + + test("selecting single distinct column with table and grouping expression") { + val tableName = "t" + withTable(tableName) { + sql(s"create table $tableName(col int) using parquet") + val query = s"SELECT COUNT(DISTINCT 1) FROM $tableName GROUP BY col" + checkAnswer(sql(query), Seq()) + sql(s"insert into $tableName(col) values(1)") + checkAnswer(sql(query), Row(1)) + sql(s"insert into $tableName(col) values(2)") + checkAnswer(sql(query), Seq(Row(1), Row(1))) + } + } + + test("selecting complex literal with table") { + val tableName = "t" + withTable(tableName) { + sql(s"create table $tableName(col int) using parquet") + + val columnValues = Seq(0, 1, 1, 2) + columnValues.foreach { + columnValue => { + if (columnValue != 0) sql(s"insert into $tableName(col) values($columnValue)") + val result = Row(if (columnValue == 0) 0 else 1) + + // Integer literals. + checkAnswer(sql(s"""select count(distinct 1 + 2) from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct 1, 2) from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct 1, 1 + 2) from $tableName"""), result) + + // String literals. + checkAnswer(sql(s"""select count(distinct "hello") from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct "col") from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct "") from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct collation("abc")) from $tableName"""), result) + checkAnswer(sql( + s""" + |select count(distinct collation("abc" collate utf8_lcase)) from + |$tableName""".stripMargin), result) + + // Other special cases. + checkAnswer(sql(s"""select count(distinct current_date()) from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct 1, "x", current_date()) from $tableName"""), + result) + + // Complex types. + checkAnswer(sql(s"""select count(distinct array(1, 2)) from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct map(1, 2)) from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct struct(1, 2)) from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct named_struct("a", 1)) from $tableName"""), + result) + + // Field extraction. + checkAnswer(sql(s"""select count(distinct array(1, 2)[1]) from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct map(1, 2)[1]) from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct struct(1, 2).col1) from $tableName"""), result) + checkAnswer(sql(s"""select count(distinct named_struct("a", 1).a) from $tableName"""), + result) + } + } + } + } + + test("selecting multiple distinct column with table") { + val tableName = "t" + withTable(tableName) { + // Original table now has 0 rows. + sql(s"create table $tableName(col int) using parquet") + + checkAnswer(sql(s"""select count(distinct 1), count(distinct 2) from $tableName"""), + Row(0, 0)) + checkAnswer(sql(s"""select count(distinct 1), count(distinct col) from $tableName"""), + Row(0, 0)) + checkAnswer(sql(s"""select count(distinct col), count(distinct col) from $tableName"""), + Row(0, 0)) + + // Original table now has 1 row. + sql(s"insert into $tableName(col) values(1)") + + checkAnswer(sql(s"""select count(distinct 2), count(distinct 1) from $tableName"""), + Row(1, 1)) + checkAnswer(sql(s"""select count(distinct col), count(distinct 1) from $tableName"""), + Row(1, 1)) + checkAnswer(sql(s"""select count(distinct col), count(distinct col) from $tableName"""), + Row(1, 1)) + + // Original table now has 2 rows. + sql(s"insert into $tableName(col) values(2)") + + checkAnswer(sql(s"""select count(distinct 1), count(distinct 2) from $tableName"""), + Row(1, 1)) + checkAnswer(sql(s"""select count(distinct 1), count(distinct col) from $tableName"""), + Row(1, 2)) + checkAnswer(sql(s"""select count(distinct col), count(distinct col) from $tableName"""), + Row(2, 2)) + + // Original table now has 3 rows. + sql(s"insert into $tableName(col) values(3)") + + checkAnswer(sql(s"""select count(distinct 2), count(distinct 1) from $tableName"""), + Row(1, 1)) + checkAnswer(sql(s"""select count(distinct col), count(distinct 1) from $tableName"""), + Row(3, 1)) + checkAnswer(sql(s"""select count(distinct col), count(distinct col) from $tableName"""), + Row(3, 3)) + } + } + + test("non distinct and distinct aggregate expressions") { + val tableName = "test" + withTable(tableName) { + sql(s"create table $tableName(col int) using parquet") + val query = s"select count(1), count(distinct 1) from $tableName" + checkAnswer(sql(query), Row(0, 0)) + sql(s"insert into $tableName values 1") + checkAnswer(sql(query), Row(1, 1)) + sql(s"insert into $tableName values 1") + checkAnswer(sql(query), Row(2, 1)) + sql(s"insert into $tableName values 2") + checkAnswer(sql(query), Row(3, 1)) + } + } } case class B(c: Option[Double]) From 51e2edd4f2810c09bb574b082c35c66adfa16a99 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 30 Jul 2024 19:52:58 +0200 Subject: [PATCH 2/3] Fixes --- .../optimizer/RewriteDistinctAggregates.scala | 15 +- .../spark/sql/DataFrameAggregateSuite.scala | 392 +++++------------- 2 files changed, 121 insertions(+), 286 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 15379474a0a5..99bfcc0c309b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -197,6 +197,15 @@ import org.apache.spark.util.collection.Utils * techniques. */ object RewriteDistinctAggregates extends Rule[LogicalPlan] { + private def getRewriteCondition( + aggregateExpressions: Seq[AggregateExpression], + groupingExpressions: Seq[Expression]): Boolean = { + // If there are any AggregateExpressions with filter, we need to rewrite the query. + // Also, if there are no grouping expressions and all aggregate expressions are foldable, + // we can rewrite the query, e.g. SELECT COUNT(DISTINCT 1). + aggregateExpressions.exists(_.filter.isDefined) || (groupingExpressions.isEmpty && + aggregateExpressions.exists(_.aggregateFunction.children.forall(_.foldable))) + } private def mayNeedtoRewrite(a: Aggregate): Boolean = { val aggExpressions = collectAggregateExprs(a) @@ -205,8 +214,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // clause for this rule because aggregation strategy can handle a single distinct aggregate // group without filter clause. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) || - distinctAggs.exists(_.aggregateFunction.children.forall(_.foldable)) + distinctAggs.size > 1 || getRewriteCondition(distinctAggs, a.groupingExpressions) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( @@ -237,8 +245,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group without filter clause. - if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined) || - distinctAggs.exists(_.aggregateFunction.children.forall(_.foldable))) { + if (distinctAggGroups.size > 1 || getRewriteCondition(distinctAggs, a.groupingExpressions)) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7edb75a9284c..90ac4f351ff4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} +import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -2340,289 +2341,116 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1, 1)) } - test("aggregating single distinct column with empty and non-empty table") { - val tableName = "t" - withTable(tableName) { - // Original table now has 0 rows. - sql(s"create table $tableName(col int) using parquet") - - // Count function. - checkAnswer(sql(s"select count(1) from $tableName"), Row(0)) - checkAnswer(sql(s"select count(col) from $tableName"), Row(0)) - checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(0)) - checkAnswer(sql(s"select count(distinct col) from $tableName"), Row(0)) - // Sum function. - checkAnswer(sql(s"select sum(1) from $tableName"), Row(null)) - checkAnswer(sql(s"select sum(col) from $tableName"), Row(null)) - checkAnswer(sql(s"select sum(distinct 1) from $tableName"), Row(null)) - checkAnswer(sql(s"select sum(distinct col) from $tableName"), Row(null)) - - // Original table now has 1 row. - sql(s"insert into $tableName(col) values(1)") - - // Count function. - checkAnswer(sql(s"select count(1) from $tableName"), Row(1)) - checkAnswer(sql(s"select count(col) from $tableName"), Row(1)) - checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(1)) - checkAnswer(sql(s"select count(distinct col) from $tableName"), Row(1)) - // Sum function. - checkAnswer(sql(s"select sum(1) from $tableName"), Row(1)) - checkAnswer(sql(s"select sum(col) from $tableName"), Row(1)) - checkAnswer(sql(s"select sum(distinct 1) from $tableName"), Row(1)) - checkAnswer(sql(s"select sum(distinct col) from $tableName"), Row(1)) - - // Original table now has 2 rows. - sql(s"insert into $tableName(col) values(2)") - - // Count function. - checkAnswer(sql(s"select count(1) from $tableName"), Row(2)) - checkAnswer(sql(s"select count(col) from $tableName"), Row(2)) - checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(1)) - checkAnswer(sql(s"select count(distinct col) from $tableName"), Row(2)) - // Sum function. - checkAnswer(sql(s"select sum(1) from $tableName"), Row(2)) - checkAnswer(sql(s"select sum(col) from $tableName"), Row(3)) - checkAnswer(sql(s"select sum(distinct 1) from $tableName"), Row(1)) - checkAnswer(sql(s"select sum(distinct col) from $tableName"), Row(3)) - - // Original table now has 3 rows. - sql(s"insert into $tableName(col) values(3)") - - // Count function. - checkAnswer(sql(s"select count(1) from $tableName"), Row(3)) - checkAnswer(sql(s"select count(col) from $tableName"), Row(3)) - checkAnswer(sql(s"select count(distinct 1) from $tableName"), Row(1)) - checkAnswer(sql(s"select count(distinct col) from $tableName"), Row(3)) - // Sum function. - checkAnswer(sql(s"select sum(1) from $tableName"), Row(3)) - checkAnswer(sql(s"select sum(col) from $tableName"), Row(6)) - checkAnswer(sql(s"select sum(distinct 1) from $tableName"), Row(1)) - checkAnswer(sql(s"select sum(distinct col) from $tableName"), Row(6)) - } - } - - test("aggregating single distinct column with table") { - val tableName = "t" - withTable(tableName) { - sql(s"create table $tableName(col int) using parquet") - - // Column `col` does exist in the table. - checkAnswer(sql(s"select max(col) from $tableName"), Row(null)) - checkAnswer(sql(s"select count(col) from $tableName"), Row(0)) - checkAnswer(sql(s"select max(distinct col) from $tableName"), Row(null)) - checkAnswer(sql(s"select count(distinct col) from $tableName"), Row(0)) - checkAnswer(sql(s"select max(`col`) from $tableName"), Row(null)) - checkAnswer(sql(s"select count(`col`) from $tableName"), Row(0)) - checkAnswer(sql(s"select max(distinct `col`) from $tableName"), Row(null)) - checkAnswer(sql(s"select count(distinct `col`) from $tableName"), Row(0)) - // But here 'col' is a string literal, not a column name. - checkAnswer(sql(s"select max('col') from $tableName"), Row(null)) - checkAnswer(sql(s"select count('col') from $tableName"), Row(0)) - checkAnswer(sql(s"select max(distinct 'col') from $tableName"), Row(null)) - checkAnswer(sql(s"select count(distinct 'col') from $tableName"), Row(0)) - checkAnswer(sql(s"""select max("col") from $tableName"""), Row(null)) - checkAnswer(sql(s"""select count("col") from $tableName"""), Row(0)) - checkAnswer(sql(s"""select max(distinct "col") from $tableName"""), Row(null)) - checkAnswer(sql(s"""select count(distinct "col") from $tableName"""), Row(0)) - // Works the same for any string literal. - checkAnswer(sql(s"select max('hello') from $tableName"), Row(null)) - checkAnswer(sql(s"select count('hello') from $tableName"), Row(0)) - checkAnswer(sql(s"select max(distinct 'hello') from $tableName"), Row(null)) - checkAnswer(sql(s"select count(distinct 'hello') from $tableName"), Row(0)) - checkAnswer(sql(s"""select max("hello") from $tableName"""), Row(null)) - checkAnswer(sql(s"""select count("hello") from $tableName"""), Row(0)) - checkAnswer(sql(s"""select max(distinct "hello") from $tableName"""), Row(null)) - checkAnswer(sql(s"""select count(distinct "hello") from $tableName"""), Row(0)) - // Or any other kind of literal for that matter. - checkAnswer(sql(s"select max(1) from $tableName"), Row(null)) - checkAnswer(sql(s"select count(1) from $tableName"), Row(0)) - checkAnswer(sql(s"""select max(distinct 1) from $tableName"""), Row(null)) - checkAnswer(sql(s"""select count(distinct 1) from $tableName"""), Row(0)) - } - } - - test("selecting single distinct column with table") { - val tableName = "t" - withTable(tableName) { - sql(s"create table $tableName(col int) using parquet") - - // Column `col` does exist in the table. - checkAnswer(sql(s"select col from $tableName"), Seq()) - checkAnswer(sql(s"select distinct col from $tableName"), Seq()) - checkAnswer(sql(s"select `col` from $tableName"), Seq()) - checkAnswer(sql(s"select distinct `col` from $tableName"), Seq()) - // But here 'col' is a string literal, not a column name. - checkAnswer(sql(s"select 'col' from $tableName"), Seq()) - checkAnswer(sql(s"select distinct 'col' from $tableName"), Seq()) - checkAnswer(sql(s"""select "col" from $tableName"""), Seq()) - checkAnswer(sql(s"""select distinct "col" from $tableName"""), Seq()) - // Works the same for any string literal. - checkAnswer(sql(s"select 'hello' from $tableName"), Seq()) - checkAnswer(sql(s"select distinct 'hello' from $tableName"), Seq()) - checkAnswer(sql(s"""select "hello" from $tableName"""), Seq()) - checkAnswer(sql(s"""select distinct "hello" from $tableName"""), Seq()) - // Or any other literal for that matter. - checkAnswer(sql(s"select 1 from $tableName"), Seq()) - checkAnswer(sql(s"select distinct 1 from $tableName"), Seq()) - checkAnswer(sql(s"""select 1 from $tableName"""), Seq()) - checkAnswer(sql(s"""select distinct 1 from $tableName"""), Seq()) - } - } - - test("selecting single distinct column with subquery") { - val tableName = "t" - withTable(tableName) { - sql(s"create table $tableName(col int) using parquet") - - val querySumDistinctCountDistinct = - s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $tableName)" - val querySumDistinctCount = - s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $tableName)" - val querySumCountDistinct = - s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $tableName)" - val querySumRefCountDistinct = - s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $tableName)" - - checkAnswer(sql(querySumDistinctCountDistinct), Row(1)) - checkAnswer(sql(querySumDistinctCount), Row(1)) - checkAnswer(sql(querySumCountDistinct), Row(1)) - checkAnswer(sql(querySumRefCountDistinct), Row(0)) - - sql(s"insert into $tableName(col) values(1)") - - checkAnswer(sql(querySumDistinctCountDistinct), Row(1)) - checkAnswer(sql(querySumDistinctCount), Row(1)) - checkAnswer(sql(querySumCountDistinct), Row(1)) - checkAnswer(sql(querySumRefCountDistinct), Row(1)) - - sql(s"insert into $tableName(col) values(2)") - - checkAnswer(sql(querySumDistinctCountDistinct), Row(1)) - checkAnswer(sql(querySumDistinctCount), Row(1)) - checkAnswer(sql(querySumCountDistinct), Row(1)) - checkAnswer(sql(querySumRefCountDistinct), Row(1)) - } - } - - test("selecting single distinct column with table and grouping expression") { - val tableName = "t" - withTable(tableName) { - sql(s"create table $tableName(col int) using parquet") - val query = s"SELECT COUNT(DISTINCT 1) FROM $tableName GROUP BY col" - checkAnswer(sql(query), Seq()) - sql(s"insert into $tableName(col) values(1)") - checkAnswer(sql(query), Row(1)) - sql(s"insert into $tableName(col) values(2)") - checkAnswer(sql(query), Seq(Row(1), Row(1))) - } - } - - test("selecting complex literal with table") { - val tableName = "t" - withTable(tableName) { - sql(s"create table $tableName(col int) using parquet") - - val columnValues = Seq(0, 1, 1, 2) - columnValues.foreach { - columnValue => { - if (columnValue != 0) sql(s"insert into $tableName(col) values($columnValue)") - val result = Row(if (columnValue == 0) 0 else 1) - - // Integer literals. - checkAnswer(sql(s"""select count(distinct 1 + 2) from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct 1, 2) from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct 1, 1 + 2) from $tableName"""), result) - - // String literals. - checkAnswer(sql(s"""select count(distinct "hello") from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct "col") from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct "") from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct collation("abc")) from $tableName"""), result) - checkAnswer(sql( - s""" - |select count(distinct collation("abc" collate utf8_lcase)) from - |$tableName""".stripMargin), result) - - // Other special cases. - checkAnswer(sql(s"""select count(distinct current_date()) from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct 1, "x", current_date()) from $tableName"""), - result) - - // Complex types. - checkAnswer(sql(s"""select count(distinct array(1, 2)) from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct map(1, 2)) from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct struct(1, 2)) from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct named_struct("a", 1)) from $tableName"""), - result) - - // Field extraction. - checkAnswer(sql(s"""select count(distinct array(1, 2)[1]) from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct map(1, 2)[1]) from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct struct(1, 2).col1) from $tableName"""), result) - checkAnswer(sql(s"""select count(distinct named_struct("a", 1).a) from $tableName"""), - result) + test("aggregating with various distinct expressions") { + abstract class AggregateTestCaseBase( + val query: String, + val resultSeq: Seq[Seq[Row]], + val hasExpandNodeInPlan: Boolean) + case class AggregateTestCase( + override val query: String, + override val resultSeq: Seq[Seq[Row]], + override val hasExpandNodeInPlan: Boolean) + extends AggregateTestCaseBase(query, resultSeq, hasExpandNodeInPlan) + case class AggregateTestCaseDefault( + override val query: String) + extends AggregateTestCaseBase( + query, + Seq(Seq(Row(0)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = true) + + val t = "t" + val testCases: Seq[AggregateTestCaseBase] = Seq( + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT "col") FROM $t""" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1) FROM $t" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1 + 2) FROM $t" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1, 2, 1 + 2) FROM $t" + ), + AggregateTestCase( + s"SELECT COUNT(1), COUNT(DISTINCT 1) FROM $t", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(2, 1))), + hasExpandNodeInPlan = true + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT 1, "col") FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT collation("abc")) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT current_date()) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT array(1, 2)[1]) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT map(1, 2)[1]) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT struct(1, 2).col1) FROM $t""" + ), + AggregateTestCase( + s"SELECT COUNT(DISTINCT 1) FROM $t GROUP BY col", + Seq(Seq(), Seq(Row(1)), Seq(Row(1), Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 1" + ), + AggregateTestCase( + s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 0", + Seq(Seq(Row(0)), Seq(Row(0)), Seq(Row(0))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCaseDefault( + s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $t)"), + AggregateTestCase( + s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT "col") FROM $t""", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 1))), + hasExpandNodeInPlan = true + ), + AggregateTestCase( + s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT col) FROM $t""", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 2))), + hasExpandNodeInPlan = true + ) + ) + withTable(t) { + sql(s"create table $t(col int) using parquet") + Seq(0, 1, 2).foreach(columnValue => { + if (columnValue != 0) { + sql(s"insert into $t(col) values($columnValue)") } - } - } - } - - test("selecting multiple distinct column with table") { - val tableName = "t" - withTable(tableName) { - // Original table now has 0 rows. - sql(s"create table $tableName(col int) using parquet") - - checkAnswer(sql(s"""select count(distinct 1), count(distinct 2) from $tableName"""), - Row(0, 0)) - checkAnswer(sql(s"""select count(distinct 1), count(distinct col) from $tableName"""), - Row(0, 0)) - checkAnswer(sql(s"""select count(distinct col), count(distinct col) from $tableName"""), - Row(0, 0)) - - // Original table now has 1 row. - sql(s"insert into $tableName(col) values(1)") - - checkAnswer(sql(s"""select count(distinct 2), count(distinct 1) from $tableName"""), - Row(1, 1)) - checkAnswer(sql(s"""select count(distinct col), count(distinct 1) from $tableName"""), - Row(1, 1)) - checkAnswer(sql(s"""select count(distinct col), count(distinct col) from $tableName"""), - Row(1, 1)) - - // Original table now has 2 rows. - sql(s"insert into $tableName(col) values(2)") - - checkAnswer(sql(s"""select count(distinct 1), count(distinct 2) from $tableName"""), - Row(1, 1)) - checkAnswer(sql(s"""select count(distinct 1), count(distinct col) from $tableName"""), - Row(1, 2)) - checkAnswer(sql(s"""select count(distinct col), count(distinct col) from $tableName"""), - Row(2, 2)) - - // Original table now has 3 rows. - sql(s"insert into $tableName(col) values(3)") - - checkAnswer(sql(s"""select count(distinct 2), count(distinct 1) from $tableName"""), - Row(1, 1)) - checkAnswer(sql(s"""select count(distinct col), count(distinct 1) from $tableName"""), - Row(3, 1)) - checkAnswer(sql(s"""select count(distinct col), count(distinct col) from $tableName"""), - Row(3, 3)) - } - } - - test("non distinct and distinct aggregate expressions") { - val tableName = "test" - withTable(tableName) { - sql(s"create table $tableName(col int) using parquet") - val query = s"select count(1), count(distinct 1) from $tableName" - checkAnswer(sql(query), Row(0, 0)) - sql(s"insert into $tableName values 1") - checkAnswer(sql(query), Row(1, 1)) - sql(s"insert into $tableName values 1") - checkAnswer(sql(query), Row(2, 1)) - sql(s"insert into $tableName values 2") - checkAnswer(sql(query), Row(3, 1)) + testCases.foreach(testCase => { + val query = sql(testCase.query) + checkAnswer(query, testCase.resultSeq(columnValue)) + val hasExpandNodeInPlan = query.queryExecution.optimizedPlan.collectFirst { + case _: Expand => true + }.nonEmpty + assert(hasExpandNodeInPlan == testCase.hasExpandNodeInPlan) + }) + }) } } } From 1b1f8f9c26c9f3bd5c648c7da5a57ba54b6d7bdd Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:01:56 +0200 Subject: [PATCH 3/3] Fixes --- .../catalyst/optimizer/RewriteDistinctAggregates.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 99bfcc0c309b..e91493188873 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -197,12 +197,12 @@ import org.apache.spark.util.collection.Utils * techniques. */ object RewriteDistinctAggregates extends Rule[LogicalPlan] { - private def getRewriteCondition( + private def mustRewrite( aggregateExpressions: Seq[AggregateExpression], groupingExpressions: Seq[Expression]): Boolean = { // If there are any AggregateExpressions with filter, we need to rewrite the query. // Also, if there are no grouping expressions and all aggregate expressions are foldable, - // we can rewrite the query, e.g. SELECT COUNT(DISTINCT 1). + // we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). aggregateExpressions.exists(_.filter.isDefined) || (groupingExpressions.isEmpty && aggregateExpressions.exists(_.aggregateFunction.children.forall(_.foldable))) } @@ -214,7 +214,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // clause for this rule because aggregation strategy can handle a single distinct aggregate // group without filter clause. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 || getRewriteCondition(distinctAggs, a.groupingExpressions) + distinctAggs.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( @@ -245,7 +245,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group without filter clause. - if (distinctAggGroups.size > 1 || getRewriteCondition(distinctAggs, a.groupingExpressions)) { + if (distinctAggGroups.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect {