diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b029a3b0ce917..f1801fa272683 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.planning +import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ @@ -296,12 +298,17 @@ object PhysicalAggregation { // build a set of semantically distinct aggregate expressions and re-write expressions so // that they reference the single copy of the aggregate function which actually gets computed. // Non-deterministic aggregate expressions are not deduplicated. - val equivalentAggregateExpressions = new EquivalentExpressions + val equivalentAggregateExpressions = mutable.Map.empty[Expression, Expression] val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { - // addExpr() always returns false for non-deterministic expressions and do not add them. case a - if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) => + if AggregateExpression.isAggregate(a) && (!a.deterministic || + (if (equivalentAggregateExpressions.contains(a.canonicalized)) { + false + } else { + equivalentAggregateExpressions += a.canonicalized -> a + true + })) => a } } @@ -328,12 +335,12 @@ object PhysicalAggregation { case ae: AggregateExpression => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, // so replace each aggregate expression by its corresponding attribute in the set: - equivalentAggregateExpressions.getExprState(ae).map(_.expr) - .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute + equivalentAggregateExpressions.getOrElse(ae.canonicalized, ae) + .asInstanceOf[AggregateExpression].resultAttribute // Similar to AggregateExpression case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) => - equivalentAggregateExpressions.getExprState(ue).map(_.expr) - .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute + equivalentAggregateExpressions.getOrElse(ue.canonicalized, ue) + .asInstanceOf[PythonUDF].resultAttribute case expression if !expression.foldable => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index b16629f59aa2d..f093f87c96965 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, IntegerType} class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper { test("Semantic equals and hash") { @@ -449,6 +449,20 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(e2.getCommonSubexpressions.size == 1) assert(e2.getCommonSubexpressions.head == add) } + + test("SPARK-42851: Handle supportExpressions consistently across add and get") { + val tx = { + val arr = Literal(Array(1, 2)) + val ArrayType(et, cn) = arr.dataType + val lv = NamedLambdaVariable("x", et, cn) + val lambda = LambdaFunction(lv, Seq(lv)) + ArrayTransform(arr, lambda) + } + val equivalence = new EquivalentExpressions + val isNewExpr = !equivalence.addExpr(tx) + val cseState = equivalence.getExprState(tx) + assert(isNewExpr == cseState.isDefined) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 737d31cc6e913..2ba9039166f48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1538,6 +1538,13 @@ class DataFrameAggregateSuite extends QueryTest ) checkAnswer(res, Row(1, 1, 1) :: Row(4, 1, 2) :: Nil) } + + test("SPARK-42851: common subexpression should consistently handle aggregate and result exprs") { + val res = sql( + "select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)" + ) + checkAnswer(res, Row(Array(1), Array(1))) + } } case class B(c: Option[Double])