-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-49000][SQL] Fix "select count(distinct 1) from t" where t is empty table by expanding RewriteDistinctAggregates #47525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -197,6 +197,15 @@ import org.apache.spark.util.collection.Utils | |
| * techniques. | ||
| */ | ||
| object RewriteDistinctAggregates extends Rule[LogicalPlan] { | ||
| private def mustRewrite( | ||
| aggregateExpressions: Seq[AggregateExpression], | ||
| groupingExpressions: Seq[Expression]): Boolean = { | ||
| // If there are any AggregateExpressions with filter, we need to rewrite the query. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/any/any distinct/ |
||
| // Also, if there are no grouping expressions and all aggregate expressions are foldable, | ||
| // we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Compared to the comment in To better improve the code readability, it would be better to explain why the rewriting is needed for the case. |
||
| aggregateExpressions.exists(_.filter.isDefined) || (groupingExpressions.isEmpty && | ||
| aggregateExpressions.exists(_.aggregateFunction.children.forall(_.foldable))) | ||
| } | ||
|
|
||
| private def mayNeedtoRewrite(a: Aggregate): Boolean = { | ||
| val aggExpressions = collectAggregateExprs(a) | ||
|
|
@@ -205,7 +214,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| // clause for this rule because aggregation strategy can handle a single distinct aggregate | ||
| // group without filter clause. | ||
| // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). | ||
|
Comment on lines
214
to
216
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is better to update the comment. |
||
| distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) | ||
| distinctAggs.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions) | ||
| } | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( | ||
|
|
@@ -236,7 +245,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| } | ||
|
|
||
| // Aggregation strategy can handle queries with a single distinct group without filter clause. | ||
| if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { | ||
| if (distinctAggGroups.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)) { | ||
| // Create the attributes for the grouping id and the group by clause. | ||
| val gid = AttributeReference("gid", IntegerType, nullable = false)() | ||
| val groupByMap = a.groupingExpressions.collect { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import scala.util.Random | |
| import org.scalatest.matchers.must.Matchers.the | ||
|
|
||
| import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} | ||
| import org.apache.spark.sql.catalyst.plans.logical.Expand | ||
| import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS | ||
| import org.apache.spark.sql.execution.WholeStageCodegenExec | ||
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | ||
|
|
@@ -2339,6 +2340,119 @@ class DataFrameAggregateSuite extends QueryTest | |
| test("SPARK-32761: aggregating multiple distinct CONSTANT columns") { | ||
| checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1, 1)) | ||
| } | ||
|
|
||
| test("aggregating with various distinct expressions") { | ||
| abstract class AggregateTestCaseBase( | ||
| val query: String, | ||
| val resultSeq: Seq[Seq[Row]], | ||
| val hasExpandNodeInPlan: Boolean) | ||
| case class AggregateTestCase( | ||
| override val query: String, | ||
| override val resultSeq: Seq[Seq[Row]], | ||
| override val hasExpandNodeInPlan: Boolean) | ||
| extends AggregateTestCaseBase(query, resultSeq, hasExpandNodeInPlan) | ||
| case class AggregateTestCaseDefault( | ||
| override val query: String) | ||
| extends AggregateTestCaseBase( | ||
| query, | ||
| Seq(Seq(Row(0)), Seq(Row(1)), Seq(Row(1))), | ||
| hasExpandNodeInPlan = true) | ||
|
|
||
| val t = "t" | ||
| val testCases: Seq[AggregateTestCaseBase] = Seq( | ||
| AggregateTestCaseDefault( | ||
| s"""SELECT COUNT(DISTINCT "col") FROM $t""" | ||
nikolamand-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ), | ||
| AggregateTestCaseDefault( | ||
| s"SELECT COUNT(DISTINCT 1) FROM $t" | ||
| ), | ||
| AggregateTestCaseDefault( | ||
| s"SELECT COUNT(DISTINCT 1 + 2) FROM $t" | ||
| ), | ||
| AggregateTestCaseDefault( | ||
| s"SELECT COUNT(DISTINCT 1, 2, 1 + 2) FROM $t" | ||
| ), | ||
| AggregateTestCase( | ||
| s"SELECT COUNT(1), COUNT(DISTINCT 1) FROM $t", | ||
| Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(2, 1))), | ||
| hasExpandNodeInPlan = true | ||
| ), | ||
| AggregateTestCaseDefault( | ||
| s"""SELECT COUNT(DISTINCT 1, "col") FROM $t""" | ||
| ), | ||
| AggregateTestCaseDefault( | ||
| s"""SELECT COUNT(DISTINCT collation("abc")) FROM $t""" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test case cannot be merged into branch-3.5, as collation is a new function added in Spark 4.0.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, collation doesn't exist in older version - so this test will need to be excluded I can take care of that in a follow-up |
||
| ), | ||
| AggregateTestCaseDefault( | ||
| s"""SELECT COUNT(DISTINCT current_date()) FROM $t""" | ||
| ), | ||
| AggregateTestCaseDefault( | ||
| s"""SELECT COUNT(DISTINCT array(1, 2)[1]) FROM $t""" | ||
| ), | ||
| AggregateTestCaseDefault( | ||
| s"""SELECT COUNT(DISTINCT map(1, 2)[1]) FROM $t""" | ||
| ), | ||
| AggregateTestCaseDefault( | ||
| s"""SELECT COUNT(DISTINCT struct(1, 2).col1) FROM $t""" | ||
| ), | ||
| AggregateTestCase( | ||
| s"SELECT COUNT(DISTINCT 1) FROM $t GROUP BY col", | ||
| Seq(Seq(), Seq(Row(1)), Seq(Row(1), Row(1))), | ||
| hasExpandNodeInPlan = false | ||
| ), | ||
| AggregateTestCaseDefault( | ||
| s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 1" | ||
| ), | ||
| AggregateTestCase( | ||
| s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 0", | ||
| Seq(Seq(Row(0)), Seq(Row(0)), Seq(Row(0))), | ||
| hasExpandNodeInPlan = false | ||
| ), | ||
| AggregateTestCase( | ||
| s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", | ||
| Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), | ||
| hasExpandNodeInPlan = false | ||
| ), | ||
| AggregateTestCase( | ||
| s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $t)", | ||
| Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), | ||
| hasExpandNodeInPlan = false | ||
| ), | ||
| AggregateTestCase( | ||
| s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", | ||
| Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), | ||
| hasExpandNodeInPlan = false | ||
| ), | ||
| AggregateTestCaseDefault( | ||
| s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $t)"), | ||
| AggregateTestCase( | ||
| s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT "col") FROM $t""", | ||
| Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 1))), | ||
| hasExpandNodeInPlan = true | ||
| ), | ||
| AggregateTestCase( | ||
| s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT col) FROM $t""", | ||
nikolamand-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 2))), | ||
| hasExpandNodeInPlan = true | ||
| ) | ||
| ) | ||
| withTable(t) { | ||
| sql(s"create table $t(col int) using parquet") | ||
| Seq(0, 1, 2).foreach(columnValue => { | ||
| if (columnValue != 0) { | ||
| sql(s"insert into $t(col) values($columnValue)") | ||
| } | ||
| testCases.foreach(testCase => { | ||
| val query = sql(testCase.query) | ||
| checkAnswer(query, testCase.resultSeq(columnValue)) | ||
| val hasExpandNodeInPlan = query.queryExecution.optimizedPlan.collectFirst { | ||
| case _: Expand => true | ||
| }.nonEmpty | ||
| assert(hasExpandNodeInPlan == testCase.hasExpandNodeInPlan) | ||
| }) | ||
| }) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| case class B(c: Option[Double]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/aggregateExpressions/distinctAggs/