diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f8f868b59b967..4dc3d48eb6f61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -67,7 +67,8 @@ case class SubExprEliminationState(isNull: String, value: String) /** * Codes and common subexpressions mapping used for subexpression elimination. * - * @param codes Strings representing the codes that evaluate common subexpressions. + * @param codes Strings representing the codes that reset the initialization status of + * common subexpression evaluation. * @param states Foreach expression that is participating in subexpression elimination, * the state to use. */ @@ -713,6 +714,47 @@ class CodegenContext { genCodes } + /** + * A private helper function used to construct the parameter list for subexpression elimination + * evaluation functions + * + * @param expression The subexpression to evaluate. + * @param caller Indicating to construct parameter list for function caller. + */ + private def genFunctionParamsListForSubExprEliminate( + expression: Expression, + caller: Boolean): String = { + val boundRefs = expression.collect { + case b: BoundReference => b + }.distinct + if (currentVars == null) { + if (caller) INPUT_ROW else s"InternalRow $INPUT_ROW" + } else { + val boundRefsInCurrentVars = boundRefs.filter(b => currentVars(b.ordinal) != null) + val currentVarsParams = boundRefsInCurrentVars.map { bound => + val paramType = javaType(bound.dataType) + val variable = currentVars(bound.ordinal).value + val isNull = currentVars(bound.ordinal).isNull + if (caller) { + if (isNull == "false") variable else s"$variable, $isNull" + } else { + if (isNull == "false") { + s"$paramType $variable" + } else { + s"$paramType $variable, boolean $isNull" + } + } + } + + if (boundRefsInCurrentVars.size == boundRefs.size) { + currentVarsParams.mkString(", ") + } else { + val rowParam = if (caller) INPUT_ROW else s"InternalRow $INPUT_ROW" + (Seq(rowParam) ++ currentVarsParams).mkString(", ") + } + } + } + /** * Checks and sets up the state and codegen for subexpression elimination. This finds the * common subexpressions, generates the code snippets that evaluate those expressions and @@ -733,11 +775,63 @@ class CodegenContext { val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) val codes = commonExprs.map { e => val expr = e.head + val fnName = freshName("evalSubExpr") + val isNull = s"${fnName}IsNull" + val value = s"${fnName}Value" + val isInitialized = s"${fnName}IsInitialized" + + val functionParams = genFunctionParamsListForSubExprEliminate(expr, false) + val callerParams = genFunctionParamsListForSubExprEliminate(expr, true) + // Generate the code for this expression tree. val code = expr.genCode(this) - val state = SubExprEliminationState(code.isNull, code.value) + val returnType = javaType(expr.dataType) + val fn = + s""" + |private void $fnName($functionParams) { + | ${code.code.trim} + | $isNull = ${code.isNull}; + | $value = ${code.value}; + | $isInitialized = true; + |} + """.stripMargin + + val valueFnName = s"${fnName}ForValue" + val valueFn = + s""" + |private $returnType $valueFnName($functionParams) { + | if (!$isInitialized) { + | $fnName($callerParams); + | } + | return $value; + |} + """.stripMargin + + val isNullFnName = s"${fnName}ForIsNull" + val isNullFn = + s""" + |private boolean $isNullFnName($functionParams) { + | if (!$isInitialized) { + | $fnName($callerParams); + | } + | return $isNull; + |} + """.stripMargin + + addNewFunction(fnName, fn) + addNewFunction(valueFnName, valueFn) + addNewFunction(isNullFnName, isNullFn) + + addMutableState("boolean", isNull, s"$isNull = false;") + addMutableState("boolean", isInitialized, s"$isInitialized = false;") + addMutableState(returnType, value, s"$value = ${defaultValue(expr.dataType)};") + + val state = SubExprEliminationState( + isNull = s"$isNullFnName($callerParams)", + value = s"$valueFnName($callerParams)") + e.foreach(subExprEliminationExprs.put(_, state)) - code.code.trim + s"$isInitialized = false;" } SubExprCodes(codes, subExprEliminationExprs.toMap) } @@ -759,19 +853,46 @@ class CodegenContext { val fnName = freshName("evalExpr") val isNull = s"${fnName}IsNull" val value = s"${fnName}Value" + val isInitialized = s"${fnName}IsInitialized" // Generate the code for this expression tree and wrap it in a function. val code = expr.genCode(this) + val returnType = javaType(expr.dataType) val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { | ${code.code.trim} | $isNull = ${code.isNull}; | $value = ${code.value}; + | $isInitialized = true; + |} + """.stripMargin + + val valueFnName = s"${fnName}ForValue" + val valueFn = + s""" + |private $returnType $valueFnName(InternalRow $INPUT_ROW) { + | if (!$isInitialized) { + | $fnName($INPUT_ROW); + | } + | return $value; + |} + """.stripMargin + + val isNullFnName = s"${fnName}ForIsNull" + val isNullFn = + s""" + |private boolean $isNullFnName(InternalRow $INPUT_ROW) { + | if (!$isInitialized) { + | $fnName($INPUT_ROW); + | } + | return $isNull; |} """.stripMargin addNewFunction(fnName, fn) + addNewFunction(valueFnName, valueFn) + addNewFunction(isNullFnName, isNullFn) // Add a state and a mapping of the common subexpressions that are associate with this // state. Adding this expression to subExprEliminationExprMap means it will call `fn` @@ -790,11 +911,14 @@ class CodegenContext { // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. addMutableState("boolean", isNull, s"$isNull = false;") - addMutableState(javaType(expr.dataType), value, + addMutableState("boolean", isInitialized, s"$isInitialized = false;") + addMutableState(returnType, value, s"$value = ${defaultValue(expr.dataType)};") - subexprFunctions += s"$fnName($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + subexprFunctions += s"$isInitialized = false;" + val state = SubExprEliminationState( + isNull = s"$isNullFnName($INPUT_ROW)", + value = s"$valueFnName($INPUT_ROW)") e.foreach(subExprEliminationExprs.put(_, state)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 2db2a043e546a..630a430fb5514 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull +import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} class SubexpressionEliminationSuite extends SparkFunSuite { test("Semantic equals and hash") { @@ -172,6 +173,21 @@ class SubexpressionEliminationSuite extends SparkFunSuite { assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode } + + test("SPARK-18395: evaluate subexpressions like lazy variables") { + val row = new GenericInternalRow(Seq(null, new java.lang.Integer(10)).toArray[Any]) + val bound = BoundReference(0, IntegerType, true) + val add = Add(bound, Literal(1)) + val assertNotNull = AssertNotNull(bound, Seq.empty[String]) + val expr = If( + IsNull(bound), + Literal(1), + Add(assertNotNull, Add(assertNotNull, Literal(1)))) + val schema = StructType(Seq(StructField("int", IntegerType))).toAttributes + val projection = + GenerateMutableProjection.generate(Seq(expr), schema, useSubexprElimination = true) + projection(row).getInt(0) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4529ed067e565..7e34e8b6c47cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -253,7 +253,7 @@ case class HashAggregateExec( ctx.currentVars = bufVars ++ input val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") + val resetSubExprEvaluation = subExprs.codes.mkString("\n") val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } @@ -266,8 +266,8 @@ case class HashAggregateExec( } s""" | // do aggregate - | // common sub-expressions - | $effectiveCodes + | // reset the initialization status for common sub-expressions + | $resetSubExprEvaluation | // evaluate aggregate function | ${evaluateVariables(aggVals)} | // update aggregation buffer @@ -758,7 +758,7 @@ case class HashAggregateExec( ctx.INPUT_ROW = fastRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") + val resetSubExprEvaluation = subExprs.codes.mkString("\n") val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } @@ -768,8 +768,8 @@ case class HashAggregateExec( } Option( s""" - |// common sub-expressions - |$effectiveCodes + |// reset the initialization status for common sub-expressions + |$resetSubExprEvaluation |// evaluate aggregate function |${evaluateVariables(fastRowEvals)} |// update fast row @@ -814,7 +814,7 @@ case class HashAggregateExec( ctx.INPUT_ROW = unsafeRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") + val resetSubExprEvaluation = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } @@ -823,8 +823,8 @@ case class HashAggregateExec( ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) } s""" - |// common sub-expressions - |$effectiveCodes + |// reset the initialization status for common sub-expressions + |$resetSubExprEvaluation |// evaluate aggregate function |${evaluateVariables(unsafeRowBufferEvals)} |// update unsafe row buffer