From ce74fb8baff13332bed8c64db6c2a49be815f81f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 1 Dec 2017 13:55:58 +0100 Subject: [PATCH] [SPARK-22669][SQL] Avoid unnecessary function calls in code generation --- .../expressions/nullExpressions.scala | 141 ++++++++++++------ .../sql/catalyst/expressions/predicates.scala | 68 ++++++--- 2 files changed, 140 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 173e171910b69..3b52a0efd404a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -75,23 +75,51 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) ctx.addMutableState(ctx.javaType(dataType), ev.value) + // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => val eval = e.genCode(ctx) s""" - if (${ev.isNull}) { - ${eval.code} - if (!${eval.isNull}) { - ${ev.isNull} = false; - ${ev.value} = ${eval.value}; - } - } - """ + |${eval.code} + |if (!${eval.isNull}) { + | ${ev.isNull} = false; + | ${ev.value} = ${eval.value}; + | continue; + |} + """.stripMargin } + val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + evals.mkString("\n") + } else { + ctx.splitExpressions(evals, "coalesce", + ("InternalRow", ctx.INPUT_ROW) :: Nil, + makeSplitFunction = { + func => + s""" + |do { + | $func + |} while (false); + """.stripMargin + }, + foldFunctions = { funcCalls => + funcCalls.map { funcCall => + s""" + |$funcCall; + |if (!${ev.isNull}) { + | continue; + |} + """.stripMargin + }.mkString + }) + } - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ctx.splitExpressions(evals)}""") + ev.copy(code = + s""" + |${ev.isNull} = true; + |${ev.value} = ${ctx.defaultValue(dataType)}; + |do { + | $code + |} while (false); + """.stripMargin) } } @@ -358,53 +386,70 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val nonnull = ctx.freshName("nonnull") + // all evals are meant to be inside a do { ... } while (false); loop val evals = children.map { e => val eval = e.genCode(ctx) e.dataType match { case DoubleType | FloatType => s""" - if ($nonnull < $n) { - ${eval.code} - if (!${eval.isNull} && !Double.isNaN(${eval.value})) { - $nonnull += 1; - } - } - """ + |if ($nonnull < $n) { + | ${eval.code} + | if (!${eval.isNull} && !Double.isNaN(${eval.value})) { + | $nonnull += 1; + | } + |} else { + | continue; + |} + """.stripMargin case _ => s""" - if ($nonnull < $n) { - ${eval.code} - if (!${eval.isNull}) { - $nonnull += 1; - } - } - """ + |if ($nonnull < $n) { + | ${eval.code} + | if (!${eval.isNull}) { + | $nonnull += 1; + | } + |} else { + | continue; + |} + """.stripMargin } } val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - evals.mkString("\n") - } else { - ctx.splitExpressions( - expressions = evals, - funcName = "atLeastNNonNulls", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil, - returnType = "int", - makeSplitFunction = { body => - s""" - $body - return $nonnull; - """ - }, - foldFunctions = { funcCalls => - funcCalls.map(funcCall => s"$nonnull = $funcCall;").mkString("\n") - } - ) - } + evals.mkString("\n") + } else { + ctx.splitExpressions( + expressions = evals, + funcName = "atLeastNNonNulls", + arguments = ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_INT, nonnull) :: Nil, + returnType = ctx.JAVA_INT, + makeSplitFunction = { body => + s""" + |do { + | $body + |} while (false); + |return $nonnull; + """.stripMargin + }, + foldFunctions = { funcCalls => + funcCalls.map(funcCall => + s""" + |$nonnull = $funcCall; + |if ($nonnull >= $n) { + | continue; + |} + """.stripMargin).mkString("\n") + } + ) + } - ev.copy(code = s""" - int $nonnull = 0; - $code - boolean ${ev.value} = $nonnull >= $n;""", isNull = "false") + ev.copy(code = + s""" + |${ctx.JAVA_INT} $nonnull = 0; + |do { + | $code + |} while (false); + |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; + """.stripMargin, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1aaaaf1db48d1..75cc9b3bd8045 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -234,36 +234,62 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaDataType = ctx.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value) ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) val valueArg = ctx.freshName("valueArg") + // All the blocks are meant to be inside a do { ... } while (false); loop. + // The evaluation of variables can be stopped when we find a matching value. val listCode = listGen.map(x => s""" - if (!${ev.value}) { - ${x.code} - if (${x.isNull}) { - ${ev.isNull} = true; - } else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { - ${ev.isNull} = false; - ${ev.value} = true; + |${x.code} + |if (${x.isNull}) { + | ${ev.isNull} = true; + |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | continue; + |} + """.stripMargin) + val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + listCode.mkString("\n") + } else { + ctx.splitExpressions( + expressions = listCode, + funcName = "valueIn", + arguments = ("InternalRow", ctx.INPUT_ROW) :: (javaDataType, valueArg) :: Nil, + makeSplitFunction = { body => + s""" + |do { + | $body + |} while (false); + """.stripMargin + }, + foldFunctions = { funcCalls => + funcCalls.map(funcCall => + s""" + |$funcCall; + |if (${ev.value}) { + | continue; + |} + """.stripMargin).mkString("\n") } - } - """) - val listCodes = ctx.splitExpressions( - expressions = listCode, - funcName = "valueIn", - extraArguments = (ctx.javaType(value.dataType), valueArg) :: Nil) - ev.copy(code = s""" - ${valueGen.code} - ${ev.value} = false; - ${ev.isNull} = ${valueGen.isNull}; - if (!${ev.isNull}) { - ${ctx.javaType(value.dataType)} $valueArg = ${valueGen.value}; - $listCodes + ) } - """) + ev.copy(code = + s""" + |${valueGen.code} + |${ev.value} = false; + |${ev.isNull} = ${valueGen.isNull}; + |if (!${ev.isNull}) { + | $javaDataType $valueArg = ${valueGen.value}; + | do { + | $code + | } while (false); + |} + """.stripMargin) } override def sql: String = {