Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -462,34 +462,9 @@ abstract class BinaryExpression extends Expression {
ctx: CodegenContext,
ev: ExprCode,
f: (String, String) => String): ExprCode = {
val leftGen = left.genCode(ctx)
val rightGen = right.genCode(ctx)
val resultCode = f(leftGen.value, rightGen.value)

if (nullable) {
val nullSafeEval =
leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) {
rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) {
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
}
}

ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${leftGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode""", isNull = "false")
}
val (childrenEval, childrenGen) = ExprCodegen.genCodeWithChildren(ctx, this)
val resultCode = f(childrenGen(0).value, childrenGen(1).value)
ExprCodegen.nullSafeCodeGen(ctx, ev, this, childrenEval, resultCode)
}
}

Expand Down Expand Up @@ -583,9 +558,9 @@ abstract class TernaryExpression extends Expression {
* @param f accepts three variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String) => String): ExprCode = {
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String) => String): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => {
s"${ev.value} = ${f(eval1, eval2, eval3)};"
})
Expand All @@ -600,38 +575,125 @@ abstract class TernaryExpression extends Expression {
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String) => String): ExprCode = {
val leftGen = children(0).genCode(ctx)
val midGen = children(1).genCode(ctx)
val rightGen = children(2).genCode(ctx)
val resultCode = f(leftGen.value, midGen.value, rightGen.value)
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String) => String): ExprCode = {
val (childrenEval, childrenGen) = ExprCodegen.genCodeWithChildren(ctx, this)
val resultCode = f(childrenGen(0).value, childrenGen(1).value, childrenGen(2).value)
ExprCodegen.nullSafeCodeGen(ctx, ev, this, childrenEval, resultCode)
}
}

if (nullable) {
val nullSafeEval =
leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) {
midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) {
rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) {
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
}
object ExprCodegen {

private val placeHolderForResultCode = "__PLACEHOLDER__"

def isNotWholeStageCodegen(ctx: CodegenContext): Boolean =
ctx.INPUT_ROW != null && ctx.currentVars == null

// Moves generated codes for child expression into a function.
// Only supports non whole stage codegen case.
def genChildCodeInFunction(
ctx: CodegenContext,
child: Expression,
childExprCode: ExprCode): Unit = {
if (isNotWholeStageCodegen(ctx)) {
val setIsNull = if (childExprCode.isNull != "false" && childExprCode.isNull != "true") {
val globalIsNull = ctx.freshName("globalIsNull")
ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;")
val localIsNull = childExprCode.isNull
childExprCode.isNull = globalIsNull
s"$globalIsNull = $localIsNull;"
} else {
""
}

val javaType = ctx.javaType(child.dataType)
val newValue = ctx.freshName("value")

val funcName = ctx.freshName(child.nodeName)
val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
| ${childExprCode.code.trim}
| $setIsNull
| return ${childExprCode.value};
|}
""".stripMargin)

childExprCode.value = newValue
childExprCode.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}

// Generating evaluation code for children expressions. We fold the code for each child
// expression from right to left. Once code length will be larger than the threshold, we
// put the code from next child expression into a method to prevent possible 64k limit.
def genCodeWithChildren(
ctx: CodegenContext,
expr: Expression): (String, Seq[ExprCode]) = {
val genCodeForChildren = expr.children.map(_.genCode(ctx))

// For nullable expression, the code of children expression is wrapped in "if" blocks
// for null check. We leave a special placeholder string which will be replaced with
// evaluation code of this expression later.
val initCode = if (expr.nullable) {
placeHolderForResultCode
} else {
""
}

var childIdx = expr.children.length - 1
val childrenEval = genCodeForChildren.foldRight(initCode) { case (childCode, curCode) =>
if (curCode.length + childCode.code.trim.length > 1024) {
genChildCodeInFunction(ctx, expr.children(childIdx), childCode)
}
val code = if (expr.nullable) {
childCode.code +
ctx.nullSafeExec(expr.children(childIdx).nullable, childCode.isNull) {
curCode
}
} else {
childCode.code + "\n" + curCode
}
childIdx -= 1
code
}
(childrenEval, genCodeForChildren)
}

/**
* Generating evaluation code for binary and ternary expressions for now.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param childrenEval Evaluation code for all sub-expressions.
* @param resultCode Evaluation code for the current expression. The evaluation is based on
* the values of `childrenEval`.
*/
def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
expr: Expression,
childrenEval: String,
resultCode: String): ExprCode = {
if (expr.nullable) {
val insertCode = s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
val nullSafeEval = childrenEval.replace(placeHolderForResultCode, insertCode)

ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$nullSafeEval""")
${ctx.javaType(expr.dataType)} ${ev.value} = ${ctx.defaultValue(expr.dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${leftGen.code}
${midGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$childrenEval
${ctx.javaType(expr.dataType)} ${ev.value} = ${ctx.defaultValue(expr.dataType)};
$resultCode""", isNull = "false")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,16 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd")
}
}

test("SPARK-22551: Prevent possible 64kb compile error for common expression types") {
val N = 1800
var addedExpr: Expression = Literal(1)
var expected = 1
for (i <- 0 until N) {
addedExpr = Add(Literal(i), addedExpr)
expected += i
}

checkEvaluation(Add(addedExpr, addedExpr), expected * 2, EmptyRow)
}
}