From dc49b6e1c884bce164e08bb3f63cbdec86541c75 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 18 Nov 2017 15:11:05 +0000 Subject: [PATCH 1/3] Put large generated codes of children expressions into functions. --- .../sql/catalyst/expressions/Expression.scala | 158 ++++++++++++------ 1 file changed, 104 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a3b722a47d68..8ddaed6875dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -462,34 +462,9 @@ abstract class BinaryExpression extends Expression { ctx: CodegenContext, ev: ExprCode, f: (String, String) => String): ExprCode = { - val leftGen = left.genCode(ctx) - val rightGen = right.genCode(ctx) - val resultCode = f(leftGen.value, rightGen.value) - - if (nullable) { - val nullSafeEval = - leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) { - rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) { - s""" - ${ev.isNull} = false; // resultCode could change nullability. - $resultCode - """ - } - } - - ev.copy(code = s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $nullSafeEval - """) - } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; - ${leftGen.code} - ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode""", isNull = "false") - } + val (childrenEval, childrenGen) = ExprCodegen.genCodeWithChildren(ctx, this) + val resultCode = f(childrenGen(0).value, childrenGen(1).value) + ExprCodegen.nullSafeCodeGen(ctx, ev, this, childrenEval, resultCode) } } @@ -583,9 +558,9 @@ abstract class TernaryExpression extends Expression { * @param f accepts three variable names and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodegenContext, - ev: ExprCode, - f: (String, String, String) => String): ExprCode = { + ctx: CodegenContext, + ev: ExprCode, + f: (String, String, String) => String): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { s"${ev.value} = ${f(eval1, eval2, eval3)};" }) @@ -600,38 +575,113 @@ abstract class TernaryExpression extends Expression { * and returns Java code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodegenContext, - ev: ExprCode, - f: (String, String, String) => String): ExprCode = { - val leftGen = children(0).genCode(ctx) - val midGen = children(1).genCode(ctx) - val rightGen = children(2).genCode(ctx) - val resultCode = f(leftGen.value, midGen.value, rightGen.value) + ctx: CodegenContext, + ev: ExprCode, + f: (String, String, String) => String): ExprCode = { + val (childrenEval, childrenGen) = ExprCodegen.genCodeWithChildren(ctx, this) + val resultCode = f(childrenGen(0).value, childrenGen(1).value, childrenGen(2).value) + ExprCodegen.nullSafeCodeGen(ctx, ev, this, childrenEval, resultCode) + } +} - if (nullable) { - val nullSafeEval = - leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) { - midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) { - rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) { - s""" - ${ev.isNull} = false; // resultCode could change nullability. - $resultCode - """ - } +object ExprCodegen { + + val placeHolderForResultCode = "__PLACEHOLDER__" + + def isNotWholeStageCodegen(ctx: CodegenContext): Boolean = + ctx.INPUT_ROW != null && ctx.currentVars == null + + // Moves generated codes for child expression into a function. + // Only supports non whole stage codegen case. + def genChildCodeInFunction( + ctx: CodegenContext, + child: Expression, + childExprCode: ExprCode): Unit = { + if (isNotWholeStageCodegen(ctx)) { + val setIsNull = if (childExprCode.isNull != "false" && childExprCode.isNull != "true") { + val globalIsNull = ctx.freshName("globalIsNull") + ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") + val localIsNull = childExprCode.isNull + childExprCode.isNull = globalIsNull + s"$globalIsNull = $localIsNull;" + } else { + "" + } + + val javaType = ctx.javaType(child.dataType) + val newValue = ctx.freshName("value") + + val funcName = ctx.freshName(child.nodeName) + val funcFullName = ctx.addNewFunction(funcName, + s""" + |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { + | ${childExprCode.code.trim} + | $setIsNull + | return ${childExprCode.value}; + |} + """.stripMargin) + + childExprCode.value = newValue + childExprCode.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + } + } + + // Generates codes safe from 64k limit for children expressions. + def genCodeWithChildren( + ctx: CodegenContext, + expr: Expression): (String, Seq[ExprCode]) = { + val rawGenCode = expr.children.map(_.genCode(ctx)) + + val childrenEval = if (expr.nullable) { + var childIdx = expr.children.length - 1 + rawGenCode.foldRight(placeHolderForResultCode) { case (childCode, curCode) => + if (curCode.length + childCode.code.trim.length > 1024) { + genChildCodeInFunction(ctx, expr.children(childIdx), childCode) + } + val code = childCode.code + + ctx.nullSafeExec(expr.children(childIdx).nullable, childCode.isNull) { + curCode } + childIdx -= 1 + code } + } else { + var childIdx = expr.children.length - 1 + rawGenCode.foldRight("") { case (childCode, curCode) => + if (curCode.length + childCode.code.trim.length > 1024) { + genChildCodeInFunction(ctx, expr.children(childIdx), childCode) + } + val code = childCode.code + "\n" + curCode + childIdx -= 1 + code + } + } + (childrenEval, rawGenCode) + } + + def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + expr: Expression, + childrenEval: String, + resultCode: String): ExprCode = { + if (expr.nullable) { + val insertCode = s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + val nullSafeEval = childrenEval.replace(ExprCodegen.placeHolderForResultCode, insertCode) ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $nullSafeEval""") + ${ctx.javaType(expr.dataType)} ${ev.value} = ${ctx.defaultValue(expr.dataType)}; + $nullSafeEval + """) } else { ev.copy(code = s""" boolean ${ev.isNull} = false; - ${leftGen.code} - ${midGen.code} - ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $childrenEval + ${ctx.javaType(expr.dataType)} ${ev.value} = ${ctx.defaultValue(expr.dataType)}; $resultCode""", isNull = "false") } } From 250e8a685af8682defb8066823529c75f39e90df Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 19 Nov 2017 10:07:09 +0000 Subject: [PATCH 2/3] Add test. --- .../catalyst/expressions/CodeGenerationSuite.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 8f6289f00571..0a4ccf7cb46d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -380,4 +380,16 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd") } } + + test("SPARK-22551: Prevent possible 64kb compile error for common expression types") { + val N = 1800 + var addedExpr: Expression = Literal(1) + var expected = 1 + for (i <- 0 until N) { + addedExpr = Add(Literal(i), addedExpr) + expected += i + } + + checkEvaluation(Add(addedExpr, addedExpr), expected * 2, EmptyRow) + } } From 64e93ec47a4516a341ad9f39dab02266ce7b6e4d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Nov 2017 02:40:13 +0000 Subject: [PATCH 3/3] Reduce duplication. --- .../sql/catalyst/expressions/Expression.scala | 60 +++++++++++-------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8ddaed6875dc..511471b8237d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -586,7 +586,7 @@ abstract class TernaryExpression extends Expression { object ExprCodegen { - val placeHolderForResultCode = "__PLACEHOLDER__" + private val placeHolderForResultCode = "__PLACEHOLDER__" def isNotWholeStageCodegen(ctx: CodegenContext): Boolean = ctx.INPUT_ROW != null && ctx.currentVars == null @@ -626,39 +626,51 @@ object ExprCodegen { } } - // Generates codes safe from 64k limit for children expressions. + // Generating evaluation code for children expressions. We fold the code for each child + // expression from right to left. Once code length will be larger than the threshold, we + // put the code from next child expression into a method to prevent possible 64k limit. def genCodeWithChildren( ctx: CodegenContext, expr: Expression): (String, Seq[ExprCode]) = { - val rawGenCode = expr.children.map(_.genCode(ctx)) + val genCodeForChildren = expr.children.map(_.genCode(ctx)) - val childrenEval = if (expr.nullable) { - var childIdx = expr.children.length - 1 - rawGenCode.foldRight(placeHolderForResultCode) { case (childCode, curCode) => - if (curCode.length + childCode.code.trim.length > 1024) { - genChildCodeInFunction(ctx, expr.children(childIdx), childCode) - } - val code = childCode.code + + // For nullable expression, the code of children expression is wrapped in "if" blocks + // for null check. We leave a special placeholder string which will be replaced with + // evaluation code of this expression later. + val initCode = if (expr.nullable) { + placeHolderForResultCode + } else { + "" + } + + var childIdx = expr.children.length - 1 + val childrenEval = genCodeForChildren.foldRight(initCode) { case (childCode, curCode) => + if (curCode.length + childCode.code.trim.length > 1024) { + genChildCodeInFunction(ctx, expr.children(childIdx), childCode) + } + val code = if (expr.nullable) { + childCode.code + ctx.nullSafeExec(expr.children(childIdx).nullable, childCode.isNull) { curCode } - childIdx -= 1 - code - } - } else { - var childIdx = expr.children.length - 1 - rawGenCode.foldRight("") { case (childCode, curCode) => - if (curCode.length + childCode.code.trim.length > 1024) { - genChildCodeInFunction(ctx, expr.children(childIdx), childCode) - } - val code = childCode.code + "\n" + curCode - childIdx -= 1 - code + } else { + childCode.code + "\n" + curCode } + childIdx -= 1 + code } - (childrenEval, rawGenCode) + (childrenEval, genCodeForChildren) } + /** + * Generating evaluation code for binary and ternary expressions for now. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param childrenEval Evaluation code for all sub-expressions. + * @param resultCode Evaluation code for the current expression. The evaluation is based on + * the values of `childrenEval`. + */ def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, @@ -670,7 +682,7 @@ object ExprCodegen { ${ev.isNull} = false; // resultCode could change nullability. $resultCode """ - val nullSafeEval = childrenEval.replace(ExprCodegen.placeHolderForResultCode, insertCode) + val nullSafeEval = childrenEval.replace(placeHolderForResultCode, insertCode) ev.copy(code = s""" boolean ${ev.isNull} = true;