From a9dbcd509159486924e26855dd02c95e5dbb9e4b Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Sat, 7 May 2022 10:41:03 -0400 Subject: [PATCH 1/8] Add codegen support to array functions --- .../expressions/EquivalentExpressions.scala | 4 + .../expressions/codegen/CodeGenerator.scala | 34 ++ .../expressions/higherOrderFunctions.scala | 412 +++++++++++++++++- .../sql/errors/QueryExecutionErrors.scala | 9 + 4 files changed, 451 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 78f73f8778b8..43d29ab27e15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -146,9 +146,13 @@ class EquivalentExpressions( // There are some special expressions that we should not recurse into all of its children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) // 2. ConditionalExpression: use its children that will always be evaluated. + // 3. HigherOrderFunction: lambda functions operate in the context of local lambdas and can't + // be called outside of that scope, only the arguments can be evaluated ahead of + // time. private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { case _: CodegenFallback => Nil case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut) + case h: HigherOrderFunction => h.arguments case other => skipForShortcut(other).children } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 13b1d329f7ec..beeee7cbae7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -174,6 +174,40 @@ class CodegenContext extends Logging { */ var currentVars: Seq[ExprCode] = null + /** + * Holding a map of current lambda variables. + */ + var currentLambdaVars: mutable.Map[String, ExprCode] = mutable.HashMap.empty + + def withLambdaVars(namedLambdas: Seq[NamedLambdaVariable], + f: Seq[ExprCode] => ExprCode): ExprCode = { + val lambdaVars = namedLambdas.map { namedLambda => + val name = namedLambda.variableName + if (currentLambdaVars.get(name).nonEmpty) { + throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(name) + } + val isNull = if (namedLambda.nullable) { + JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "lambdaIsNull")) + } else { + FalseLiteral + } + val value = addMutableState(javaType(namedLambda.dataType), "lambdaValue") + val lambdaVar = ExprCode(isNull, JavaCode.global(value, namedLambda.dataType)) + currentLambdaVars.put(name, lambdaVar) + lambdaVar + } + + val result = f(lambdaVars) + namedLambdas.foreach(v => currentLambdaVars.remove(v.variableName)) + result + } + + def getLambdaVar(name: String): ExprCode = { + currentLambdaVars.getOrElse(name, { + throw QueryExecutionErrors.lambdaVariableNotDefinedError(name) + }) + } + /** * Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a * 2-tuple: java type, variable name. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 2a5a38e93706..72999e0725f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -81,8 +82,7 @@ case class NamedLambdaVariable( exprId: ExprId = NamedExpression.newExprId, value: AtomicReference[Any] = new AtomicReference()) extends LeafExpression - with NamedExpression - with CodegenFallback { + with NamedExpression { override def qualifier: Seq[String] = Seq.empty @@ -103,6 +103,14 @@ case class NamedLambdaVariable( override def simpleString(maxFields: Int): String = { s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}" } + + // We need to include the Expr ID in the Codegen variable name since several tests bypass + // `UnresolvedNamedLambdaVariable.freshVarName` + lazy val variableName = s"${name}_${exprId.id}" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.getLambdaVar(variableName) + } } /** @@ -114,7 +122,7 @@ case class LambdaFunction( function: Expression, arguments: Seq[NamedExpression], hidden: Boolean = false) - extends Expression with CodegenFallback { + extends Expression { override def children: Seq[Expression] = function +: arguments override def dataType: DataType = function.dataType @@ -132,6 +140,23 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val functionCode = function.genCode(ctx) + + if (nullable) { + ev.copy(code = code""" + |${functionCode.code} + |boolean ${ev.isNull} = ${functionCode.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value}; + """.stripMargin) + } else { + ev.copy(code = code""" + |${functionCode.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value}; + """.stripMargin, isNull = FalseLiteral) + } + } + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): LambdaFunction = copy( @@ -239,6 +264,53 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { val canonicalizedChildren = cleaned.children.map(_.canonicalized) withNewChildren(canonicalizedChildren) } + + + protected def assignAtomic(atomicRef: String, value: String, isNull: String = FalseLiteral, + nullable: Boolean = false) = { + if (nullable) { + s""" + if ($isNull) { + $atomicRef.set(null); + } else { + $atomicRef.set($value); + } + """ + } else { + s"$atomicRef.set($value);" + } + } + + protected def assignArrayElement(ctx: CodegenContext, arrayName: String, elementCode: ExprCode, + elementVar: NamedLambdaVariable, index: String): String = { + val elementType = elementVar.dataType + val elementAtomic = ctx.addReferenceObj(elementVar.variableName, elementVar.value) + val extractElement = CodeGenerator.getValue(arrayName, elementType, index) + val atomicAssign = assignAtomic(elementAtomic, elementCode.value, + elementCode.isNull, elementVar.nullable) + + if (elementVar.nullable) { + s""" + ${elementCode.value} = $extractElement; + ${elementCode.isNull} = $arrayName.isNullAt($index); + $atomicAssign + """ + } else { + s""" + ${elementCode.value} = $extractElement; + $atomicAssign + """ + } + } + + protected def assignIndex(ctx: CodegenContext, indexCode: ExprCode, + indexVar: NamedLambdaVariable, index: String): String = { + val indexAtomic = ctx.addReferenceObj(indexVar.variableName, indexVar.value) + s""" + ${indexCode.value} = $index; + ${assignAtomic(indexAtomic, indexCode.value)} + """ + } } /** @@ -284,6 +356,29 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expr } } + protected def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: String => String): ExprCode = { + val argumentGen = argument.genCode(ctx) + val resultCode = f(argumentGen.value) + + if (nullable) { + val nullSafeEval = ctx.nullSafeExec(argument.nullable, argumentGen.isNull)(resultCode) + ev.copy(code = code""" + |${argumentGen.code} + |boolean ${ev.isNull} = ${argumentGen.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """) + } else { + ev.copy(code = code""" + |${argumentGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """, isNull = FalseLiteral) + } + } } trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { @@ -312,7 +407,7 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { case class ArrayTransform( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction { override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) @@ -354,6 +449,49 @@ case class ArrayTransform( result } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, { lambdaExprs => + val elementCode = lambdaExprs.head + val indexCode = lambdaExprs.tail.headOption + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + val i = ctx.freshName("i") + + val initialization = CodeGenerator.createArrayData( + arrayData, dataType.elementType, numElements, s" $prettyName failed.") + + val functionCode = function.genCode(ctx) + + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) + val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") + + // Some expressions return internal buffers that we have to copy + val copy = if (CodeGenerator.isPrimitiveType(function.dataType)) { + s"${functionCode.value}" + } else { + s"InternalRow.copyValue(${functionCode.value})" + } + val resultNull = if (function.nullable) Some(functionCode.isNull.toString) else None + val resultAssignment = CodeGenerator.setArrayElement(arrayData, dataType.elementType, + i, copy, isNull = resultNull) + + s""" + |final int $numElements = ${arg}.numElements(); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | $varAssignments + | ${functionCode.code} + | $resultAssignment + |} + |${ev.value} = $arrayData; + """.stripMargin + }) + }) + } + override def nodeName: String = "transform" override protected def withNewChildrenInternal( @@ -581,7 +719,7 @@ case class MapFilter( case class ArrayFilter( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction { override def dataType: DataType = argument.dataType @@ -622,6 +760,67 @@ case class ArrayFilter( new GenericArrayData(buffer) } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, { lambdaExprs => + val elementCode = lambdaExprs.head + val indexCode = lambdaExprs.tail.headOption + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val count = ctx.freshName("count") + val arrayTracker = ctx.freshName("arrayTracker") + val arrayData = ctx.freshName("arrayData") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + + val arrayType = dataType.asInstanceOf[ArrayType] + + val trackerInit = CodeGenerator.createArrayData( + arrayTracker, BooleanType, numElements, s" $prettyName failed.") + val resultInit = CodeGenerator.createArrayData( + arrayData, arrayType.elementType, count, s" $prettyName failed.") + + val functionCode = function.genCode(ctx) + + val elementAtomic = ctx.addReferenceObj(elementVar.variableName, elementVar.value) + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) + val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") + + val resultAssignment = CodeGenerator.setArrayElement(arrayTracker, BooleanType, + i, functionCode.value, isNull = None) + + val getTrackerValue = CodeGenerator.getValue(arrayTracker, BooleanType, i) + val copy = CodeGenerator.createArrayAssignment(arrayData, arrayType.elementType, arg, + j, i, arrayType.containsNull) + + s""" + |final int $numElements = ${arg}.numElements(); + |$trackerInit + |int $count = 0; + |for (int $i = 0; $i < $numElements; $i++) { + | $varAssignments + | ${functionCode.code} + | $resultAssignment + | if ((boolean)${functionCode.value}) { + | $count++; + | } + |} + | + |$resultInit + |int $j = 0; + |for (int $i = 0; $i < $numElements; $i++) { + | if ($getTrackerValue) { + | $copy + | $j++; + | } + |} + |${ev.value} = $arrayData; + """.stripMargin + }) + }) + } + override def nodeName: String = "filter" override protected def withNewChildrenInternal( @@ -653,7 +852,7 @@ case class ArrayExists( argument: Expression, function: Expression, followThreeValuedLogic: Boolean) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback with Predicate { + extends ArrayBasedSimpleHigherOrderFunction with Predicate { def this(argument: Expression, function: Expression) = { this( @@ -706,6 +905,50 @@ case class ArrayExists( } } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar), { case Seq(elementCode) => + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val exists = ctx.freshName("exists") + val foundNull = ctx.freshName("foundNull") + val i = ctx.freshName("i") + + val functionCode = function.genCode(ctx) + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val threeWayLogic = if (followThreeValuedLogic) TrueLiteral else FalseLiteral + + val nullCheck = if (nullable) { + s""" + if ($threeWayLogic && !$exists && $foundNull) { + ${ev.isNull} = true; + } + """ + } else { + "" + } + + s""" + |final int $numElements = ${arg}.numElements(); + |boolean $exists = false; + |boolean $foundNull = false; + |int $i = 0; + |while ($i < $numElements && !$exists) { + | $elementAssignment + | ${functionCode.code} + | if (${functionCode.isNull}) { + | $foundNull = true; + | } else if (${functionCode.value}) { + | $exists = true; + | } + | $i++; + |} + |$nullCheck + |${ev.value} = $exists; + """.stripMargin + }) + }) + } + override def nodeName: String = "exists" override protected def withNewChildrenInternal( @@ -740,7 +983,7 @@ object ArrayExists { case class ArrayForAll( argument: Expression, function: Expression) - extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback with Predicate { + extends ArrayBasedSimpleHigherOrderFunction with Predicate { override def nullable: Boolean = super.nullable || function.nullable @@ -785,6 +1028,49 @@ case class ArrayForAll( } } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar), { case Seq(elementCode) => + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val forall = ctx.freshName("forall") + val foundNull = ctx.freshName("foundNull") + val i = ctx.freshName("i") + + val functionCode = function.genCode(ctx) + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + + val nullCheck = if (nullable) { + s""" + if ($forall && $foundNull) { + ${ev.isNull} = true; + } + """ + } else { + "" + } + + s""" + |final int $numElements = ${arg}.numElements(); + |boolean $forall = true; + |boolean $foundNull = false; + |int $i = 0; + |while ($i < $numElements && $forall) { + | $elementAssignment + | ${functionCode.code} + | if (${functionCode.isNull}) { + | $foundNull = true; + | } else if (!${functionCode.value}) { + | $forall = false; + | } + | $i++; + |} + |$nullCheck + |${ev.value} = $forall; + """.stripMargin + }) + }) + } + override def nodeName: String = "forall" override protected def withNewChildrenInternal( @@ -816,7 +1102,7 @@ case class ArrayAggregate( zero: Expression, merge: Expression, finish: Expression) - extends HigherOrderFunction with CodegenFallback with QuaternaryLike[Expression] { + extends HigherOrderFunction with QuaternaryLike[Expression] { def this(argument: Expression, zero: Expression, merge: Expression) = { this(argument, zero, merge, LambdaFunction.identity) @@ -886,6 +1172,116 @@ case class ArrayAggregate( } } + protected def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: String => String): ExprCode = { + val argumentGen = argument.genCode(ctx) + val resultCode = f(argumentGen.value) + + if (nullable) { + val nullSafeEval = ctx.nullSafeExec(argument.nullable, argumentGen.isNull)(resultCode) + ev.copy(code = code""" + |${argumentGen.code} + |boolean ${ev.isNull} = ${argumentGen.isNull}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """) + } else { + ev.copy(code = code""" + |${argumentGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """, isNull = FalseLiteral) + } + } + + protected def assignVar(varCode: ExprCode, value: String, isNull: String, + nullable: Boolean): String = { + if (nullable) { + s""" + ${varCode.value} = $value; + ${varCode.isNull} = $isNull; + """ + } else { + s""" + ${varCode.value} = $value; + """ + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.withLambdaVars(Seq(elementVar, accForMergeVar, accForFinishVar), { varCodes => + val Seq(elementCode, accForMergeCode, accForFinishCode) = varCodes + + nullSafeCodeGen(ctx, ev, arg => { + val numElements = ctx.freshName("numElements") + val i = ctx.freshName("i") + + val zeroCode = zero.genCode(ctx) + val mergeCode = merge.genCode(ctx) + val finishCode = finish.genCode(ctx) + + val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) + val mergeAtomic = ctx.addReferenceObj(accForMergeVar.variableName, + accForMergeVar.value) + val finishAtomic = ctx.addReferenceObj(accForFinishVar.variableName, + accForFinishVar.value) + + val mergeJavaType = CodeGenerator.javaType(accForMergeVar.dataType) + val finishJavaType = CodeGenerator.javaType(accForFinishVar.dataType) + + // Some expressions return internal buffers that we have to copy + val mergeCopy = if (CodeGenerator.isPrimitiveType(merge.dataType)) { + s"${mergeCode.value}" + } else { + s"($mergeJavaType)InternalRow.copyValue(${mergeCode.value})" + } + + val nullCheck = if (nullable) { + s"${ev.isNull} = ${finishCode.isNull};" + } else { + "" + } + + val initialAssignment = assignVar(accForMergeCode, zeroCode.value, zeroCode.isNull, + zero.nullable) + val initialAtomic = assignAtomic(mergeAtomic, accForMergeCode.value, + accForMergeCode.isNull, merge.nullable) + + val mergeAssignment = assignVar(accForMergeCode, mergeCopy, + mergeCode.isNull, merge.nullable) + val mergeAtomicAssignment = assignAtomic(mergeAtomic, accForMergeCode.value, + accForMergeCode.isNull, merge.nullable) + + val finishAssignment = assignVar(accForFinishCode, accForMergeCode.value, + accForMergeCode.isNull, merge.nullable) + val finishAtomicAssignment = assignAtomic(finishAtomic, accForFinishCode.value, + accForFinishCode.isNull, merge.nullable) + + s""" + |final int $numElements = ${arg}.numElements(); + |${zeroCode.code} + |$initialAssignment + |$initialAtomic + | + |for (int $i = 0; $i < $numElements; $i++) { + | $elementAssignment + | ${mergeCode.code} + | $mergeAssignment + | $mergeAtomicAssignment + |} + | + |$finishAssignment + |$finishAtomicAssignment + |${finishCode.code} + |${ev.value} = ${finishCode.value}; + |$nullCheck + """.stripMargin + }) + }) + } + override def nodeName: String = "aggregate" override def first: Expression = argument diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index fd687cc6d5c5..11b76c3cdd9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -447,6 +447,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE s"failed to match ${toSQLId(funcName)} at `addNewFunction`.") } + def lambdaVariableAlreadyDefinedError(name: String): Throwable = { + new IllegalArgumentException(s"Lambda variable $name cannot be redefined") + } + + def lambdaVariableNotDefinedError(name: String): Throwable = { + new IllegalArgumentException( + s"Lambda variable $name is not defined in the current codegen scope") + } + def cannotGenerateCodeForIncomparableTypeError( codeType: String, dataType: DataType): Throwable = { SparkException.internalError( From dfd63811bc801f38a435c0e025640b12d0f7f6bc Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 4 Apr 2023 16:43:33 -0400 Subject: [PATCH 2/8] Remove unnecessary variableName and clean up some formatting --- .../expressions/codegen/CodeGenerator.scala | 19 ++++++------ .../expressions/higherOrderFunctions.scala | 30 ++++++++----------- .../sql/errors/QueryExecutionErrors.scala | 8 ++--- 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index beeee7cbae7d..d0c0247ea299 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -177,14 +177,14 @@ class CodegenContext extends Logging { /** * Holding a map of current lambda variables. */ - var currentLambdaVars: mutable.Map[String, ExprCode] = mutable.HashMap.empty + var currentLambdaVars: mutable.Map[Long, ExprCode] = mutable.HashMap.empty def withLambdaVars(namedLambdas: Seq[NamedLambdaVariable], f: Seq[ExprCode] => ExprCode): ExprCode = { val lambdaVars = namedLambdas.map { namedLambda => - val name = namedLambda.variableName - if (currentLambdaVars.get(name).nonEmpty) { - throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(name) + val id = namedLambda.exprId.id + if (currentLambdaVars.get(id).nonEmpty) { + throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(id) } val isNull = if (namedLambda.nullable) { JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "lambdaIsNull")) @@ -193,19 +193,18 @@ class CodegenContext extends Logging { } val value = addMutableState(javaType(namedLambda.dataType), "lambdaValue") val lambdaVar = ExprCode(isNull, JavaCode.global(value, namedLambda.dataType)) - currentLambdaVars.put(name, lambdaVar) + currentLambdaVars.put(id, lambdaVar) lambdaVar } val result = f(lambdaVars) - namedLambdas.foreach(v => currentLambdaVars.remove(v.variableName)) + namedLambdas.map(_.exprId.id).foreach(currentLambdaVars.remove) result } - def getLambdaVar(name: String): ExprCode = { - currentLambdaVars.getOrElse(name, { - throw QueryExecutionErrors.lambdaVariableNotDefinedError(name) - }) + def getLambdaVar(id: Long): ExprCode = { + currentLambdaVars.getOrElse(id, + throw QueryExecutionErrors.lambdaVariableNotDefinedError(id)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 72999e0725f5..b90a24d7b793 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -104,12 +104,8 @@ case class NamedLambdaVariable( s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}" } - // We need to include the Expr ID in the Codegen variable name since several tests bypass - // `UnresolvedNamedLambdaVariable.freshVarName` - lazy val variableName = s"${name}_${exprId.id}" - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.getLambdaVar(variableName) + ctx.getLambdaVar(exprId.id) } } @@ -284,7 +280,7 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { protected def assignArrayElement(ctx: CodegenContext, arrayName: String, elementCode: ExprCode, elementVar: NamedLambdaVariable, index: String): String = { val elementType = elementVar.dataType - val elementAtomic = ctx.addReferenceObj(elementVar.variableName, elementVar.value) + val elementAtomic = ctx.addReferenceObj(elementVar.name, elementVar.value) val extractElement = CodeGenerator.getValue(arrayName, elementType, index) val atomicAssign = assignAtomic(elementAtomic, elementCode.value, elementCode.isNull, elementVar.nullable) @@ -305,7 +301,7 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { protected def assignIndex(ctx: CodegenContext, indexCode: ExprCode, indexVar: NamedLambdaVariable, index: String): String = { - val indexAtomic = ctx.addReferenceObj(indexVar.variableName, indexVar.value) + val indexAtomic = ctx.addReferenceObj(indexVar.name, indexVar.value) s""" ${indexCode.value} = $index; ${assignAtomic(indexAtomic, indexCode.value)} @@ -450,9 +446,9 @@ case class ArrayTransform( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.withLambdaVars(Seq(elementVar) ++ indexVar, { lambdaExprs => - val elementCode = lambdaExprs.head - val indexCode = lambdaExprs.tail.headOption + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, varCodes => { + val elementCode = varCodes.head + val indexCode = varCodes.tail.headOption nullSafeCodeGen(ctx, ev, arg => { val numElements = ctx.freshName("numElements") @@ -761,9 +757,9 @@ case class ArrayFilter( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.withLambdaVars(Seq(elementVar) ++ indexVar, { lambdaExprs => - val elementCode = lambdaExprs.head - val indexCode = lambdaExprs.tail.headOption + ctx.withLambdaVars(Seq(elementVar) ++ indexVar, varCodes => { + val elementCode = varCodes.head + val indexCode = varCodes.tail.headOption nullSafeCodeGen(ctx, ev, arg => { val numElements = ctx.freshName("numElements") @@ -782,7 +778,7 @@ case class ArrayFilter( val functionCode = function.genCode(ctx) - val elementAtomic = ctx.addReferenceObj(elementVar.variableName, elementVar.value) + val elementAtomic = ctx.addReferenceObj(elementVar.name, elementVar.value) val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") @@ -1211,7 +1207,7 @@ case class ArrayAggregate( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.withLambdaVars(Seq(elementVar, accForMergeVar, accForFinishVar), { varCodes => + ctx.withLambdaVars(Seq(elementVar, accForMergeVar, accForFinishVar), varCodes => { val Seq(elementCode, accForMergeCode, accForFinishCode) = varCodes nullSafeCodeGen(ctx, ev, arg => { @@ -1223,9 +1219,9 @@ case class ArrayAggregate( val finishCode = finish.genCode(ctx) val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) - val mergeAtomic = ctx.addReferenceObj(accForMergeVar.variableName, + val mergeAtomic = ctx.addReferenceObj(accForMergeVar.name, accForMergeVar.value) - val finishAtomic = ctx.addReferenceObj(accForFinishVar.variableName, + val finishAtomic = ctx.addReferenceObj(accForFinishVar.name, accForFinishVar.value) val mergeJavaType = CodeGenerator.javaType(accForMergeVar.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 11b76c3cdd9f..4ab76c0c684c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -447,13 +447,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE s"failed to match ${toSQLId(funcName)} at `addNewFunction`.") } - def lambdaVariableAlreadyDefinedError(name: String): Throwable = { - new IllegalArgumentException(s"Lambda variable $name cannot be redefined") + def lambdaVariableAlreadyDefinedError(id: Long): Throwable = { + new IllegalArgumentException(s"Lambda variable $id cannot be redefined") } - def lambdaVariableNotDefinedError(name: String): Throwable = { + def lambdaVariableNotDefinedError(id: Long): Throwable = { new IllegalArgumentException( - s"Lambda variable $name is not defined in the current codegen scope") + s"Lambda variable $id is not defined in the current codegen scope") } def cannotGenerateCodeForIncomparableTypeError( From 0c0e0378f25f6252987d96ecab6444b0cc13f363 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Sun, 21 May 2023 11:38:42 -0400 Subject: [PATCH 3/8] Remove unnecessary extra variable copies --- .../expressions/higherOrderFunctions.scala | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index b90a24d7b793..ddb97fb915c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -137,20 +137,7 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val functionCode = function.genCode(ctx) - - if (nullable) { - ev.copy(code = code""" - |${functionCode.code} - |boolean ${ev.isNull} = ${functionCode.isNull}; - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value}; - """.stripMargin) - } else { - ev.copy(code = code""" - |${functionCode.code} - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value}; - """.stripMargin, isNull = FalseLiteral) - } + function.genCode(ctx) } override protected def withNewChildrenInternal( From 40c7246ac6b9aab038dc09a23bd2179889425696 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Thu, 22 Jun 2023 20:31:11 -0400 Subject: [PATCH 4/8] Improve some styling --- .../expressions/codegen/CodeGenerator.scala | 16 ++++++------ .../expressions/higherOrderFunctions.scala | 25 ++++++++++++++----- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d0c0247ea299..d2568a5903c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -179,20 +179,21 @@ class CodegenContext extends Logging { */ var currentLambdaVars: mutable.Map[Long, ExprCode] = mutable.HashMap.empty - def withLambdaVars(namedLambdas: Seq[NamedLambdaVariable], + def withLambdaVars( + namedLambdas: Seq[NamedLambdaVariable], f: Seq[ExprCode] => ExprCode): ExprCode = { - val lambdaVars = namedLambdas.map { namedLambda => - val id = namedLambda.exprId.id + val lambdaVars = namedLambdas.map { lambda => + val id = lambda.exprId.id if (currentLambdaVars.get(id).nonEmpty) { throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(id) } - val isNull = if (namedLambda.nullable) { + val isNull = if (lambda.nullable) { JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "lambdaIsNull")) } else { FalseLiteral } - val value = addMutableState(javaType(namedLambda.dataType), "lambdaValue") - val lambdaVar = ExprCode(isNull, JavaCode.global(value, namedLambda.dataType)) + val value = addMutableState(javaType(lambda.dataType), "lambdaValue") + val lambdaVar = ExprCode(isNull, JavaCode.global(value, lambda.dataType)) currentLambdaVars.put(id, lambdaVar) lambdaVar } @@ -203,7 +204,8 @@ class CodegenContext extends Logging { } def getLambdaVar(id: Long): ExprCode = { - currentLambdaVars.getOrElse(id, + currentLambdaVars.getOrElse( + id, throw QueryExecutionErrors.lambdaVariableNotDefinedError(id)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index ddb97fb915c4..4a4717ca4f3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -249,7 +249,10 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { } - protected def assignAtomic(atomicRef: String, value: String, isNull: String = FalseLiteral, + protected def assignAtomic( + atomicRef: String, + value: String, + isNull: String = FalseLiteral, nullable: Boolean = false) = { if (nullable) { s""" @@ -264,8 +267,12 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { } } - protected def assignArrayElement(ctx: CodegenContext, arrayName: String, elementCode: ExprCode, - elementVar: NamedLambdaVariable, index: String): String = { + protected def assignArrayElement( + ctx: CodegenContext, + arrayName: String, + elementCode: ExprCode, + elementVar: NamedLambdaVariable, + index: String): String = { val elementType = elementVar.dataType val elementAtomic = ctx.addReferenceObj(elementVar.name, elementVar.value) val extractElement = CodeGenerator.getValue(arrayName, elementType, index) @@ -286,8 +293,11 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { } } - protected def assignIndex(ctx: CodegenContext, indexCode: ExprCode, - indexVar: NamedLambdaVariable, index: String): String = { + protected def assignIndex( + ctx: CodegenContext, + indexCode: ExprCode, + indexVar: NamedLambdaVariable, + index: String): String = { val indexAtomic = ctx.addReferenceObj(indexVar.name, indexVar.value) s""" ${indexCode.value} = $index; @@ -1179,7 +1189,10 @@ case class ArrayAggregate( } } - protected def assignVar(varCode: ExprCode, value: String, isNull: String, + protected def assignVar( + varCode: ExprCode, + value: String, + isNull: String, nullable: Boolean): String = { if (nullable) { s""" From a8ea40cee60482d732504e1c104c0d4e63707c58 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 1 Oct 2024 07:13:29 -0400 Subject: [PATCH 5/8] Add tests for codegen fallback inside HOF --- .../HigherOrderFunctionsSuite.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index cc36cd73d6d7..bc608b7afecf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -149,6 +151,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val plusOne: Expression => Expression = x => x + 1 val plusIndex: (Expression, Expression) => Expression = (x, i) => x + i + val plusOneFallback: Expression => Expression = x => CodegenFallbackExpr(x + 1) checkEvaluation(transform(ai0, plusOne), Seq(2, 3, 4)) checkEvaluation(transform(ai0, plusIndex), Seq(1, 3, 5)) @@ -158,6 +161,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transform(transform(ai1, plusIndex), plusOne), Seq(2, null, 6)) checkEvaluation(transform(ain, plusOne), null) + checkEvaluation(transform(ai0, plusOneFallback), Seq(2, 3, 4)) + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) @@ -277,6 +282,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isEven: Expression => Expression = x => x % 2 === 0 val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val indexIsEven: (Expression, Expression) => Expression = { case (_, idx) => idx % 2 === 0 } + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) checkEvaluation(filter(ai0, isEven), Seq(2)) checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3)) @@ -286,6 +292,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(filter(ain, isEven), null) checkEvaluation(filter(ain, isNullOrOdd), null) + checkEvaluation(filter(ai0, isEvenFallback), Seq(2)) + val as0 = Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) @@ -321,6 +329,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType) + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) for (followThreeValuedLogic <- Seq(false, true)) { withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key @@ -337,6 +346,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(exists(ain, isNullOrOdd), null) checkEvaluation(exists(ain, alwaysFalse), null) checkEvaluation(exists(ain, alwaysNull), null) + checkEvaluation(exists(ai0, isEvenFallback), true) } } @@ -383,6 +393,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 val alwaysFalse: Expression => Expression = _ => Literal.FalseLiteral val alwaysNull: Expression => Expression = _ => Literal(null, BooleanType) + val isEvenFallback: Expression => Expression = x => CodegenFallbackExpr(x % 2 === 0) checkEvaluation(forall(ai0, isEven), true) checkEvaluation(forall(ai0, isNullOrOdd), false) @@ -401,6 +412,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(forall(ain, alwaysFalse), null) checkEvaluation(forall(ain, alwaysNull), null) + checkEvaluation(forall(ai0, isEvenFallback), true) + val as0 = Literal.create(Seq("a0", "a1", "a2", "a3"), ArrayType(StringType, containsNull = false)) val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, containsNull = true)) @@ -886,3 +899,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ))) } } + +case class CodegenFallbackExpr(child: Expression) extends UnaryExpression with CodegenFallback { + override def nullable: Boolean = child.nullable + override def dataType: DataType = child.dataType + override lazy val resolved = child.resolved + override def eval(input: InternalRow): Any = child.eval(input) + override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpr = + copy(child = newChild) +} From d72a5a099897dbd0f19a376394ed0a07b22d9f88 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Fri, 14 Mar 2025 18:28:26 +0000 Subject: [PATCH 6/8] Small cleanup --- .../expressions/higherOrderFunctions.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 4a4717ca4f3b..9222b585914e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -472,7 +472,7 @@ case class ArrayTransform( i, copy, isNull = resultNull) s""" - |final int $numElements = ${arg}.numElements(); + |final int $numElements = $arg.numElements(); |$initialization |for (int $i = 0; $i < $numElements; $i++) { | $varAssignments @@ -775,7 +775,6 @@ case class ArrayFilter( val functionCode = function.genCode(ctx) - val elementAtomic = ctx.addReferenceObj(elementVar.name, elementVar.value) val elementAssignment = assignArrayElement(ctx, arg, elementCode, elementVar, i) val indexAssignment = indexCode.map(c => assignIndex(ctx, c, indexVar.get, i)) val varAssignments = (Seq(elementAssignment) ++ indexAssignment).mkString("\n") @@ -787,8 +786,14 @@ case class ArrayFilter( val copy = CodeGenerator.createArrayAssignment(arrayData, arrayType.elementType, arg, j, i, arrayType.containsNull) + // This takes a two passes to avoid evaluating the predicate multiple times + // The first pass evaluates each element in the array, tracks how many elements + // returned true, and tracks the result of each element in a boolean array `arrayTracker`. + // The second pass copies elements from the original array to the new array created + // based on the number of elements matching the first pass. + s""" - |final int $numElements = ${arg}.numElements(); + |final int $numElements = $arg.numElements(); |$trackerInit |int $count = 0; |for (int $i = 0; $i < $numElements; $i++) { @@ -1191,17 +1196,21 @@ case class ArrayAggregate( protected def assignVar( varCode: ExprCode, + atomicVar: String, value: String, isNull: String, nullable: Boolean): String = { + val atomicAssign = assignAtomic(atomicVar, value, isNull, nullable) if (nullable) { s""" ${varCode.value} = $value; ${varCode.isNull} = $isNull; + $atomicAssign """ } else { s""" ${varCode.value} = $value; + $atomicAssign """ } } @@ -1240,36 +1249,27 @@ case class ArrayAggregate( "" } - val initialAssignment = assignVar(accForMergeCode, zeroCode.value, zeroCode.isNull, - zero.nullable) - val initialAtomic = assignAtomic(mergeAtomic, accForMergeCode.value, - accForMergeCode.isNull, merge.nullable) + val initialAssignment = assignVar(accForMergeCode, mergeAtomic, zeroCode.value, + zeroCode.isNull, zero.nullable) - val mergeAssignment = assignVar(accForMergeCode, mergeCopy, + val mergeAssignment = assignVar(accForMergeCode, mergeAtomic, mergeCopy, mergeCode.isNull, merge.nullable) - val mergeAtomicAssignment = assignAtomic(mergeAtomic, accForMergeCode.value, - accForMergeCode.isNull, merge.nullable) - val finishAssignment = assignVar(accForFinishCode, accForMergeCode.value, + val finishAssignment = assignVar(accForFinishCode, finishAtomic, accForMergeCode.value, accForMergeCode.isNull, merge.nullable) - val finishAtomicAssignment = assignAtomic(finishAtomic, accForFinishCode.value, - accForFinishCode.isNull, merge.nullable) s""" |final int $numElements = ${arg}.numElements(); |${zeroCode.code} |$initialAssignment - |$initialAtomic | |for (int $i = 0; $i < $numElements; $i++) { | $elementAssignment | ${mergeCode.code} | $mergeAssignment - | $mergeAtomicAssignment |} | |$finishAssignment - |$finishAtomicAssignment |${finishCode.code} |${ev.value} = ${finishCode.value}; |$nullCheck From 0b448ecc97ec60dfb15fe2678dff1d31451152f6 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Wed, 19 Mar 2025 07:13:54 -0400 Subject: [PATCH 7/8] Add benchmark --- .../HigherOrderFunctionsBenchmark.scala | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala new file mode 100644 index 000000000000..837f530d32b9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.functions._ + +/** + * Synthetic benchmark for higher order functions. + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/Test/runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain " + * Results will be written to "benchmarks/HigherOrderFunctionsBenchmark-results.txt". + * }}} + */ +object HigherOrderFunctionsBenchmark extends SqlBasedBenchmark { + private val N = 100_000_00 + private val M = 10 + + private val df = spark.range(N).select(array(col("id"), col("id"), col("id")).alias("arr")) + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Higher order functions") { + var benchmark = new Benchmark("transform", N, output = output) + benchmark.addCase("codegen", M) { _ => + df.select(transform(col("arr"), x => x + 1)).noop() + } + benchmark.addCase("interpreted", M) { _ => + withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { + df.select(transform(col("arr"), x => x + 1)).noop() + } + } + benchmark.run() + + benchmark = new Benchmark("filter", N, output = output) + benchmark.addCase("codegen", M) { _ => + df.select(filter(col("arr"), x => x > 1)).noop() + } + benchmark.addCase("interpreted", M) { _ => + withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { + df.select(filter(col("arr"), x => x > 1)).noop() + } + } + benchmark.run() + + benchmark = new Benchmark("forall - fast", N, output = output) + benchmark.addCase("codegen", M) { _ => + df.select(forall(col("arr"), x => x < 0)).noop() + } + benchmark.addCase("interpreted", M) { _ => + withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { + df.select(forall(col("arr"), x => x < 0)).noop() + } + } + benchmark.run() + + benchmark = new Benchmark("forall - slow", N, output = output) + benchmark.addCase("codegen", M) { _ => + df.select(forall(col("arr"), x => x >= 0)).noop() + } + benchmark.addCase("interpreted", M) { _ => + withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { + df.select(forall(col("arr"), x => x >= 0)).noop() + } + } + benchmark.run() + + benchmark = new Benchmark("exists - fast", N, output = output) + benchmark.addCase("codegen", M) { _ => + df.select(exists(col("arr"), x => x >= 0)).noop() + } + benchmark.addCase("interpreted", M) { _ => + withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { + df.select(exists(col("arr"), x => x >= 0)).noop() + } + } + benchmark.run() + + benchmark = new Benchmark("exists - slow", N, output = output) + benchmark.addCase("codegen", M) { _ => + df.select(exists(col("arr"), x => x < 0)).noop() + } + benchmark.addCase("interpreted", M) { _ => + withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { + df.select(exists(col("arr"), x => x < 0)).noop() + } + } + benchmark.run() + + benchmark = new Benchmark("aggregate", N, output = output) + benchmark.addCase("codegen", M) { _ => + df.select(aggregate(col("arr"), lit(0L), (acc, x) => acc + x)).noop() + } + benchmark.addCase("interpreted", M) { _ => + withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { + df.select(aggregate(col("arr"), lit(0L), (acc, x) => acc + x)).noop() + } + } + benchmark.run() + } + } +} From 35cd592afd93bf065f3c869ff5523a311c937b82 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 1 Apr 2025 15:10:08 +0000 Subject: [PATCH 8/8] Simplify benchmrak --- .../HigherOrderFunctionsBenchmark.scala | 94 +++++-------------- 1 file changed, 23 insertions(+), 71 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala index 837f530d32b9..960607625559 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HigherOrderFunctionsBenchmark.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf /** * Synthetic benchmark for higher order functions. @@ -40,82 +43,31 @@ object HigherOrderFunctionsBenchmark extends SqlBasedBenchmark { override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("Higher order functions") { - var benchmark = new Benchmark("transform", N, output = output) - benchmark.addCase("codegen", M) { _ => - df.select(transform(col("arr"), x => x + 1)).noop() - } - benchmark.addCase("interpreted", M) { _ => - withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { - df.select(transform(col("arr"), x => x + 1)).noop() - } - } - benchmark.run() - - benchmark = new Benchmark("filter", N, output = output) - benchmark.addCase("codegen", M) { _ => - df.select(filter(col("arr"), x => x > 1)).noop() - } - benchmark.addCase("interpreted", M) { _ => - withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { - df.select(filter(col("arr"), x => x > 1)).noop() - } - } - benchmark.run() - - benchmark = new Benchmark("forall - fast", N, output = output) - benchmark.addCase("codegen", M) { _ => - df.select(forall(col("arr"), x => x < 0)).noop() - } - benchmark.addCase("interpreted", M) { _ => - withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { - df.select(forall(col("arr"), x => x < 0)).noop() - } - } - benchmark.run() - - benchmark = new Benchmark("forall - slow", N, output = output) - benchmark.addCase("codegen", M) { _ => - df.select(forall(col("arr"), x => x >= 0)).noop() - } - benchmark.addCase("interpreted", M) { _ => - withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { - df.select(forall(col("arr"), x => x >= 0)).noop() + def benchFunction(name: String, col: Column) = { + var benchmark = new Benchmark(name, N, output = output) + benchmark.addCase("codegen", M) { _ => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> + CodegenObjectFactoryMode.CODEGEN_ONLY.toString()) { + df.select(col).noop() + } } - } - benchmark.run() - benchmark = new Benchmark("exists - fast", N, output = output) - benchmark.addCase("codegen", M) { _ => - df.select(exists(col("arr"), x => x >= 0)).noop() - } - benchmark.addCase("interpreted", M) { _ => - withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { - df.select(exists(col("arr"), x => x >= 0)).noop() + benchmark.addCase("interpreted", M) { _ => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> + CodegenObjectFactoryMode.NO_CODEGEN.toString()) { + df.select(col).noop() + } } + benchmark.run() } - benchmark.run() - benchmark = new Benchmark("exists - slow", N, output = output) - benchmark.addCase("codegen", M) { _ => - df.select(exists(col("arr"), x => x < 0)).noop() - } - benchmark.addCase("interpreted", M) { _ => - withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { - df.select(exists(col("arr"), x => x < 0)).noop() - } - } - benchmark.run() - - benchmark = new Benchmark("aggregate", N, output = output) - benchmark.addCase("codegen", M) { _ => - df.select(aggregate(col("arr"), lit(0L), (acc, x) => acc + x)).noop() - } - benchmark.addCase("interpreted", M) { _ => - withSQLConf("spark.sql.codegen.factoryMode" -> "NO_CODEGEN") { - df.select(aggregate(col("arr"), lit(0L), (acc, x) => acc + x)).noop() - } - } - benchmark.run() + benchFunction("transform", transform(col("arr"), x => x + 1)) + benchFunction("filter", filter(col("arr"), x => x > 1)) + benchFunction("forall - fast", forall(col("arr"), x => x < 0)) + benchFunction("forall - slow", forall(col("arr"), x => x >= 0)) + benchFunction("exists - fast", exists(col("arr"), x => x >= 0)) + benchFunction("exists - slow", exists(col("arr"), x => x < 0)) + benchFunction("aggregate", aggregate(col("arr"), lit(0L), (acc, x) => acc + x)) } } }