Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ case class SubExprEliminationState(isNull: String, value: String)
/**
* Codes and common subexpressions mapping used for subexpression elimination.
*
* @param codes Strings representing the codes that evaluate common subexpressions.
* @param codes Strings representing the codes that reset the initialization status of
* common subexpression evaluation.
* @param states Foreach expression that is participating in subexpression elimination,
* the state to use.
*/
Expand Down Expand Up @@ -713,6 +714,47 @@ class CodegenContext {
genCodes
}

/**
* A private helper function used to construct the parameter list for subexpression elimination
* evaluation functions
*
* @param expression The subexpression to evaluate.
* @param caller Indicating to construct parameter list for function caller.
*/
private def genFunctionParamsListForSubExprEliminate(
expression: Expression,
caller: Boolean): String = {
val boundRefs = expression.collect {
case b: BoundReference => b
}.distinct
if (currentVars == null) {
if (caller) INPUT_ROW else s"InternalRow $INPUT_ROW"
} else {
val boundRefsInCurrentVars = boundRefs.filter(b => currentVars(b.ordinal) != null)
val currentVarsParams = boundRefsInCurrentVars.map { bound =>
val paramType = javaType(bound.dataType)
val variable = currentVars(bound.ordinal).value
val isNull = currentVars(bound.ordinal).isNull
if (caller) {
if (isNull == "false") variable else s"$variable, $isNull"
} else {
if (isNull == "false") {
s"$paramType $variable"
} else {
s"$paramType $variable, boolean $isNull"
}
}
}

if (boundRefsInCurrentVars.size == boundRefs.size) {
currentVarsParams.mkString(", ")
} else {
val rowParam = if (caller) INPUT_ROW else s"InternalRow $INPUT_ROW"
(Seq(rowParam) ++ currentVarsParams).mkString(", ")
}
}
}

