From 63accf804059a5b6ea9ad359b980472024f3b4fd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 10 Nov 2016 04:57:35 +0000 Subject: [PATCH 1/3] Evaluate common subexpression like lazy variable with a function approach. --- .../expressions/codegen/CodeGenerator.scala | 100 +++++++++++++++++- .../aggregate/HashAggregateExec.scala | 18 ++-- 2 files changed, 106 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9c3c6d3b2a7f..dad39c4e4960 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -67,7 +67,8 @@ case class SubExprEliminationState(isNull: String, value: String) /** * Codes and common subexpressions mapping used for subexpression elimination. * - * @param codes Strings representing the codes that evaluate common subexpressions. + * @param codes Strings representing the codes that reset the initialization status of + * common subexpression evaluation. * @param states Foreach expression that is participating in subexpression elimination, * the state to use. */ @@ -680,6 +681,47 @@ class CodegenContext { genCodes } + /** + * A private helper function used to construct the parameter list for subexpression elimination + * evaluation functions + * + * @param expression The subexpression to evaluate. + * @param caller Indicating to construct parameter list for function caller. + */ + private def genFunctionParamsListForSubExprEliminate( + expression: Expression, + caller: Boolean): String = { + val boundRefs = expression.collect { + case b: BoundReference => b + }.distinct + if (currentVars == null) { + if (caller) INPUT_ROW else s"InternalRow $INPUT_ROW" + } else { + val boundRefsInCurrentVars = boundRefs.filter(b => currentVars(b.ordinal) != null) + val currentVarsParams = boundRefsInCurrentVars.map { bound => + val paramType = javaType(bound.dataType) + val variable = currentVars(bound.ordinal).value + val isNull = currentVars(bound.ordinal).isNull + if (caller) { + if (isNull == "false") variable else s"$variable, $isNull" + } else { + if (isNull == "false") { + s"$paramType $variable" + } else { + s"$paramType $variable, boolean $isNull" + } + } + } + + if (boundRefsInCurrentVars.size == boundRefs.size) { + currentVarsParams.mkString(", ") + } else { + val rowParam = if (caller) INPUT_ROW else s"InternalRow $INPUT_ROW" + (Seq(rowParam) ++ currentVarsParams).mkString(", ") + } + } + } + /** * Checks and sets up the state and codegen for subexpression elimination. This finds the * common subexpressions, generates the code snippets that evaluate those expressions and @@ -700,11 +742,63 @@ class CodegenContext { val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) val codes = commonExprs.map { e => val expr = e.head + val fnName = freshName("evalSubExpr") + val isNull = s"${fnName}IsNull" + val value = s"${fnName}Value" + val isInitialized = s"${fnName}IsInitialized" + + val functionParams = genFunctionParamsListForSubExprEliminate(expr, false) + val callerParams = genFunctionParamsListForSubExprEliminate(expr, true) + // Generate the code for this expression tree. val code = expr.genCode(this) - val state = SubExprEliminationState(code.isNull, code.value) + val returnType = javaType(expr.dataType) + val fn = + s""" + |private void $fnName($functionParams) { + | ${code.code.trim} + | $isNull = ${code.isNull}; + | $value = ${code.value}; + | $isInitialized = true; + |} + """.stripMargin + + val valueFnName = s"${fnName}ForValue" + val valueFn = + s""" + |private $returnType $valueFnName($functionParams) { + | if (!$isInitialized) { + | $fnName($callerParams); + | } + | return $value; + |} + """.stripMargin + + val isNullFnName = s"${fnName}ForIsNull" + val isNullFn = + s""" + |private boolean $isNullFnName($functionParams) { + | if (!$isInitialized) { + | $fnName($callerParams); + | } + | return $isNull; + |} + """.stripMargin + + addNewFunction(fnName, fn) + addNewFunction(valueFnName, valueFn) + addNewFunction(isNullFnName, isNullFn) + + addMutableState("boolean", isNull, s"$isNull = false;") + addMutableState("boolean", isInitialized, s"$isInitialized = false;") + addMutableState(returnType, value, s"$value = ${defaultValue(expr.dataType)};") + + val state = SubExprEliminationState( + isNull = s"$isNullFnName($callerParams)", + value = s"$valueFnName($callerParams)") + e.foreach(subExprEliminationExprs.put(_, state)) - code.code.trim + s"$isInitialized = false;" } SubExprCodes(codes, subExprEliminationExprs.toMap) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4529ed067e56..7e34e8b6c47c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -253,7 +253,7 @@ case class HashAggregateExec( ctx.currentVars = bufVars ++ input val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") + val resetSubExprEvaluation = subExprs.codes.mkString("\n") val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } @@ -266,8 +266,8 @@ case class HashAggregateExec( } s""" | // do aggregate - | // common sub-expressions - | $effectiveCodes + | // reset the initialization status for common sub-expressions + | $resetSubExprEvaluation | // evaluate aggregate function | ${evaluateVariables(aggVals)} | // update aggregation buffer @@ -758,7 +758,7 @@ case class HashAggregateExec( ctx.INPUT_ROW = fastRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") + val resetSubExprEvaluation = subExprs.codes.mkString("\n") val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } @@ -768,8 +768,8 @@ case class HashAggregateExec( } Option( s""" - |// common sub-expressions - |$effectiveCodes + |// reset the initialization status for common sub-expressions + |$resetSubExprEvaluation |// evaluate aggregate function |${evaluateVariables(fastRowEvals)} |// update fast row @@ -814,7 +814,7 @@ case class HashAggregateExec( ctx.INPUT_ROW = unsafeRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") + val resetSubExprEvaluation = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } @@ -823,8 +823,8 @@ case class HashAggregateExec( ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) } s""" - |// common sub-expressions - |$effectiveCodes + |// reset the initialization status for common sub-expressions + |$resetSubExprEvaluation |// evaluate aggregate function |${evaluateVariables(unsafeRowBufferEvals)} |// update unsafe row buffer From 3274e91e670e2218bf4a1d9da0c19c25b09b0881 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 10 Nov 2016 07:41:32 +0000 Subject: [PATCH 2/3] For non-wholestage codegen subexpression elimination. --- .../expressions/codegen/CodeGenerator.scala | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index dad39c4e4960..1cdf1ec39a30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -820,19 +820,46 @@ class CodegenContext { val fnName = freshName("evalExpr") val isNull = s"${fnName}IsNull" val value = s"${fnName}Value" + val isInitialized = s"${fnName}IsInitialized" // Generate the code for this expression tree and wrap it in a function. val code = expr.genCode(this) + val returnType = javaType(expr.dataType) val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { | ${code.code.trim} | $isNull = ${code.isNull}; | $value = ${code.value}; + | $isInitialized = true; + |} + """.stripMargin + + val valueFnName = s"${fnName}ForValue" + val valueFn = + s""" + |private $returnType $valueFnName(InternalRow $INPUT_ROW) { + | if (!$isInitialized) { + | $fnName($INPUT_ROW); + | } + | return $value; + |} + """.stripMargin + + val isNullFnName = s"${fnName}ForIsNull" + val isNullFn = + s""" + |private boolean $isNullFnName(InternalRow $INPUT_ROW) { + | if (!$isInitialized) { + | $fnName($INPUT_ROW); + | } + | return $isNull; |} """.stripMargin addNewFunction(fnName, fn) + addNewFunction(valueFnName, valueFn) + addNewFunction(isNullFnName, isNullFn) // Add a state and a mapping of the common subexpressions that are associate with this // state. Adding this expression to subExprEliminationExprMap means it will call `fn` @@ -851,11 +878,14 @@ class CodegenContext { // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. addMutableState("boolean", isNull, s"$isNull = false;") - addMutableState(javaType(expr.dataType), value, + addMutableState("boolean", isInitialized, s"$isInitialized = false;") + addMutableState(returnType, value, s"$value = ${defaultValue(expr.dataType)};") - subexprFunctions += s"$fnName($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + subexprFunctions += s"$isInitialized = false;" + val state = SubExprEliminationState( + isNull = s"$isNullFnName($INPUT_ROW)", + value = s"$valueFnName($INPUT_ROW)") e.foreach(subExprEliminationExprs.put(_, state)) } } From 3ae70e8bb400c178049a1bba0523ce942ced8a8a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 10 Nov 2016 08:41:54 +0000 Subject: [PATCH 3/3] Add test. --- .../SubexpressionEliminationSuite.scala | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 1e39b24fe877..351cc6c630b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class SubexpressionEliminationSuite extends SparkFunSuite { test("Semantic equals and hash") { @@ -171,4 +173,19 @@ class SubexpressionEliminationSuite extends SparkFunSuite { assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode } + + test("SPARK-18395: evaluate subexpressions like lazy variables") { + val row = new GenericInternalRow(Seq(null, new java.lang.Integer(10)).toArray[Any]) + val bound = BoundReference(0, IntegerType, true) + val add = Add(bound, Literal(1)) + val assertNotNull = AssertNotNull(bound, Seq.empty[String]) + val expr = If( + IsNull(bound), + Literal(1), + Add(assertNotNull, Add(assertNotNull, Literal(1)))) + val schema = StructType(Seq(StructField("int", IntegerType))).toAttributes + val projection = + GenerateMutableProjection.generate(Seq(expr), schema, useSubexprElimination = true) + projection(row).getInt(0) + } }