From 6ab256f858fab4979166a561cda35d790e3283a3 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Wed, 19 Mar 2025 14:31:06 +0000 Subject: [PATCH 1/8] Consolidate subexpression elimination for whole stage and non-whole stage --- .../expressions/codegen/CodeGenerator.scala | 174 ++++++------------ .../codegen/GenerateMutableProjection.scala | 6 +- .../codegen/GeneratePredicate.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 5 +- .../expressions/CodeGenerationSuite.scala | 18 -- .../SubexpressionEliminationSuite.scala | 4 +- .../aggregate/AggregateCodegenSupport.scala | 2 +- .../aggregate/HashAggregateExec.scala | 6 +- .../execution/basicPhysicalOperators.scala | 2 +- 9 files changed, 65 insertions(+), 156 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2564d4eab9bd..173fc1857cb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -411,29 +411,11 @@ class CodegenContext extends Logging { partitionInitializationStatements.mkString("\n") } - /** - * Holds expressions that are equivalent. Used to perform subexpression elimination - * during codegen. - * - * For expressions that appear more than once, generate additional code to prevent - * recomputing the value. - * - * For example, consider two expression generated from this SQL statement: - * SELECT (col1 + col2), (col1 + col2) / col3. - * - * equivalentExpressions will match the tree containing `col1 + col2` and it will only - * be evaluated once. - */ - private val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - // Foreach expression that is participating in subexpression elimination, the state to use. // Visible for testing. private[expressions] var subExprEliminationExprs = Map.empty[ExpressionEquals, SubExprEliminationState] - // The collection of sub-expression result resetting methods that need to be called on each row. - private val subexprFunctions = mutable.ArrayBuffer.empty[String] - val outerClassName = "OuterClass" /** @@ -1064,15 +1046,6 @@ class CodegenContext extends Logging { } } - /** - * Returns the code for subexpression elimination after splitting it if necessary. - */ - def subexprFunctionsCode: String = { - // Whole-stage codegen's subexpression elimination is handled in another code path - assert(currentVars == null || subexprFunctions.isEmpty) - splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) - } - /** * Perform a function which generates a sequence of ExprCodes with a given mapping between * expressions and common expressions, instead of using the mapping in current context. @@ -1090,25 +1063,26 @@ class CodegenContext extends Logging { genCodes } + private def collectSubExprCodes(subExprStates: Seq[SubExprEliminationState]): Seq[String] = { + subExprStates.flatMap { state => + val codes = collectSubExprCodes(state.children) :+ state.eval.code.toString() + state.eval.code = EmptyBlock + codes + } + } + /** * Evaluates a sequence of `SubExprEliminationState` which represent subexpressions. After * evaluating a subexpression, this method will clean up the code block to avoid duplicate * evaluation. */ def evaluateSubExprEliminationState(subExprStates: Iterable[SubExprEliminationState]): String = { - val code = new StringBuilder() - - subExprStates.foreach { state => - val currentCode = evaluateSubExprEliminationState(state.children) + "\n" + state.eval.code - code.append(currentCode + "\n") - state.eval.code = EmptyBlock - } - - code.toString() + val codes = collectSubExprCodes(subExprStates.toSeq) + splitExpressionsWithCurrentInputs(codes, "subexprFunc_split") } /** - * Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen. + * Checks and sets up the state and codegen for subexpression elimination. * * This finds the common subexpressions, generates the code snippets that evaluate those * expressions and populates the mapping of common subexpressions to the generated code snippets. @@ -1141,10 +1115,10 @@ class CodegenContext extends Logging { * (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression * evaluation, we can look for generated subexpressions and do replacement. */ - def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { + def subexpressionElimination(expressions: Seq[Expression]): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - val localSubExprEliminationExprsForNonSplit = + val localSubExprEliminationExprs = mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. @@ -1157,8 +1131,28 @@ class CodegenContext extends Logging { val nonSplitCode = { val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState] commonExprs.map { expr => - withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { + withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { val eval = expr.genCode(this) + + val value = addMutableState(javaType(expr.dataType), "subExprValue") + + val isNullLiteral = eval.isNull match { + case TrueLiteral | FalseLiteral => true + case _ => false + } + val (isNull, isNullEvalCode) = if (!isNullLiteral) { + val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") + } else { + (eval.isNull, "") + } + + val code = code""" + |${eval.code} + |$isNullEvalCode + |$value = ${eval.value}; + """ + // Collects other subexpressions from the children. val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] expr.foreach { e => @@ -1167,8 +1161,10 @@ class CodegenContext extends Logging { case _ => } } - val state = SubExprEliminationState(eval, childrenSubExprs.toSeq) - localSubExprEliminationExprsForNonSplit.put(ExpressionEquals(expr), state) + val state = SubExprEliminationState( + ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), + childrenSubExprs.toSeq) + localSubExprEliminationExprs.put(ExpressionEquals(expr), state) allStates += state Seq(eval) } @@ -1188,38 +1184,18 @@ class CodegenContext extends Logging { val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold val (subExprsMap, exprCodes) = if (needSplit) { if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { - val localSubExprEliminationExprs = - mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] commonExprs.zipWithIndex.foreach { case (expr, i) => - val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { - Seq(expr.genCode(this)) - }.head - - val value = addMutableState(javaType(expr.dataType), "subExprValue") - - val isNullLiteral = eval.isNull match { - case TrueLiteral | FalseLiteral => true - case _ => false - } - val (isNull, isNullEvalCode) = if (!isNullLiteral) { - val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") - } else { - (eval.isNull, "") - } - // Generate the code for this expression tree and wrap it in a function. val fnName = freshName("subExpr") val inputVars = inputVarsForAllFuncs(i) val argList = inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") + val subExprState = localSubExprEliminationExprs.remove(ExpressionEquals(expr)).get val fn = s""" |private void $fnName(${argList.mkString(", ")}) { - | ${eval.code} - | $isNullEvalCode - | $value = ${eval.value}; + | ${subExprState.eval.code} |} """.stripMargin @@ -1235,7 +1211,7 @@ class CodegenContext extends Logging { val inputVariables = inputVars.map(_.variableName).mkString(", ") val code = code"${addNewFunction(fnName, fn)}($inputVariables);" val state = SubExprEliminationState( - ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), + subExprState.eval.copy(code = code), childrenSubExprs.toSeq) localSubExprEliminationExprs.put(ExpressionEquals(expr), state) } @@ -1248,67 +1224,15 @@ class CodegenContext extends Logging { throw SparkException.internalError(errMsg) } else { logInfo(errMsg) - (localSubExprEliminationExprsForNonSplit, Seq.empty) + (localSubExprEliminationExprs, Seq.empty) } } } else { - (localSubExprEliminationExprsForNonSplit, Seq.empty) + (localSubExprEliminationExprs, Seq.empty) } SubExprCodes(subExprsMap.toMap, exprCodes.flatten) } - /** - * Checks and sets up the state and codegen for subexpression elimination. This finds the - * common subexpressions, generates the functions that evaluate those expressions and populates - * the mapping of common subexpressions to the generated functions. - */ - private def subexpressionElimination(expressions: Seq[Expression]): Unit = { - // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree(_)) - - // Get all the expressions that appear at least twice and set up the state for subexpression - // elimination. - val commonExprs = equivalentExpressions.getCommonSubexpressions - commonExprs.foreach { expr => - val fnName = freshName("subExpr") - val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - val value = addMutableState(javaType(expr.dataType), "subExprValue") - - // Generate the code for this expression tree and wrap it in a function. - val eval = expr.genCode(this) - val fn = - s""" - |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code} - | $isNull = ${eval.isNull}; - | $value = ${eval.value}; - |} - """.stripMargin - - // Add a state and a mapping of the common subexpressions that are associate with this - // state. Adding this expression to subExprEliminationExprMap means it will call `fn` - // when it is code generated. This decision should be a cost based one. - // - // The cost of doing subexpression elimination is: - // 1. Extra function call, although this is probably *good* as the JIT can decide to - // inline or not. - // The benefit doing subexpression elimination is: - // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 - // above. - // 2. Less code. - // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with - // at least two nodes) as the cost of doing it is expected to be low. - - val subExprCode = s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - subexprFunctions += subExprCode - val state = SubExprEliminationState( - ExprCode(code"$subExprCode", - JavaCode.isNullGlobal(isNull), - JavaCode.global(value, expr.dataType))) - subExprEliminationExprs += ExpressionEquals(expr) -> state - } - } - /** * Generates code for expressions. If doSubexpressionElimination is true, subexpression * elimination will be performed. Subexpression elimination assumes that the code for each @@ -1316,12 +1240,20 @@ class CodegenContext extends Logging { */ def generateExpressions( expressions: Seq[Expression], - doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { + doSubexpressionElimination: Boolean = false): (Seq[ExprCode], String) = { // We need to make sure that we do not reuse stateful expressions. This is needed for codegen // as well because some expressions may implement `CodegenFallback`. val cleanedExpressions = expressions.map(_.freshCopyIfContainsStatefulExpression()) - if (doSubexpressionElimination) subexpressionElimination(cleanedExpressions) + if (doSubexpressionElimination) { + val subExprs = subexpressionElimination(cleanedExpressions) + val generatedExprs = withSubExprEliminationExprs(subExprs.states) { cleanedExpressions.map(e => e.genCode(this)) + } + val subExprCode = evaluateSubExprEliminationState(subExprs.states.values) + (generatedExprs, subExprCode) + } else { + (cleanedExpressions.map(e => e.genCode(this)), "") + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 2e018de07101..6db00654ad1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -61,7 +61,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP case (NoOp, _) => false case _ => true } - val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) + val (exprVals, evalSubexpr) = + ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { @@ -91,9 +92,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP (code, update) } - // Evaluate all the subexpressions. - val evalSubexpr = ctx.subexprFunctionsCode - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index c246d07f189b..2383ffc0839e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -38,8 +38,8 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { val ctx = newCodeGenContext() // Do sub-expression elimination for predicates. - val eval = ctx.generateExpressions(Seq(predicate), useSubexprElimination).head - val evalSubexpr = ctx.subexprFunctionsCode + val (evalExprs, evalSubexpr) = ctx.generateExpressions(Seq(predicate), useSubexprElimination) + val eval = evalExprs.head val codeBody = s""" public SpecificPredicate generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 459c1d9a8ba1..d180215783db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -287,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { - val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) + val (exprEvals, evalSubexpr) = ctx.generateExpressions(expressions, useSubexprElimination) val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable)) val numVarLenFields = exprSchemas.count { @@ -299,9 +299,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") - // Evaluate all the subexpression. - val evalSubexpr = ctx.subexprFunctionsCode - val writeExpressions = writeExpressionsToBuffer( ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7ce14bcedf4b..4bbbc368010a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -491,24 +491,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ctx.subExprEliminationExprs.contains(wrap(ref))) assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) } - - // emulate an actual codegen workload - { - val ctx = new CodegenContext - // before - ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE - assert(ctx.subExprEliminationExprs.contains(wrap(add1))) - // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(wrap(ref) -> dummy)) { - assert(ctx.subExprEliminationExprs.contains(wrap(ref))) - assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) - Seq.empty - } - // after - assert(ctx.subExprEliminationExprs.nonEmpty) - assert(ctx.subExprEliminationExprs.contains(wrap(add1))) - assert(!ctx.subExprEliminationExprs.contains(wrap(ref))) - } } test("SPARK-23986: freshName can generate duplicated names") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index e9faeba2411c..dfbbaf59c075 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -278,7 +278,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel ExprCode(TrueLiteral, oneVar), ExprCode(TrueLiteral, twoVar)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val subExprs = ctx.subexpressionElimination(exprs) ctx.withSubExprEliminationExprs(subExprs.states) { exprs.map(_.genCode(ctx)) } @@ -408,7 +408,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val exprs = Seq(add1, add1, add2, add2) val ctx = new CodegenContext() - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val subExprs = ctx.subexpressionElimination(exprs) val add2State = subExprs.states(ExpressionEquals(add2)) val add1State = subExprs.states(ExpressionEquals(add1)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala index 40112979c6d4..fe01eed63363 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala @@ -210,7 +210,7 @@ trait AggregateCodegenSupport val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val subExprs = ctx.subexpressionElimination(boundUpdateExprs.flatten) val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 24528b6f4da1..5904e0c9cace 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -629,7 +629,7 @@ case class HashAggregateExec( // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, bindReferences[Expression](groupingExpressions, child.output)) - val fastRowKeys = ctx.generateExpressions( + val (fastRowKeys, _) = ctx.generateExpressions( bindReferences[Expression](groupingExpressions, child.output)) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash") @@ -732,7 +732,7 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val subExprs = ctx.subexpressionElimination(boundUpdateExprs.flatten) val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { @@ -778,7 +778,7 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val subExprs = ctx.subexpressionElimination(boundUpdateExprs.flatten) val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 70ade390c733..cc2cfff3c73e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -69,7 +69,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) val exprs = bindReferences[Expression](projectList, child.output) val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) { // subexpression elimination - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val subExprs = ctx.subexpressionElimination(exprs) val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { exprs.map(_.genCode(ctx)) } From 60ca4981005693a0e4b86a53e4793deb012ef8f0 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Sat, 7 May 2022 10:41:03 -0400 Subject: [PATCH 2/8] Add codegen support to array functions --- .../expressions/EquivalentExpressions.scala | 4 + .../expressions/codegen/CodeGenerator.scala | 34 ++ .../expressions/higherOrderFunctions.scala | 412 +++++++++++++++++- .../sql/errors/QueryExecutionErrors.scala | 9 + 4 files changed, 451 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 78f73f8778b8..43d29ab27e15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -146,9 +146,13 @@ class EquivalentExpressions( // There are some special expressions that we should not recurse into all of its children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) // 2. ConditionalExpression: use its children that will always be evaluated. + // 3. HigherOrderFunction: lambda functions operate in the context of local lambdas and can't + // be called outside of that scope, only the arguments can be evaluated ahead of + // time. private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { case _: CodegenFallback => Nil case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut) + case h: HigherOrderFunction => h.arguments case other => skipForShortcut(other).children } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 173fc1857cb7..4d4c448cfcc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -174,6 +174,40 @@ class CodegenContext extends Logging { */ var currentVars: Seq[ExprCode] = null + /** + * Holding a map of current lambda variables. + */ + var currentLambdaVars: mutable.Map[String, ExprCode] = mutable.HashMap.empty + + def withLambdaVars(namedLambdas: Seq[NamedLambdaVariable], + f: Seq[ExprCode] => ExprCode): ExprCode = { + val lambdaVars = namedLambdas.map { namedLambda => + val name = namedLambda.variableName + if (currentLambdaVars.get(name).nonEmpty) { + throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(name) + } + val isNull = if (namedLambda.nullable) { + JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "lambdaIsNull")) + } else { + FalseLiteral + } + val value = addMutableState(javaType(namedLambda.dataType), "lambdaValue") + val lambdaVar = ExprCode(isNull, JavaCode.global(value, namedLambda.dataType)) + currentLambdaVars.put(name, lambdaVar) + lambdaVar + } + + val result = f(lambdaVars) + namedLambdas.foreach(v => currentLambdaVars.remove(v.variableName)) + result + } + + def getLambdaVar(name: String): ExprCode = { + currentLambdaVars.getOrElse(name, { + throw QueryExecutionErrors.lambdaVariableNotDefinedError(name) + }) + } + /** * Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a * 2-tuple: java type, variable name. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 2a5a38e93706..72999e0725f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -81,8 +82,7 @@ case class NamedLambdaVariable( exprId: ExprId = NamedExpression.newExprId, value: AtomicReference[Any] = new AtomicReference()) extends LeafExpression - with NamedExpression - with CodegenFallback { + with NamedExpression { override def qualifier: Seq[String] = Seq.empty @@ -103,6 +103,14 @@ case class NamedLambdaVariable( override def simpleString(maxFields: Int): String = { s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}" } + + // We need to include the Expr ID in the Codegen variable name since several tests bypass + // `UnresolvedNamedLambdaVariable.freshVarName` + lazy val variableName = s"${name}_${exprId.id}" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.getLambdaVar(variableName) + } } /** @@ -114,7 +122,7 @@ case class LambdaFunction( function: Expression, arguments: Seq[NamedExpression], hidden: Boolean = false) - extends Expression with CodegenFallback { + extends Expression { override def children: Seq[Expression] = function +: arguments override def dataType: DataType = function.dataType @@ -132,6 +140,23 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val functionCode = function.genCode(ctx) + + if (nullable) { + ev.copy(code = code""" + |${functionCode.code} + |boolean ${ev.isNull} = ${functionCode.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value}; + """.stripMargin) + } else { + ev.copy(code = code""" + |${functionCode.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value}; + """.stripMargin, isNull = FalseLiteral) + } + } + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): LambdaFunction = copy( @@ -239,6 +264,53 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { val canonicalizedChildren = cleaned.children.map(_.canonicalized) withNewChildren(canonicalizedChildren) } + + + protected def assignAtomic(atomicRef: String, value: String, isNull: String = FalseLiteral, + nullable: Boolean = false) = { + if (nullable) { + s""" + if ($isNull) { + $atomicRef.set(null); + } else { + $atomicRef.set($value); + } + """ + } else { + s"$atomicRef.set($value);" + } + } + + protected def assignArrayElement(ctx: CodegenContext, arrayName: String, elementCode: ExprCode, + elementVar: NamedLambdaVariable, index: String): String = { + val elementType = elementVar.dataType + val elementAtomic = ctx.addReferenceObj(elementVar.variableName, elementVar.value) + val extractElement = CodeGenerator.getValue(arrayName, elementType, index) + val atomicAssign = assignAtomic(elementAtomic, elementCode.value, + elementCode.isNull, elementVar.nullable) + + if (elementVar.nullable) { + s""" + ${elementCode.value} = $extractElement; + ${elementCode.isNull} = $arrayName.isNullAt($index); + $atomicAssign + """ + } else { + s""" + ${elementCode.value} = $extractElement; + $atomicAssign + """ + } + } + + protected def assignIndex(ctx: CodegenContext, indexCode: ExprCode, + indexVar: NamedLambdaVariable, index: String): String = { + val indexAtomic = ctx.addReferenceObj(indexVar.variableName, indexVar.value) + s""" + ${indexCode.value} = $index; + ${assignAtomic(indexAtomic, indexCode.value)} + """ + } } /** @@ -284,6 +356,29 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expr } } + protected def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: String => String): ExprCode = { + val argumentGen = argument.genCode(ctx) + val resultCode = f(argumentGen.value) + + if (nullable) { + val nullSafeEval = ctx.nullSafeExec(argument.nullable, argumentGen.isNull)(resultCode) + ev.copy(code = code""" + |${argumentGen.code} + |boolean ${ev.isNull} = ${argumentGen.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """) + } else { + ev.copy(code = code""" + |${argumentGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """, isNull = FalseLiteral) + } + } } trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { @@ -312,7 +407,7 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { case class ArrayTransform( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction { override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) @@ -354,6 +449,49 @@ case class ArrayTransform( result } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, { lambdaExprs => + val elementCode = lambdaExprs.head + val indexCode = lambdaExprs.tail.headOption + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + val i = ctx.freshName("i") + + val initialization = CodeGenerator.createArrayData( + arrayData, dataType.elementType, numElements, s" $prettyName failed.") + + val functionCode = function.genCode(ctx) + + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) + val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") + + // Some expressions return internal buffers that we have to copy + val copy = if (CodeGenerator.isPrimitiveType(function.dataType)) { + s"${functionCode.value}" + } else { + s"InternalRow.copyValue(${functionCode.value})" + } + val resultNull = if (function.nullable) Some(functionCode.isNull.toString) else None + val resultAssignment = CodeGenerator.setArrayElement(arrayData, dataType.elementType, + i, copy, isNull = resultNull) + + s""" + |final int $numElements = ${arg}.numElements(); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | $varAssignments + | ${functionCode.code} + | $resultAssignment + |} + |${ev.value} = $arrayData; + """.stripMargin + }) + }) + } + override def nodeName: String = "transform" override protected def withNewChildrenInternal( @@ -581,7 +719,7 @@ case class MapFilter( case class ArrayFilter( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction { override def dataType: DataType = argument.dataType @@ -622,6 +760,67 @@ case class ArrayFilter( new GenericArrayData(buffer) } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, { lambdaExprs => + val elementCode = lambdaExprs.head + val indexCode = lambdaExprs.tail.headOption + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val count = ctx.freshName("count") + val arrayTracker = ctx.freshName("arrayTracker") + val arrayData = ctx.freshName("arrayData") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + + val arrayType = dataType.asInstanceOf[ArrayType] + + val trackerInit = CodeGenerator.createArrayData( + arrayTracker, BooleanType, numElements, s" $prettyName failed.") + val resultInit = CodeGenerator.createArrayData( + arrayData, arrayType.elementType, count, s" $prettyName failed.") + + val functionCode = function.genCode(ctx) + + val elementAtomic = ctx.addReferenceObj(elementVar.variableName, elementVar.value) + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) + val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") + + val resultAssignment = CodeGenerator.setArrayElement(arrayTracker, BooleanType, + i, functionCode.value, isNull = None) + + val getTrackerValue = CodeGenerator.getValue(arrayTracker, BooleanType, i) + val copy = CodeGenerator.createArrayAssignment(arrayData, arrayType.elementType, arg, + j, i, arrayType.containsNull) + + s""" + |final int $numElements = ${arg}.numElements(); + |$trackerInit + |int $count = 0; + |for (int $i = 0; $i < $numElements; $i++) { + | $varAssignments + | ${functionCode.code} + | $resultAssignment + | if ((boolean)${functionCode.value}) { + | $count++; + | } + |} + | + |$resultInit + |int $j = 0; + |for (int $i = 0; $i < $numElements; $i++) { + | if ($getTrackerValue) { + | $copy + | $j++; + | } + |} + |${ev.value} = $arrayData; + """.stripMargin + }) + }) + } + override def nodeName: String = "filter" override protected def withNewChildrenInternal( @@ -653,7 +852,7 @@ case class ArrayExists( argument: Expression, function: Expression, followThreeValuedLogic: Boolean) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback with Predicate { + extends ArrayBasedSimpleHigherOrderFunction with Predicate { def this(argument: Expression, function: Expression) = { this( @@ -706,6 +905,50 @@ case class ArrayExists( } } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar), { case Seq(elementCode) => + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val exists = ctx.freshName("exists") + val foundNull = ctx.freshName("foundNull") + val i = ctx.freshName("i") + + val functionCode = function.genCode(ctx) + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val threeWayLogic = if (followThreeValuedLogic) TrueLiteral else FalseLiteral + + val nullCheck = if (nullable) { + s""" + if ($threeWayLogic && !$exists && $foundNull) { + ${ev.isNull} = true; + } + """ + } else { + "" + } + + s""" + |final int $numElements = ${arg}.numElements(); + |boolean $exists = false; + |boolean $foundNull = false; + |int $i = 0; + |while ($i < $numElements && !$exists) { + | $elementAssignment + | ${functionCode.code} + | if (${functionCode.isNull}) { + | $foundNull = true; + | } else if (${functionCode.value}) { + | $exists = true; + | } + | $i++; + |} + |$nullCheck + |${ev.value} = $exists; + """.stripMargin + }) + }) + } + override def nodeName: String = "exists" override protected def withNewChildrenInternal( @@ -740,7 +983,7 @@ object ArrayExists { case class ArrayForAll( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback with Predicate { + extends ArrayBasedSimpleHigherOrderFunction with Predicate { override def nullable: Boolean = super.nullable || function.nullable @@ -785,6 +1028,49 @@ case class ArrayForAll( } } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar), { case Seq(elementCode) => + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val forall = ctx.freshName("forall") + val foundNull = ctx.freshName("foundNull") + val i = ctx.freshName("i") + + val functionCode = function.genCode(ctx) + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + + val nullCheck = if (nullable) { + s""" + if ($forall && $foundNull) { + ${ev.isNull} = true; + } + """ + } else { + "" + } + + s""" + |final int $numElements = ${arg}.numElements(); + |boolean $forall = true; + |boolean $foundNull = false; + |int $i = 0; + |while ($i < $numElements && $forall) { + | $elementAssignment + | ${functionCode.code} + | if (${functionCode.isNull}) { + | $foundNull = true; + | } else if (!${functionCode.value}) { + | $forall = false; + | } + | $i++; + |} + |$nullCheck + |${ev.value} = $forall; + """.stripMargin + }) + }) + } + override def nodeName: String = "forall" override protected def withNewChildrenInternal( @@ -816,7 +1102,7 @@ case class ArrayAggregate( zero: Expression, merge: Expression, finish: Expression) - extends HigherOrderFunction with CodegenFallback with QuaternaryLike[Expression] { + extends HigherOrderFunction with QuaternaryLike[Expression] { def this(argument: Expression, zero: Expression, merge: Expression) = { this(argument, zero, merge, LambdaFunction.identity) @@ -886,6 +1172,116 @@ case class ArrayAggregate( } } + protected def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: String => String): ExprCode = { + val argumentGen = argument.genCode(ctx) + val resultCode = f(argumentGen.value) + + if (nullable) { + val nullSafeEval = ctx.nullSafeExec(argument.nullable, argumentGen.isNull)(resultCode) + ev.copy(code = code""" + |${argumentGen.code} + |boolean ${ev.isNull} = ${argumentGen.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """) + } else { + ev.copy(code = code""" + |${argumentGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """, isNull = FalseLiteral) + } + } + + protected def assignVar(varCode: ExprCode, value: String, isNull: String, + nullable: Boolean): String = { + if (nullable) { + s""" + ${varCode.value} = $value; + ${varCode.isNull} = $isNull; + """ + } else { + s""" + ${varCode.value} = $value; + """ + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar, accForMergeVar, accForFinishVar), { varCodes => + val Seq(elementCode, accForMergeCode, accForFinishCode) = varCodes + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val i = ctx.freshName("i") + + val zeroCode = zero.genCode(ctx) + val mergeCode = merge.genCode(ctx) + val finishCode = finish.genCode(ctx) + + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val mergeAtomic = ctx.addReferenceObj(accForMergeVar.variableName, + accForMergeVar.value) + val finishAtomic = ctx.addReferenceObj(accForFinishVar.variableName, + accForFinishVar.value) + + val mergeJavaType = CodeGenerator.javaType(accForMergeVar.dataType) + val finishJavaType = CodeGenerator.javaType(accForFinishVar.dataType) + + // Some expressions return internal buffers that we have to copy + val mergeCopy = if (CodeGenerator.isPrimitiveType(merge.dataType)) { + s"${mergeCode.value}" + } else { + s"($mergeJavaType)InternalRow.copyValue(${mergeCode.value})" + } + + val nullCheck = if (nullable) { + s"${ev.isNull} = ${finishCode.isNull};" + } else { + "" + } + + val initialAssignment = assignVar(accForMergeCode, zeroCode.value, zeroCode.isNull, + zero.nullable) + val initialAtomic = assignAtomic(mergeAtomic, accForMergeCode.value, + accForMergeCode.isNull, merge.nullable) + + val mergeAssignment = assignVar(accForMergeCode, mergeCopy, + mergeCode.isNull, merge.nullable) + val mergeAtomicAssignment = assignAtomic(mergeAtomic, accForMergeCode.value, + accForMergeCode.isNull, merge.nullable) + + val finishAssignment = assignVar(accForFinishCode, accForMergeCode.value, + accForMergeCode.isNull, merge.nullable) + val finishAtomicAssignment = assignAtomic(finishAtomic, accForFinishCode.value, + accForFinishCode.isNull, merge.nullable) + + s""" + |final int $numElements = ${arg}.numElements(); + |${zeroCode.code} + |$initialAssignment + |$initialAtomic + | + |for (int $i = 0; $i < $numElements; $i++) { + | $elementAssignment + | ${mergeCode.code} + | $mergeAssignment + | $mergeAtomicAssignment + |} + | + |$finishAssignment + |$finishAtomicAssignment + |${finishCode.code} + |${ev.value} = ${finishCode.value}; + |$nullCheck + """.stripMargin + }) + }) + } + override def nodeName: String = "aggregate" override def first: Expression = argument diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 68d4fe690007..93e0e4b1f027 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -439,6 +439,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE s"failed to match ${toSQLId(funcName)} at `addNewFunction`.") } + def lambdaVariableAlreadyDefinedError(name: String): Throwable = { + new IllegalArgumentException(s"Lambda variable $name cannot be redefined") + } + + def lambdaVariableNotDefinedError(name: String): Throwable = { + new IllegalArgumentException( + s"Lambda variable $name is not defined in the current codegen scope") + } + def cannotGenerateCodeForIncomparableTypeError( codeType: String, dataType: DataType): Throwable = { SparkException.internalError( From 39ba077b61fe1c8a03da978e17db19cec5ffac78 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 4 Apr 2023 16:43:33 -0400 Subject: [PATCH 3/8] Remove unnecessary variableName and clean up some formatting --- .../expressions/codegen/CodeGenerator.scala | 19 ++++++------ .../expressions/higherOrderFunctions.scala | 30 ++++++++----------- .../sql/errors/QueryExecutionErrors.scala | 8 ++--- 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4d4c448cfcc5..d26b7f13720f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -177,14 +177,14 @@ class CodegenContext extends Logging { /** * Holding a map of current lambda variables. */ - var currentLambdaVars: mutable.Map[String, ExprCode] = mutable.HashMap.empty + var currentLambdaVars: mutable.Map[Long, ExprCode] = mutable.HashMap.empty def withLambdaVars(namedLambdas: Seq[NamedLambdaVariable], f: Seq[ExprCode] => ExprCode): ExprCode = { val lambdaVars = namedLambdas.map { namedLambda => - val name = namedLambda.variableName - if (currentLambdaVars.get(name).nonEmpty) { - throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(name) + val id = namedLambda.exprId.id + if (currentLambdaVars.get(id).nonEmpty) { + throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(id) } val isNull = if (namedLambda.nullable) { JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "lambdaIsNull")) @@ -193,19 +193,18 @@ class CodegenContext extends Logging { } val value = addMutableState(javaType(namedLambda.dataType), "lambdaValue") val lambdaVar = ExprCode(isNull, JavaCode.global(value, namedLambda.dataType)) - currentLambdaVars.put(name, lambdaVar) + currentLambdaVars.put(id, lambdaVar) lambdaVar } val result = f(lambdaVars) - namedLambdas.foreach(v => currentLambdaVars.remove(v.variableName)) + namedLambdas.map(_.exprId.id).foreach(currentLambdaVars.remove) result } - def getLambdaVar(name: String): ExprCode = { - currentLambdaVars.getOrElse(name, { - throw QueryExecutionErrors.lambdaVariableNotDefinedError(name) - }) + def getLambdaVar(id: Long): ExprCode = { + currentLambdaVars.getOrElse(id, + throw QueryExecutionErrors.lambdaVariableNotDefinedError(id)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 72999e0725f5..b90a24d7b793 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -104,12 +104,8 @@ case class NamedLambdaVariable( s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}" } - // We need to include the Expr ID in the Codegen variable name since several tests bypass - // `UnresolvedNamedLambdaVariable.freshVarName` - lazy val variableName = s"${name}_${exprId.id}" - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.getLambdaVar(variableName) + ctx.getLambdaVar(exprId.id) } } @@ -284,7 +280,7 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { protected def assignArrayElement(ctx: CodegenContext, arrayName: String, elementCode: ExprCode, elementVar: NamedLambdaVariable, index: String): String = { val elementType = elementVar.dataType - val elementAtomic = ctx.addReferenceObj(elementVar.variableName, elementVar.value) + val elementAtomic = ctx.addReferenceObj(elementVar.name, elementVar.value) val extractElement = CodeGenerator.getValue(arrayName, elementType, index) val atomicAssign = assignAtomic(elementAtomic, elementCode.value, elementCode.isNull, elementVar.nullable) @@ -305,7 +301,7 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { protected def assignIndex(ctx: CodegenContext, indexCode: ExprCode, indexVar: NamedLambdaVariable, index: String): String = { - val indexAtomic = ctx.addReferenceObj(indexVar.variableName, indexVar.value) + val indexAtomic = ctx.addReferenceObj(indexVar.name, indexVar.value) s""" ${indexCode.value} = $index; ${assignAtomic(indexAtomic, indexCode.value)} @@ -450,9 +446,9 @@ case class ArrayTransform( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.withLambdaVars(Seq(elementVar) ++ indexVar, { lambdaExprs => - val elementCode = lambdaExprs.head - val indexCode = lambdaExprs.tail.headOption + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, varCodes => { + val elementCode = varCodes.head + val indexCode = varCodes.tail.headOption nullSafeCodeGen(ctx, ev, arg => { val numElements = ctx.freshName("numElements") @@ -761,9 +757,9 @@ case class ArrayFilter( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.withLambdaVars(Seq(elementVar) ++ indexVar, { lambdaExprs => - val elementCode = lambdaExprs.head - val indexCode = lambdaExprs.tail.headOption + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, varCodes => { + val elementCode = varCodes.head + val indexCode = varCodes.tail.headOption nullSafeCodeGen(ctx, ev, arg => { val numElements = ctx.freshName("numElements") @@ -782,7 +778,7 @@ case class ArrayFilter( val functionCode = function.genCode(ctx) - val elementAtomic = ctx.addReferenceObj(elementVar.variableName, elementVar.value) + val elementAtomic = ctx.addReferenceObj(elementVar.name, elementVar.value) val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") @@ -1211,7 +1207,7 @@ case class ArrayAggregate( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.withLambdaVars(Seq(elementVar, accForMergeVar, accForFinishVar), { varCodes => + ctx.withLambdaVars(Seq(elementVar, accForMergeVar, accForFinishVar), varCodes => { val Seq(elementCode, accForMergeCode, accForFinishCode) = varCodes nullSafeCodeGen(ctx, ev, arg => { @@ -1223,9 +1219,9 @@ case class ArrayAggregate( val finishCode = finish.genCode(ctx) val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) - val mergeAtomic = ctx.addReferenceObj(accForMergeVar.variableName, + val mergeAtomic = ctx.addReferenceObj(accForMergeVar.name, accForMergeVar.value) - val finishAtomic = ctx.addReferenceObj(accForFinishVar.variableName, + val finishAtomic = ctx.addReferenceObj(accForFinishVar.name, accForFinishVar.value) val mergeJavaType = CodeGenerator.javaType(accForMergeVar.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 93e0e4b1f027..87baa3b59033 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -439,13 +439,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE s"failed to match ${toSQLId(funcName)} at `addNewFunction`.") } - def lambdaVariableAlreadyDefinedError(name: String): Throwable = { - new IllegalArgumentException(s"Lambda variable $name cannot be redefined") + def lambdaVariableAlreadyDefinedError(id: Long): Throwable = { + new IllegalArgumentException(s"Lambda variable $id cannot be redefined") } - def lambdaVariableNotDefinedError(name: String): Throwable = { + def lambdaVariableNotDefinedError(id: Long): Throwable = { new IllegalArgumentException( - s"Lambda variable $name is not defined in the current codegen scope") + s"Lambda variable $id is not defined in the current codegen scope") } def cannotGenerateCodeForIncomparableTypeError( From 196d6532cf3293d1ac33ed8136c63f1442c68f13 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Sun, 21 May 2023 11:38:42 -0400 Subject: [PATCH 4/8] Remove unnecessary extra variable copies --- .../expressions/higherOrderFunctions.scala | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index b90a24d7b793..ddb97fb915c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -137,20 +137,7 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val functionCode = function.genCode(ctx) - - if (nullable) { - ev.copy(code = code""" - |${functionCode.code} - |boolean ${ev.isNull} = ${functionCode.isNull}; - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value}; - """.stripMargin) - } else { - ev.copy(code = code""" - |${functionCode.code} - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value}; - """.stripMargin, isNull = FalseLiteral) - } + function.genCode(ctx) } override protected def withNewChildrenInternal( From 05d98937a98f75443f45a460845b656b7d715e0a Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Thu, 22 Jun 2023 20:31:11 -0400 Subject: [PATCH 5/8] Improve some styling --- .../expressions/codegen/CodeGenerator.scala | 16 ++++++------ .../expressions/higherOrderFunctions.scala | 25 ++++++++++++++----- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d26b7f13720f..d1d2e2f85804 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -179,20 +179,21 @@ class CodegenContext extends Logging { */ var currentLambdaVars: mutable.Map[Long, ExprCode] = mutable.HashMap.empty - def withLambdaVars(namedLambdas: Seq[NamedLambdaVariable], + def withLambdaVars( + namedLambdas: Seq[NamedLambdaVariable], f: Seq[ExprCode] => ExprCode): ExprCode = { - val lambdaVars = namedLambdas.map { namedLambda => - val id = namedLambda.exprId.id + val lambdaVars = namedLambdas.map { lambda => + val id = lambda.exprId.id if (currentLambdaVars.get(id).nonEmpty) { throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(id) } - val isNull = if (namedLambda.nullable) { + val isNull = if (lambda.nullable) { JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "lambdaIsNull")) } else { FalseLiteral } - val value = addMutableState(javaType(namedLambda.dataType), "lambdaValue") - val lambdaVar = ExprCode(isNull, JavaCode.global(value, namedLambda.dataType)) + val value = addMutableState(javaType(lambda.dataType), "lambdaValue") + val lambdaVar = ExprCode(isNull, JavaCode.global(value, lambda.dataType)) currentLambdaVars.put(id, lambdaVar) lambdaVar } @@ -203,7 +204,8 @@ class CodegenContext extends Logging { } def getLambdaVar(id: Long): ExprCode = { - currentLambdaVars.getOrElse(id, + currentLambdaVars.getOrElse( + id, throw QueryExecutionErrors.lambdaVariableNotDefinedError(id)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index ddb97fb915c4..4a4717ca4f3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -249,7 +249,10 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { } - protected def assignAtomic(atomicRef: String, value: String, isNull: String = FalseLiteral, + protected def assignAtomic( + atomicRef: String, + value: String, + isNull: String = FalseLiteral, nullable: Boolean = false) = { if (nullable) { s""" @@ -264,8 +267,12 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { } } - protected def assignArrayElement(ctx: CodegenContext, arrayName: String, elementCode: ExprCode, - elementVar: NamedLambdaVariable, index: String): String = { + protected def assignArrayElement( + ctx: CodegenContext, + arrayName: String, + elementCode: ExprCode, + elementVar: NamedLambdaVariable, + index: String): String = { val elementType = elementVar.dataType val elementAtomic = ctx.addReferenceObj(elementVar.name, elementVar.value) val extractElement = CodeGenerator.getValue(arrayName, elementType, index) @@ -286,8 +293,11 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { } } - protected def assignIndex(ctx: CodegenContext, indexCode: ExprCode, - indexVar: NamedLambdaVariable, index: String): String = { + protected def assignIndex( + ctx: CodegenContext, + indexCode: ExprCode, + indexVar: NamedLambdaVariable, + index: String): String = { val indexAtomic = ctx.addReferenceObj(indexVar.name, indexVar.value) s""" ${indexCode.value} = $index; @@ -1179,7 +1189,10 @@ case class ArrayAggregate( } } - protected def assignVar(varCode: ExprCode, value: String, isNull: String, + protected def assignVar( + varCode: ExprCode, + value: String, + isNull: String, nullable: Boolean): String = { if (nullable) { s""" From 832d9b819b08637a71ed874b535534ed8a4a9c8a Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 1 Oct 2024 07:13:29 -0400 Subject: [PATCH 6/8] Add tests for codegen fallback inside HOF --- .../HigherOrderFunctionsSuite.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index cc36cd73d6d7..bc608b7afecf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -149,6 +151,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val plusOne: Expression => Expression = x => x + 1 val plusIndex: (Expression, Expression) => Expression = (x, i) => x + i + val plusOneFallback: Expression => Expression = x => CodegenFallbackExpr(x + 1) checkEvaluation(transform(ai0, plusOne), Seq(2, 3, 4)) checkEvaluation(transform(ai0, plusIndex), Seq(1, 3, 5)) @@ -158,6 +161,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transform(transform(ai1, plusIndex), plusOne), Seq(2, null, 6)) checkEvaluation(transform(ain, plusOne), null) + checkEvaluation(transform(ai0, plusOneFallback), Seq(2, 3, 4)) + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) @@ -277,6 +282,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isEven: Expression => Expression = x => x % 2 === 0 val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val indexIsEven: (Expression, Expression) => Expression = { case (_, idx) => idx % 2 === 0 } + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) checkEvaluation(filter(ai0, isEven), Seq(2)) checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3)) @@ -286,6 +292,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(filter(ain, isEven), null) checkEvaluation(filter(ain, isNullOrOdd), null) + checkEvaluation(filter(ai0, isEvenFallback), Seq(2)) + val as0 = Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) @@ -321,6 +329,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType) + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) for (followThreeValuedLogic <- Seq(false, true)) { withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key @@ -337,6 +346,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(exists(ain, isNullOrOdd), null) checkEvaluation(exists(ain, alwaysFalse), null) checkEvaluation(exists(ain, alwaysNull), null) + checkEvaluation(exists(ai0, isEvenFallback), true) } } @@ -383,6 +393,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType) + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) checkEvaluation(forall(ai0, isEven), true) checkEvaluation(forall(ai0, isNullOrOdd), false) @@ -401,6 +412,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(forall(ain, alwaysFalse), null) checkEvaluation(forall(ain, alwaysNull), null) + checkEvaluation(forall(ai0, isEvenFallback), true) + val as0 = Literal.create(Seq("a0", "a1", "a2", "a3"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, containsNull = true)) @@ -886,3 +899,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ))) } } + +case class CodegenFallbackExpr(child: Expression) extends UnaryExpression with CodegenFallback { + override def nullable: Boolean = child.nullable + override def dataType: DataType = child.dataType + override lazy val resolved = child.resolved + override def eval(input: InternalRow): Any = child.eval(input) + override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpr = + copy(child = newChild) +} From 9b94470998f2fc23f8c12bf2d962d37559a36751 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Fri, 14 Mar 2025 18:28:26 +0000 Subject: [PATCH 7/8] Small cleanup --- .../expressions/higherOrderFunctions.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 4a4717ca4f3b..9222b585914e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -472,7 +472,7 @@ case class ArrayTransform( i, copy, isNull = resultNull) s""" - |final int $numElements = ${arg}.numElements(); + |final int $numElements = $arg.numElements(); |$initialization |for (int $i = 0; $i < $numElements; $i++) { | $varAssignments @@ -775,7 +775,6 @@ case class ArrayFilter( val functionCode = function.genCode(ctx) - val elementAtomic = ctx.addReferenceObj(elementVar.name, elementVar.value) val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") @@ -787,8 +786,14 @@ case class ArrayFilter( val copy = CodeGenerator.createArrayAssignment(arrayData, arrayType.elementType, arg, j, i, arrayType.containsNull) + // This takes a two passes to avoid evaluating the predicate multiple times + // The first pass evaluates each element in the array, tracks how many elements + // returned true, and tracks the result of each element in a boolean array `arrayTracker`. + // The second pass copies elements from the original array to the new array created + // based on the number of elements matching the first pass. + s""" - |final int $numElements = ${arg}.numElements(); + |final int $numElements = $arg.numElements(); |$trackerInit |int $count = 0; |for (int $i = 0; $i < $numElements; $i++) { @@ -1191,17 +1196,21 @@ case class ArrayAggregate( protected def assignVar( varCode: ExprCode, + atomicVar: String, value: String, isNull: String, nullable: Boolean): String = { + val atomicAssign = assignAtomic(atomicVar, value, isNull, nullable) if (nullable) { s""" ${varCode.value} = $value; ${varCode.isNull} = $isNull; + $atomicAssign """ } else { s""" ${varCode.value} = $value; + $atomicAssign """ } } @@ -1240,36 +1249,27 @@ case class ArrayAggregate( "" } - val initialAssignment = assignVar(accForMergeCode, zeroCode.value, zeroCode.isNull, - zero.nullable) - val initialAtomic = assignAtomic(mergeAtomic, accForMergeCode.value, - accForMergeCode.isNull, merge.nullable) + val initialAssignment = assignVar(accForMergeCode, mergeAtomic, zeroCode.value, + zeroCode.isNull, zero.nullable) - val mergeAssignment = assignVar(accForMergeCode, mergeCopy, + val mergeAssignment = assignVar(accForMergeCode, mergeAtomic, mergeCopy, mergeCode.isNull, merge.nullable) - val mergeAtomicAssignment = assignAtomic(mergeAtomic, accForMergeCode.value, - accForMergeCode.isNull, merge.nullable) - val finishAssignment = assignVar(accForFinishCode, accForMergeCode.value, + val finishAssignment = assignVar(accForFinishCode, finishAtomic, accForMergeCode.value, accForMergeCode.isNull, merge.nullable) - val finishAtomicAssignment = assignAtomic(finishAtomic, accForFinishCode.value, - accForFinishCode.isNull, merge.nullable) s""" |final int $numElements = ${arg}.numElements(); |${zeroCode.code} |$initialAssignment - |$initialAtomic | |for (int $i = 0; $i < $numElements; $i++) { | $elementAssignment | ${mergeCode.code} | $mergeAssignment - | $mergeAtomicAssignment |} | |$finishAssignment - |$finishAtomicAssignment |${finishCode.code} |${ev.value} = ${finishCode.value}; |$nullCheck From 8aeb23f8bc2f30af7b8eee8873b51f1bc1c627a0 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 24 Jun 2025 12:12:44 -0400 Subject: [PATCH 8/8] Add subexpression elimination to higher order functions --- .../expressions/codegen/CodeGenerator.scala | 22 +++++--- .../expressions/higherOrderFunctions.scala | 13 ++++- .../expressions/CodeGenerationSuite.scala | 30 ++++++----- .../HigherOrderFunctionsSuite.scala | 18 +++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 50 +++++++++++++++++++ 5 files changed, 109 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d1d2e2f85804..b87107e9a79f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1083,13 +1083,13 @@ class CodegenContext extends Logging { /** * Perform a function which generates a sequence of ExprCodes with a given mapping between - * expressions and common expressions, instead of using the mapping in current context. + * expressions and common expressions. Restores previous mapping after execution. */ def withSubExprEliminationExprs( newSubExprEliminationExprs: Map[ExpressionEquals, SubExprEliminationState])( f: => Seq[ExprCode]): Seq[ExprCode] = { val oldsubExprEliminationExprs = subExprEliminationExprs - subExprEliminationExprs = newSubExprEliminationExprs + subExprEliminationExprs = oldsubExprEliminationExprs ++ newSubExprEliminationExprs val genCodes = f @@ -1150,7 +1150,9 @@ class CodegenContext extends Logging { * (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression * evaluation, we can look for generated subexpressions and do replacement. */ - def subexpressionElimination(expressions: Seq[Expression]): SubExprCodes = { + def subexpressionElimination( + expressions: Seq[Expression], + variablePrefix: String = ""): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions val localSubExprEliminationExprs = @@ -1161,7 +1163,13 @@ class CodegenContext extends Logging { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. + // + // Filter out any expressions that are already existing subexpressions. This can happen + // when finding common subexpressions inside a lambda function, and the common expression + // does not reference the lambda variables for that function, but top level attributes or + // outer lambda variables. val commonExprs = equivalentExpressions.getCommonSubexpressions + .filter(e => !subExprEliminationExprs.contains(ExpressionEquals(e))) val nonSplitCode = { val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState] @@ -1169,14 +1177,14 @@ class CodegenContext extends Logging { withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { val eval = expr.genCode(this) - val value = addMutableState(javaType(expr.dataType), "subExprValue") + val value = addMutableState(javaType(expr.dataType), s"${variablePrefix}subExprValue") val isNullLiteral = eval.isNull match { case TrueLiteral | FalseLiteral => true case _ => false } val (isNull, isNullEvalCode) = if (!isNullLiteral) { - val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + val v = addMutableState(JAVA_BOOLEAN, s"${variablePrefix}subExprIsNull") (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") } else { (eval.isNull, "") @@ -1191,7 +1199,7 @@ class CodegenContext extends Logging { // Collects other subexpressions from the children. val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] expr.foreach { e => - subExprEliminationExprs.get(ExpressionEquals(e)) match { + localSubExprEliminationExprs.get(ExpressionEquals(e)) match { case Some(state) => childrenSubExprs += state case _ => } @@ -1282,7 +1290,7 @@ class CodegenContext extends Logging { if (doSubexpressionElimination) { val subExprs = subexpressionElimination(cleanedExpressions) val generatedExprs = withSubExprEliminationExprs(subExprs.states) { - cleanedExpressions.map(e => e.genCode(this)) + cleanedExpressions.map(e => e.genCode(this)) } val subExprCode = evaluateSubExprEliminationState(subExprs.states.values) (generatedExprs, subExprCode) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 9222b585914e..beb2a3ac490a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -137,7 +137,18 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - function.genCode(ctx) + val subExprCodes = ctx.subexpressionElimination(Seq(function), "lambda_") + + val functionCode = ctx.withSubExprEliminationExprs(subExprCodes.states) { + Seq(function.genCode(ctx)) + }.head + + val subExprEval = ctx.evaluateSubExprEliminationState(subExprCodes.states.values) + functionCode.copy(code = code""" + |// lambda common sub-expressions + |$subExprEval + |${functionCode.code} + """) } override protected def withNewChildrenInternal( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 4bbbc368010a..2b61b85ad815 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -473,24 +473,22 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { JavaCode.variable("dummy", BooleanType))) // raw testing of basic functionality - { - val ctx = new CodegenContext - val e = ref.genCode(ctx) - // before - ctx.subExprEliminationExprs += wrap(ref) -> SubExprEliminationState( - ExprCode(EmptyBlock, e.isNull, e.value)) - assert(ctx.subExprEliminationExprs.contains(wrap(ref))) - // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(wrap(add1) -> dummy)) { - assert(ctx.subExprEliminationExprs.contains(wrap(add1))) - assert(!ctx.subExprEliminationExprs.contains(wrap(ref))) - Seq.empty - } - // after - assert(ctx.subExprEliminationExprs.nonEmpty) + val ctx = new CodegenContext + val e = ref.genCode(ctx) + // before + ctx.subExprEliminationExprs += wrap(ref) -> SubExprEliminationState( + ExprCode(EmptyBlock, e.isNull, e.value)) + assert(ctx.subExprEliminationExprs.contains(wrap(ref))) + // call withSubExprEliminationExprs, should now contain both + ctx.withSubExprEliminationExprs(Map(wrap(add1) -> dummy)) { + assert(ctx.subExprEliminationExprs.contains(wrap(add1))) assert(ctx.subExprEliminationExprs.contains(wrap(ref))) - assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) + Seq.empty } + // after, should only contain the original + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(wrap(ref))) + assert(!ctx.subExprEliminationExprs.contains(wrap(add1))) } test("SPARK-23986: freshName can generate duplicated names") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index bc608b7afecf..3f3782733eda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -151,10 +151,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val plusOne: Expression => Expression = x => x + 1 val plusIndex: (Expression, Expression) => Expression = (x, i) => x + i + val plusIndexRepeated: (Expression, Expression) => Expression = + (x, i) => plusIndex(x, i) * plusIndex(x, i) val plusOneFallback: Expression => Expression = x => CodegenFallbackExpr(x + 1) checkEvaluation(transform(ai0, plusOne), Seq(2, 3, 4)) checkEvaluation(transform(ai0, plusIndex), Seq(1, 3, 5)) + checkEvaluation(transform(ai0, plusIndexRepeated), Seq(1, 9, 25)) checkEvaluation(transform(transform(ai0, plusIndex), plusOne), Seq(2, 4, 6)) checkEvaluation(transform(ai1, plusOne), Seq(2, null, 4)) checkEvaluation(transform(ai1, plusIndex), Seq(1, null, 5)) @@ -282,11 +285,14 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isEven: Expression => Expression = x => x % 2 === 0 val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val indexIsEven: (Expression, Expression) => Expression = { case (_, idx) => idx % 2 === 0 } + val plusIndexRepeatedEven: (Expression, Expression) => Expression = + (x, i) => ((x + i) * (x + i)) % 2 === 0 val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) checkEvaluation(filter(ai0, isEven), Seq(2)) checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3)) checkEvaluation(filter(ai0, indexIsEven), Seq(1, 3)) + checkEvaluation(filter(ai0, plusIndexRepeatedEven), Seq.empty) checkEvaluation(filter(ai1, isEven), Seq.empty) checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3)) checkEvaluation(filter(ain, isEven), null) @@ -329,6 +335,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType) + val squareRepeatedEven: Expression => Expression = + x => ((x * x) + (x * x)) % 2 === 0 val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) for (followThreeValuedLogic <- Seq(false, true)) { @@ -338,6 +346,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(exists(ai0, isNullOrOdd), true) checkEvaluation(exists(ai0, alwaysFalse), false) checkEvaluation(exists(ai0, alwaysNull), if (followThreeValuedLogic) null else false) + checkEvaluation(exists(ai0, squareRepeatedEven), true) checkEvaluation(exists(ai1, isEven), if (followThreeValuedLogic) null else false) checkEvaluation(exists(ai1, isNullOrOdd), true) checkEvaluation(exists(ai1, alwaysFalse), false) @@ -393,12 +402,15 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType) + val squareRepeatedEven: Expression => Expression = + x => ((x * x) + (x * x)) % 2 === 0 val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) checkEvaluation(forall(ai0, isEven), true) checkEvaluation(forall(ai0, isNullOrOdd), false) checkEvaluation(forall(ai0, alwaysFalse), false) checkEvaluation(forall(ai0, alwaysNull), null) + checkEvaluation(forall(ai0, squareRepeatedEven), true) checkEvaluation(forall(ai1, isEven), false) checkEvaluation(forall(ai1, isNullOrOdd), true) checkEvaluation(forall(ai1, alwaysFalse), false) @@ -441,6 +453,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), acc => acc * 10), 40) checkEvaluation(aggregate(ai2, 0, (acc, elem) => acc + elem, acc => acc * 10), 0) checkEvaluation(aggregate(ain, 0, (acc, elem) => acc + elem, acc => acc * 10), null) + checkEvaluation(aggregate( + ai0, + 1, + (acc, elem) => (acc * elem) + (acc * elem), + acc => (acc * acc) + (acc * acc) + ), 4608) val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fc6d3023ed07..13b524646ea7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3790,6 +3790,56 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { testArrayOfPrimitiveTypeContainsNull() } + test("transform function - subexpression elimination") { + val df = Seq[Seq[Integer]]( + Seq(1, 2, 3, 4, 5) + ).toDF("i") + + var count = spark.sparkContext.longAccumulator + val func = udf((x: Integer) => { + count.add(1) + x + }) + + val result = df.select( + transform(col("i"), x => func(x) + func(x)) + ) + + // Run it once to verify the count of UDF calls + result.collect() + assert(count.value == 5) + + checkAnswer(result, Seq(Row(Seq(2, 4, 6, 8, 10)))) + } + + test("transform function - subexpression elimination inside and outside lambda") { + val df = spark.read.json(Seq( + """ + { + "outer": { + "inner": { + "a": 1, + "b": 2, + "c": 3 + } + }, + "arr": [ + 1, + 2, + 3 + ] + } + """).toDS()) + + val result = df.select( + col("outer.inner.b"), + col("outer.inner.c"), + transform(col("arr"), x => x + col("outer.inner.a") + col("outer.inner.a")) + ) + + checkAnswer(result, Seq(Row(2, 3, Seq(3, 4, 5)))) + } + test("transform function - array for non-primitive type") { val df = Seq( Seq("c", "a", "b"),