Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.planning

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -296,12 +298,17 @@ object PhysicalAggregation {
// build a set of semantically distinct aggregate expressions and re-write expressions so
// that they reference the single copy of the aggregate function which actually gets computed.
// Non-deterministic aggregate expressions are not deduplicated.
val equivalentAggregateExpressions = new EquivalentExpressions
val equivalentAggregateExpressions = mutable.Map.empty[Expression, Expression]
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
// addExpr() always returns false for non-deterministic expressions and do not add them.
case a
if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's wrong with addExpr here? It does simplify the code IMO.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line of thought would be: adding the supportedExpression guard to addExpr() would cause performance regression, so let's just close our eyes and make the only remaining use of addExpr break away and do its own deduplication in the old logic without taking things like NamedLambdaVariable into account -- which is the way it's been for quite a few releases. This PR essentially inlines the addExpr path of the old EquivalentExpressions into PhysicalAggregation to recover what it used to do.

Copy link
Contributor Author

@peter-toth peter-toth Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides the above, although .addExpr() fits here well and does the job, isn't it a bit weird that an add-like method of a collection-like object doesn't return true when a new item was added, but actually it flips the meaning of the return value? If it was used at multiple places then I would keep it, but we use it only here. But maybe I'm just nitpicking...
Anyways, I'm ok with #40473 too.

if AggregateExpression.isAggregate(a) && (!a.deterministic ||
(if (equivalentAggregateExpressions.contains(a.canonicalized)) {
false
} else {
equivalentAggregateExpressions += a.canonicalized -> a
true
})) =>
a
}
}
Expand All @@ -328,12 +335,12 @@ object PhysicalAggregation {
case ae: AggregateExpression =>
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
// so replace each aggregate expression by its corresponding attribute in the set:
equivalentAggregateExpressions.getExprState(ae).map(_.expr)
.getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
equivalentAggregateExpressions.getOrElse(ae.canonicalized, ae)
.asInstanceOf[AggregateExpression].resultAttribute
// Similar to AggregateExpression
case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) =>
equivalentAggregateExpressions.getExprState(ue).map(_.expr)
.getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute
equivalentAggregateExpressions.getOrElse(ue.canonicalized, ue)
.asInstanceOf[PythonUDF].resultAttribute
case expression if !expression.foldable =>
// Since we're using `namedGroupingAttributes` to extract the grouping key
// columns, we need to replace grouping key expressions with their corresponding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType}
import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, IntegerType}

class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Semantic equals and hash") {
Expand Down Expand Up @@ -449,6 +449,20 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
assert(e2.getCommonSubexpressions.size == 1)
assert(e2.getCommonSubexpressions.head == add)
}

test("SPARK-42851: Handle supportExpressions consistently across add and get") {
val tx = {
val arr = Literal(Array(1, 2))
val ArrayType(et, cn) = arr.dataType
val lv = NamedLambdaVariable("x", et, cn)
val lambda = LambdaFunction(lv, Seq(lv))
ArrayTransform(arr, lambda)
}
val equivalence = new EquivalentExpressions
val isNewExpr = !equivalence.addExpr(tx)
val cseState = equivalence.getExprState(tx)
assert(isNewExpr == cseState.isDefined)
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,13 @@ class DataFrameAggregateSuite extends QueryTest
)
checkAnswer(res, Row(1, 1, 1) :: Row(4, 1, 2) :: Nil)
}

test("SPARK-42851: common subexpression should consistently handle aggregate and result exprs") {
val res = sql(
"select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)"
)
checkAnswer(res, Row(Array(1), Array(1)))
}
}

case class B(c: Option[Double])
Expand Down