Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ import org.apache.spark.util.collection.Utils
* techniques.
*/
object RewriteDistinctAggregates extends Rule[LogicalPlan] {
private def mustRewrite(
aggregateExpressions: Seq[AggregateExpression],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/aggregateExpressions/distinctAggs/

groupingExpressions: Seq[Expression]): Boolean = {
// If there are any AggregateExpressions with filter, we need to rewrite the query.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/any/any distinct/

// Also, if there are no grouping expressions and all aggregate expressions are foldable,
// we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1).
Copy link
Member

@viirya viirya Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compared to the comment in mayNeedtoRewrite which explains why rewriting is necessary. This comment doesn't do any explanation but just claims it needs to rewrite the query. This comment simply describes what the code does and it is obvious.

To better improve the code readability, it would be better to explain why the rewriting is needed for the case.

aggregateExpressions.exists(_.filter.isDefined) || (groupingExpressions.isEmpty &&
aggregateExpressions.exists(_.aggregateFunction.children.forall(_.foldable)))
}

private def mayNeedtoRewrite(a: Aggregate): Boolean = {
val aggExpressions = collectAggregateExprs(a)
Expand All @@ -205,7 +214,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// clause for this rule because aggregation strategy can handle a single distinct aggregate
// group without filter clause.
// This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a).
Comment on lines 214 to 216
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to update the comment.

distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined)
distinctAggs.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
Expand Down Expand Up @@ -236,7 +245,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}

// Aggregation strategy can handle queries with a single distinct group without filter clause.
if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) {
if (distinctAggGroups.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)) {
// Create the attributes for the grouping id and the group by clause.
val gid = AttributeReference("gid", IntegerType, nullable = false)()
val groupByMap = a.groupingExpressions.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.util.Random
import org.scalatest.matchers.must.Matchers.the

import org.apache.spark.{SparkArithmeticException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.plans.logical.Expand
import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
Expand Down Expand Up @@ -2339,6 +2340,119 @@ class DataFrameAggregateSuite extends QueryTest
test("SPARK-32761: aggregating multiple distinct CONSTANT columns") {
checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1, 1))
}

test("aggregating with various distinct expressions") {
abstract class AggregateTestCaseBase(
val query: String,
val resultSeq: Seq[Seq[Row]],
val hasExpandNodeInPlan: Boolean)
case class AggregateTestCase(
override val query: String,
override val resultSeq: Seq[Seq[Row]],
override val hasExpandNodeInPlan: Boolean)
extends AggregateTestCaseBase(query, resultSeq, hasExpandNodeInPlan)
case class AggregateTestCaseDefault(
override val query: String)
extends AggregateTestCaseBase(
query,
Seq(Seq(Row(0)), Seq(Row(1)), Seq(Row(1))),
hasExpandNodeInPlan = true)

val t = "t"
val testCases: Seq[AggregateTestCaseBase] = Seq(
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT "col") FROM $t"""
),
AggregateTestCaseDefault(
s"SELECT COUNT(DISTINCT 1) FROM $t"
),
AggregateTestCaseDefault(
s"SELECT COUNT(DISTINCT 1 + 2) FROM $t"
),
AggregateTestCaseDefault(
s"SELECT COUNT(DISTINCT 1, 2, 1 + 2) FROM $t"
),
AggregateTestCase(
s"SELECT COUNT(1), COUNT(DISTINCT 1) FROM $t",
Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(2, 1))),
hasExpandNodeInPlan = true
),
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT 1, "col") FROM $t"""
),
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT collation("abc")) FROM $t"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case cannot be merged into branch-3.5, as collation is a new function added in Spark 4.0.
cc @nikolamand-db @cloud-fan
also cc @yaooqinn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, collation doesn't exist in older version - so this test will need to be excluded

I can take care of that in a follow-up

),
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT current_date()) FROM $t"""
),
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT array(1, 2)[1]) FROM $t"""
),
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT map(1, 2)[1]) FROM $t"""
),
AggregateTestCaseDefault(
s"""SELECT COUNT(DISTINCT struct(1, 2).col1) FROM $t"""
),
AggregateTestCase(
s"SELECT COUNT(DISTINCT 1) FROM $t GROUP BY col",
Seq(Seq(), Seq(Row(1)), Seq(Row(1), Row(1))),
hasExpandNodeInPlan = false
),
AggregateTestCaseDefault(
s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 1"
),
AggregateTestCase(
s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 0",
Seq(Seq(Row(0)), Seq(Row(0)), Seq(Row(0))),
hasExpandNodeInPlan = false
),
AggregateTestCase(
s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)",
Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
hasExpandNodeInPlan = false
),
AggregateTestCase(
s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $t)",
Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
hasExpandNodeInPlan = false
),
AggregateTestCase(
s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)",
Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
hasExpandNodeInPlan = false
),
AggregateTestCaseDefault(
s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $t)"),
AggregateTestCase(
s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT "col") FROM $t""",
Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 1))),
hasExpandNodeInPlan = true
),
AggregateTestCase(
s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT col) FROM $t""",
Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 2))),
hasExpandNodeInPlan = true
)
)
withTable(t) {
sql(s"create table $t(col int) using parquet")
Seq(0, 1, 2).foreach(columnValue => {
if (columnValue != 0) {
sql(s"insert into $t(col) values($columnValue)")
}
testCases.foreach(testCase => {
val query = sql(testCase.query)
checkAnswer(query, testCase.resultSeq(columnValue))
val hasExpandNodeInPlan = query.queryExecution.optimizedPlan.collectFirst {
case _: Expand => true
}.nonEmpty
assert(hasExpandNodeInPlan == testCase.hasExpandNodeInPlan)
})
})
}
}
}

case class B(c: Option[Double])
Expand Down