From 1dc222a7701b306448837318219c85ec638a23fa Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 24 Jun 2025 10:51:25 -0400 Subject: [PATCH] Consolidate subexpression elimination for whole stage and non-whole stage --- .../expressions/codegen/CodeGenerator.scala | 180 ++++++------------ .../codegen/GenerateMutableProjection.scala | 6 +- .../codegen/GeneratePredicate.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 5 +- .../expressions/CodeGenerationSuite.scala | 18 -- .../SubexpressionEliminationSuite.scala | 11 +- .../aggregate/AggregateCodegenSupport.scala | 5 +- .../aggregate/HashAggregateExec.scala | 12 +- .../execution/basicPhysicalOperators.scala | 5 +- 9 files changed, 70 insertions(+), 176 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2564d4eab9bd..f2da04c9558d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -111,6 +111,7 @@ object SubExprEliminationState { */ case class SubExprCodes( states: Map[ExpressionEquals, SubExprEliminationState], + subExprCode: String, exprCodesNeedEvaluate: Seq[ExprCode]) /** @@ -411,29 +412,11 @@ class CodegenContext extends Logging { partitionInitializationStatements.mkString("\n") } - /** - * Holds expressions that are equivalent. Used to perform subexpression elimination - * during codegen. - * - * For expressions that appear more than once, generate additional code to prevent - * recomputing the value. - * - * For example, consider two expression generated from this SQL statement: - * SELECT (col1 + col2), (col1 + col2) / col3. - * - * equivalentExpressions will match the tree containing `col1 + col2` and it will only - * be evaluated once. - */ - private val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - // Foreach expression that is participating in subexpression elimination, the state to use. // Visible for testing. private[expressions] var subExprEliminationExprs = Map.empty[ExpressionEquals, SubExprEliminationState] - // The collection of sub-expression result resetting methods that need to be called on each row. - private val subexprFunctions = mutable.ArrayBuffer.empty[String] - val outerClassName = "OuterClass" /** @@ -1064,15 +1047,6 @@ class CodegenContext extends Logging { } } - /** - * Returns the code for subexpression elimination after splitting it if necessary. - */ - def subexprFunctionsCode: String = { - // Whole-stage codegen's subexpression elimination is handled in another code path - assert(currentVars == null || subexprFunctions.isEmpty) - splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) - } - /** * Perform a function which generates a sequence of ExprCodes with a given mapping between * expressions and common expressions, instead of using the mapping in current context. @@ -1095,20 +1069,17 @@ class CodegenContext extends Logging { * evaluating a subexpression, this method will clean up the code block to avoid duplicate * evaluation. */ - def evaluateSubExprEliminationState(subExprStates: Iterable[SubExprEliminationState]): String = { - val code = new StringBuilder() - - subExprStates.foreach { state => - val currentCode = evaluateSubExprEliminationState(state.children) + "\n" + state.eval.code - code.append(currentCode + "\n") + private def collectSubExprCodes( + subExprStates: Iterable[SubExprEliminationState]): Seq[String] = { + subExprStates.flatMap { state => + val codes = collectSubExprCodes(state.children) :+ state.eval.code.toString() state.eval.code = EmptyBlock - } - - code.toString() + codes + }.toSeq } /** - * Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen. + * Checks and sets up the state and codegen for subexpression elimination. * * This finds the common subexpressions, generates the code snippets that evaluate those * expressions and populates the mapping of common subexpressions to the generated code snippets. @@ -1120,10 +1091,6 @@ class CodegenContext extends Logging { * Besides, this also returns a sequences of `ExprCode` which are expression codes that need to * be evaluated (as their input parameters) before evaluating subexpressions. * - * To evaluate the returned subexpressions, please call `evaluateSubExprEliminationState` with - * the `SubExprEliminationState`s to be evaluated. During generating the code, it will cleanup - * the states to avoid duplicate evaluation. - * * The details of subexpression generation: * 1. Gets subexpression set. See `EquivalentExpressions`. * 2. Generate code of subexpressions as a whole block of code (non-split case) @@ -1141,10 +1108,10 @@ class CodegenContext extends Logging { * (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression * evaluation, we can look for generated subexpressions and do replacement. */ - def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { + def subexpressionElimination(expressions: Seq[Expression]): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - val localSubExprEliminationExprsForNonSplit = + val localSubExprEliminationExprs = mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. @@ -1157,8 +1124,28 @@ class CodegenContext extends Logging { val nonSplitCode = { val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState] commonExprs.map { expr => - withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { + withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { val eval = expr.genCode(this) + + val value = addMutableState(javaType(expr.dataType), "subExprValue") + + val isNullLiteral = eval.isNull match { + case TrueLiteral | FalseLiteral => true + case _ => false + } + val (isNull, isNullEvalCode) = if (!isNullLiteral) { + val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") + } else { + (eval.isNull, "") + } + + val code = code""" + |${eval.code} + |$isNullEvalCode + |$value = ${eval.value}; + """ + // Collects other subexpressions from the children. val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] expr.foreach { e => @@ -1167,8 +1154,10 @@ class CodegenContext extends Logging { case _ => } } - val state = SubExprEliminationState(eval, childrenSubExprs.toSeq) - localSubExprEliminationExprsForNonSplit.put(ExpressionEquals(expr), state) + val state = SubExprEliminationState( + ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), + childrenSubExprs.toSeq) + localSubExprEliminationExprs.put(ExpressionEquals(expr), state) allStates += state Seq(eval) } @@ -1188,38 +1177,18 @@ class CodegenContext extends Logging { val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold val (subExprsMap, exprCodes) = if (needSplit) { if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { - val localSubExprEliminationExprs = - mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] commonExprs.zipWithIndex.foreach { case (expr, i) => - val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { - Seq(expr.genCode(this)) - }.head - - val value = addMutableState(javaType(expr.dataType), "subExprValue") - - val isNullLiteral = eval.isNull match { - case TrueLiteral | FalseLiteral => true - case _ => false - } - val (isNull, isNullEvalCode) = if (!isNullLiteral) { - val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") - } else { - (eval.isNull, "") - } - // Generate the code for this expression tree and wrap it in a function. val fnName = freshName("subExpr") val inputVars = inputVarsForAllFuncs(i) val argList = inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") + val subExprState = localSubExprEliminationExprs.remove(ExpressionEquals(expr)).get val fn = s""" |private void $fnName(${argList.mkString(", ")}) { - | ${eval.code} - | $isNullEvalCode - | $value = ${eval.value}; + | ${subExprState.eval.code} |} """.stripMargin @@ -1235,7 +1204,7 @@ class CodegenContext extends Logging { val inputVariables = inputVars.map(_.variableName).mkString(", ") val code = code"${addNewFunction(fnName, fn)}($inputVariables);" val state = SubExprEliminationState( - ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), + subExprState.eval.copy(code = code), childrenSubExprs.toSeq) localSubExprEliminationExprs.put(ExpressionEquals(expr), state) } @@ -1248,65 +1217,15 @@ class CodegenContext extends Logging { throw SparkException.internalError(errMsg) } else { logInfo(errMsg) - (localSubExprEliminationExprsForNonSplit, Seq.empty) + (localSubExprEliminationExprs, Seq.empty) } } } else { - (localSubExprEliminationExprsForNonSplit, Seq.empty) - } - SubExprCodes(subExprsMap.toMap, exprCodes.flatten) - } - - /** - * Checks and sets up the state and codegen for subexpression elimination. This finds the - * common subexpressions, generates the functions that evaluate those expressions and populates - * the mapping of common subexpressions to the generated functions. - */ - private def subexpressionElimination(expressions: Seq[Expression]): Unit = { - // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree(_)) - - // Get all the expressions that appear at least twice and set up the state for subexpression - // elimination. - val commonExprs = equivalentExpressions.getCommonSubexpressions - commonExprs.foreach { expr => - val fnName = freshName("subExpr") - val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - val value = addMutableState(javaType(expr.dataType), "subExprValue") - - // Generate the code for this expression tree and wrap it in a function. - val eval = expr.genCode(this) - val fn = - s""" - |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code} - | $isNull = ${eval.isNull}; - | $value = ${eval.value}; - |} - """.stripMargin - - // Add a state and a mapping of the common subexpressions that are associate with this - // state. Adding this expression to subExprEliminationExprMap means it will call `fn` - // when it is code generated. This decision should be a cost based one. - // - // The cost of doing subexpression elimination is: - // 1. Extra function call, although this is probably *good* as the JIT can decide to - // inline or not. - // The benefit doing subexpression elimination is: - // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 - // above. - // 2. Less code. - // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with - // at least two nodes) as the cost of doing it is expected to be low. - - val subExprCode = s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - subexprFunctions += subExprCode - val state = SubExprEliminationState( - ExprCode(code"$subExprCode", - JavaCode.isNullGlobal(isNull), - JavaCode.global(value, expr.dataType))) - subExprEliminationExprs += ExpressionEquals(expr) -> state + (localSubExprEliminationExprs, Seq.empty) } + val subExprCode = splitExpressionsWithCurrentInputs( + collectSubExprCodes(subExprsMap.values), "subexprFunc_split") + SubExprCodes(subExprsMap.toMap, subExprCode, exprCodes.flatten) } /** @@ -1316,12 +1235,19 @@ class CodegenContext extends Logging { */ def generateExpressions( expressions: Seq[Expression], - doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { + doSubexpressionElimination: Boolean = false): (Seq[ExprCode], String) = { // We need to make sure that we do not reuse stateful expressions. This is needed for codegen // as well because some expressions may implement `CodegenFallback`. val cleanedExpressions = expressions.map(_.freshCopyIfContainsStatefulExpression()) - if (doSubexpressionElimination) subexpressionElimination(cleanedExpressions) - cleanedExpressions.map(e => e.genCode(this)) + if (doSubexpressionElimination) { + val subExprs = subexpressionElimination(cleanedExpressions) + val generatedExprs = withSubExprEliminationExprs(subExprs.states) { + cleanedExpressions.map(e => e.genCode(this)) + } + (generatedExprs, subExprs.subExprCode) + } else { + (cleanedExpressions.map(e => e.genCode(this)), "") + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 2e018de07101..6db00654ad1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -61,7 +61,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP case (NoOp, _) => false case _ => true } - val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) + val (exprVals, evalSubexpr) = + ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { @@ -91,9 +92,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP (code, update) } - // Evaluate all the subexpressions. - val evalSubexpr = ctx.subexprFunctionsCode - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index c246d07f189b..2383ffc0839e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -38,8 +38,8 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { val ctx = newCodeGenContext() // Do sub-expression elimination for predicates. - val eval = ctx.generateExpressions(Seq(predicate), useSubexprElimination).head - val evalSubexpr = ctx.subexprFunctionsCode + val (evalExprs, evalSubexpr) = ctx.generateExpressions(Seq(predicate), useSubexprElimination) + val eval = evalExprs.head val codeBody = s""" public SpecificPredicate generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 459c1d9a8ba1..d180215783db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -287,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { - val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) + val (exprEvals, evalSubexpr) = ctx.generateExpressions(expressions, useSubexprElimination) val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable)) val numVarLenFields = exprSchemas.count { @@ -299,9 +299,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") - // Evaluate all the subexpression. - val evalSubexpr = ctx.subexprFunctionsCode - val writeExpressions = writeExpressionsToBuffer( ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7ce14bcedf4b..4bbbc368010a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -491,24 +491,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ctx.subExprEliminationExprs.contains(wrap(ref))) assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) } - - // emulate an actual codegen workload - { - val ctx = new CodegenContext - // before - ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE - assert(ctx.subExprEliminationExprs.contains(wrap(add1))) - // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(wrap(ref) -> dummy)) { - assert(ctx.subExprEliminationExprs.contains(wrap(ref))) - assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) - Seq.empty - } - // after - assert(ctx.subExprEliminationExprs.nonEmpty) - assert(ctx.subExprEliminationExprs.contains(wrap(add1))) - assert(!ctx.subExprEliminationExprs.contains(wrap(ref))) - } } test("SPARK-23986: freshName can generate duplicated names") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index e9faeba2411c..a94173d8a1f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -278,11 +278,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel ExprCode(TrueLiteral, oneVar), ExprCode(TrueLiteral, twoVar)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val subExprs = ctx.subexpressionElimination(exprs) ctx.withSubExprEliminationExprs(subExprs.states) { exprs.map(_.genCode(ctx)) } - val subExprsCode = ctx.evaluateSubExprEliminationState(subExprs.states.values) val codeBody = s""" public java.lang.Object generate(Object[] references) { @@ -296,7 +295,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel } public void initialize(int partitionIndex) { - ${subExprsCode} + ${subExprs.subExprCode} } ${ctx.declareAddedFunctions()} @@ -408,16 +407,12 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val exprs = Seq(add1, add1, add2, add2) val ctx = new CodegenContext() - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val subExprs = ctx.subexpressionElimination(exprs) val add2State = subExprs.states(ExpressionEquals(add2)) val add1State = subExprs.states(ExpressionEquals(add1)) assert(add2State.children.contains(add1State)) - subExprs.states.values.foreach { state => - assert(state.eval.code != EmptyBlock) - } - ctx.evaluateSubExprEliminationState(subExprs.states.values) subExprs.states.values.foreach { state => assert(state.eval.code == EmptyBlock) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala index 40112979c6d4..3cadd319b2ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala @@ -210,8 +210,7 @@ trait AggregateCodegenSupport val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) + val subExprs = ctx.subexpressionElimination(boundUpdateExprs.flatten) val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExprsForOneFunc.map(_.genCode(ctx)) @@ -243,7 +242,7 @@ trait AggregateCodegenSupport s""" |// do aggregate |// common sub-expressions - |$effectiveCodes + |${subExprs.subExprCode} |// evaluate aggregate functions and update aggregation buffers |$codeToEvalAggFuncs """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 24528b6f4da1..f6d3bec4ee4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -629,7 +629,7 @@ case class HashAggregateExec( // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, bindReferences[Expression](groupingExpressions, child.output)) - val fastRowKeys = ctx.generateExpressions( + val (fastRowKeys, _) = ctx.generateExpressions( bindReferences[Expression](groupingExpressions, child.output)) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash") @@ -732,8 +732,7 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) + val subExprs = ctx.subexpressionElimination(boundUpdateExprs.flatten) val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExprsForOneFunc.map(_.genCode(ctx)) @@ -765,7 +764,7 @@ case class HashAggregateExec( ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs) s""" |// common sub-expressions - |$effectiveCodes + |${subExprs.subExprCode} |// evaluate aggregate functions and update aggregation buffers |$codeToEvalAggFuncs """.stripMargin @@ -778,8 +777,7 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) + val subExprs = ctx.subexpressionElimination(boundUpdateExprs.flatten) val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExprsForOneFunc.map(_.genCode(ctx)) @@ -815,7 +813,7 @@ case class HashAggregateExec( s""" |if ($fastRowBuffer != null) { | // common sub-expressions - | $effectiveCodes + | ${subExprs.subExprCode} | // evaluate aggregate functions and update aggregation buffers | $codeToEvalAggFuncs |} else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 70ade390c733..1e7b922e4e4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -69,12 +69,11 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) val exprs = bindReferences[Expression](projectList, child.output) val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) { // subexpression elimination - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val subExprs = ctx.subexpressionElimination(exprs) val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { exprs.map(_.genCode(ctx)) } - (ctx.evaluateSubExprEliminationState(subExprs.states.values), genVars, - subExprs.exprCodesNeedEvaluate) + (subExprs.subExprCode, genVars, subExprs.exprCodesNeedEvaluate) } else { ("", exprs.map(_.genCode(ctx)), Seq.empty) }