/**
* Checks and sets up the state and codegen for subexpression elimination. This finds the
* common subexpressions, generates the code snippets that evaluate those expressions and
Expand All @@ -733,11 +775,63 @@ class CodegenContext {
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
val codes = commonExprs.map { e =>
val expr = e.head
val fnName = freshName("evalSubExpr")
val isNull = s"${fnName}IsNull"
val value = s"${fnName}Value"
val isInitialized = s"${fnName}IsInitialized"

val functionParams = genFunctionParamsListForSubExprEliminate(expr, false)
val callerParams = genFunctionParamsListForSubExprEliminate(expr, true)

// Generate the code for this expression tree.
val code = expr.genCode(this)
val state = SubExprEliminationState(code.isNull, code.value)
val returnType = javaType(expr.dataType)
val fn =
s"""
|private void $fnName($functionParams) {
| ${code.code.trim}
| $isNull = ${code.isNull};
| $value = ${code.value};
| $isInitialized = true;
|}
""".stripMargin

val valueFnName = s"${fnName}ForValue"
val valueFn =
s"""
|private $returnType $valueFnName($functionParams) {
| if (!$isInitialized) {
| $fnName($callerParams);
| }
| return $value;
|}
""".stripMargin

val isNullFnName = s"${fnName}ForIsNull"
val isNullFn =
s"""
|private boolean $isNullFnName($functionParams) {
| if (!$isInitialized) {
| $fnName($callerParams);
| }
| return $isNull;
|}
""".stripMargin

addNewFunction(fnName, fn)
addNewFunction(valueFnName, valueFn)
addNewFunction(isNullFnName, isNullFn)

addMutableState("boolean", isNull, s"$isNull = false;")
addMutableState("boolean", isInitialized, s"$isInitialized = false;")
addMutableState(returnType, value, s"$value = ${defaultValue(expr.dataType)};")

val state = SubExprEliminationState(
isNull = s"$isNullFnName($callerParams)",
value = s"$valueFnName($callerParams)")

e.foreach(subExprEliminationExprs.put(_, state))
code.code.trim
s"$isInitialized = false;"
}
SubExprCodes(codes, subExprEliminationExprs.toMap)
}
Expand All @@ -759,19 +853,46 @@ class CodegenContext {
val fnName = freshName("evalExpr")
val isNull = s"${fnName}IsNull"
val value = s"${fnName}Value"
val isInitialized = s"${fnName}IsInitialized"

// Generate the code for this expression tree and wrap it in a function.
val code = expr.genCode(this)
val returnType = javaType(expr.dataType)
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
| ${code.code.trim}
| $isNull = ${code.isNull};
| $value = ${code.value};
| $isInitialized = true;
|}
""".stripMargin

val valueFnName = s"${fnName}ForValue"
val valueFn =
s"""
|private $returnType $valueFnName(InternalRow $INPUT_ROW) {
| if (!$isInitialized) {
| $fnName($INPUT_ROW);
| }
| return $value;
|}
""".stripMargin

val isNullFnName = s"${fnName}ForIsNull"
val isNullFn =
s"""
|private boolean $isNullFnName(InternalRow $INPUT_ROW) {
| if (!$isInitialized) {
| $fnName($INPUT_ROW);
| }
| return $isNull;
|}
""".stripMargin

addNewFunction(fnName, fn)
addNewFunction(valueFnName, valueFn)
addNewFunction(isNullFnName, isNullFn)

// Add a state and a mapping of the common subexpressions that are associate with this
// state. Adding this expression to subExprEliminationExprMap means it will call `fn`
Expand All @@ -790,11 +911,14 @@ class CodegenContext {
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
// at least two nodes) as the cost of doing it is expected to be low.
addMutableState("boolean", isNull, s"$isNull = false;")
addMutableState(javaType(expr.dataType), value,
addMutableState("boolean", isInitialized, s"$isInitialized = false;")
addMutableState(returnType, value,
s"$value = ${defaultValue(expr.dataType)};")

subexprFunctions += s"$fnName($INPUT_ROW);"
val state = SubExprEliminationState(isNull, value)
subexprFunctions += s"$isInitialized = false;"
val state = SubExprEliminationState(
isNull = s"$isNullFnName($INPUT_ROW)",
value = s"$valueFnName($INPUT_ROW)")
e.foreach(subExprEliminationExprs.put(_, state))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types.{DataType, IntegerType}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GenerateMutableProjection}
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType}

class SubexpressionEliminationSuite extends SparkFunSuite {
test("Semantic equals and hash") {
Expand Down Expand Up @@ -172,6 +173,21 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}

test("SPARK-18395: evaluate subexpressions like lazy variables") {
val row = new GenericInternalRow(Seq(null, new java.lang.Integer(10)).toArray[Any])
val bound = BoundReference(0, IntegerType, true)
val add = Add(bound, Literal(1))
val assertNotNull = AssertNotNull(bound, Seq.empty[String])
val expr = If(
IsNull(bound),
Literal(1),
Add(assertNotNull, Add(assertNotNull, Literal(1))))
val schema = StructType(Seq(StructField("int", IntegerType))).toAttributes
val projection =
GenerateMutableProjection.generate(Seq(expr), schema, useSubexprElimination = true)
projection(row).getInt(0)
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ case class HashAggregateExec(
ctx.currentVars = bufVars ++ input
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val resetSubExprEvaluation = subExprs.codes.mkString("\n")
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
Expand All @@ -266,8 +266,8 @@ case class HashAggregateExec(
}
s"""
| // do aggregate
| // common sub-expressions
| $effectiveCodes
| // reset the initialization status for common sub-expressions
| $resetSubExprEvaluation
| // evaluate aggregate function
| ${evaluateVariables(aggVals)}
| // update aggregation buffer
Expand Down Expand Up @@ -758,7 +758,7 @@ case class HashAggregateExec(
ctx.INPUT_ROW = fastRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val resetSubExprEvaluation = subExprs.codes.mkString("\n")
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
Expand All @@ -768,8 +768,8 @@ case class HashAggregateExec(
}
Option(
s"""
|// common sub-expressions
|$effectiveCodes
|// reset the initialization status for common sub-expressions
|$resetSubExprEvaluation
|// evaluate aggregate function
|${evaluateVariables(fastRowEvals)}
|// update fast row
Expand Down Expand Up @@ -814,7 +814,7 @@ case class HashAggregateExec(
ctx.INPUT_ROW = unsafeRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val resetSubExprEvaluation = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
Expand All @@ -823,8 +823,8 @@ case class HashAggregateExec(
ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
}
s"""
|// common sub-expressions
|$effectiveCodes
|// reset the initialization status for common sub-expressions
|$resetSubExprEvaluation
|// evaluate aggregate function
|${evaluateVariables(unsafeRowBufferEvals)}
|// update unsafe row buffer
Expand Down