From 34abc2284be485c12720437e969cf41394dfc2b5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Nov 2017 02:16:47 +0000 Subject: [PATCH 01/23] Support wholestage codegen for reducing expression codes to prevent 64k limit. --- .../sql/catalyst/expressions/Expression.scala | 140 +++++++++++++++++- .../execution/WholeStageCodegenSuite.scala | 21 ++- 2 files changed, 156 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 743782a6453e9..d5809d13fa956 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -115,9 +117,119 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Returns the eliminated subexpressions in the children expressions. + */ + private def getSubExprInChildren(ctx: CodegenContext): Seq[Expression] = { + children.flatMap { child => + child.collect { + case e if ctx.subExprEliminationExprs.contains(e) => e + } + } + } + + /** + * Given the list of eliminated subexpressions used in the children expressions, returns the + * strings of funtion parameters. The first is the variable names used to call the function, + * the second is the parameters used to declare the function in generated code. + */ + private def getParamsForSubExprs( + ctx: CodegenContext, + subExprs: Seq[Expression]): (Seq[String], Seq[String]) = { + subExprs.map { subExpr => + val arguType = ctx.javaType(subExpr.dataType) + + val subExprState = ctx.subExprEliminationExprs(subExpr) + (subExprState.value, subExprState.isNull) + + if (!subExpr.nullable || subExprState.isNull == "true" || subExprState.isNull == "false") { + (subExprState.value, s"$arguType ${subExprState.value}") + } else { + (subExprState.value + ", " + subExprState.isNull, + s"$arguType ${subExprState.value}, boolean ${subExprState.isNull}") + } + }.unzip + } + + /** + * Finds the bound attributes and corresponding input variables under wholestage codegen. + * If the input variables are not evaluated yet, don't need to include them into parameters, + */ + private def getInputVars(ctx: CodegenContext): (Seq[Expression], Seq[ExprCode]) = { + if (ctx.currentVars == null) { + (Seq.empty, Seq.empty) + } else { + children.flatMap(_.collect { + case b @ BoundReference(ordinal, dt, nullable) if ctx.currentVars(ordinal) != null && + ctx.currentVars(ordinal).code == "" => + (b, ctx.currentVars(ordinal)) + }).distinct.unzip + } + } + + /** + * Helper function to calculate the size of an expression as function parameter. + */ + private def calculateParamLength(ctx: CodegenContext, input: Expression): Int = { + ctx.javaType(input.dataType) match { + case (ctx.JAVA_LONG | ctx.JAVA_DOUBLE) if !input.nullable => 2 + case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 3 + case _ if !input.nullable => 1 + case _ => 2 + } + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length of 255 or less. `this` contributes one unit and a parameter of type long or double + * contributes two units. + */ + private def isValidParamLength( + ctx: CodegenContext, + inputs: Seq[Expression], + subExprs: Seq[Expression]): Boolean = { + // Start value is 1 for `this`. + inputs.foldLeft(1) { case (curLength, input) => + curLength + calculateParamLength(ctx, input) + } + subExprs.foldLeft(0) { case (curLength, subExpr) => + curLength + calculateParamLength(ctx, subExpr) + } <= 255 + } + + /** + * Given the lists of input attributes and variables to this expression, returns the strings of + * funtion parameters. The first is the variable names used to call the function, the second is + * the parameters used to declare the function in generated code. + */ + private def prepareFunctionParams( + ctx: CodegenContext, + inputAttrs: Seq[Expression], + inputVars: Seq[ExprCode]): (Seq[String], Seq[String]) = { + inputAttrs.zip(inputVars).map { case (input, ev) => + val arguType = ctx.javaType(input.dataType) + + if (!input.nullable || ev.isNull == "true" || ev.isNull == "false") { + (ev.value, s"$arguType ${ev.value}") + } else { + (ev.value + ", " + ev.isNull, s"$arguType ${ev.value}, boolean ${ev.isNull}") + } + }.unzip + } + + /** + * In order to prevent 64kb compile error, reducing the size of generated codes by + * separating it into a function if the size exceeds a threshold. + */ private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { - // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + val isNotWholestageCodegen = ctx.INPUT_ROW != null && ctx.currentVars == null + val (inputAttrs, inputVars) = getInputVars(ctx) + val subExprs = getSubExprInChildren(ctx) + val isValidParams = isValidParamLength(ctx, inputAttrs, subExprs) + + // Puts code into a function if the code is big, when: + // 1. Not in wholestage codegen, or + // 2. Parameter number is allowed for Java's method descriptor. + if (eval.code.trim.length > 1024 && (isNotWholestageCodegen || isValidParams)) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { val globalIsNull = ctx.freshName("globalIsNull") ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) @@ -131,10 +243,30 @@ abstract class Expression extends TreeNode[Expression] { val javaType = ctx.javaType(dataType) val newValue = ctx.freshName("value") + val callParams = mutable.ArrayBuffer[String]() + val funcParams = mutable.ArrayBuffer[String]() + + if (ctx.INPUT_ROW != null) { + callParams += ctx.INPUT_ROW + funcParams += s"InternalRow ${ctx.INPUT_ROW}" + } + + if (inputAttrs.length > 0) { + val params = prepareFunctionParams(ctx, inputAttrs, inputVars) + params._1.foreach(callParams += _) + params._2.foreach(funcParams += _) + } + + if (subExprs.length > 0) { + val subExprParams = getParamsForSubExprs(ctx, subExprs) + subExprParams._1.foreach(callParams += _) + subExprParams._2.foreach(funcParams += _) + } + val funcName = ctx.freshName(nodeName) val funcFullName = ctx.addNewFunction(funcName, s""" - |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { + |private $javaType $funcName(${funcParams.mkString(", ")}) { | ${eval.code.trim} | $setIsNull | return ${eval.value}; @@ -142,7 +274,7 @@ abstract class Expression extends TreeNode[Expression] { """.stripMargin) eval.value = newValue - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = s"$javaType $newValue = $funcFullName(${callParams.mkString(", ")});" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca578c47..7afd704d7f34d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{QueryTest, Row, SaveMode} +import org.apache.spark.sql.{Column, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec @@ -236,4 +237,22 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-22551: Fix 64kb limit for deeply nested expressions under wholestage codegen") { + import testImplicits._ + withTempPath { dir => + val path = dir.getCanonicalPath + val df = Seq(("abc", 1)).toDF("key", "int") + df.write.parquet(path) + + var strExpr: Expression = col("key").expr + for (_ <- 1 to 150) { + strExpr = Decode(Encode(strExpr, Literal("utf-8")), Literal("utf-8")) + } + val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) + + val df2 = spark.read.parquet(path).select(expressions.map(Column(_)): _*) + df2.collect() + } + } } From 65d07d525344e1d00457d2f538b2ef0b1c38a8e8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 25 Nov 2017 08:09:22 +0000 Subject: [PATCH 02/23] Assert the added test is under wholestage codegen. --- .../apache/spark/sql/execution/WholeStageCodegenExec.scala | 6 ++---- .../apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 ++ 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 7166b7771e4db..55957b9c9d97a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -135,7 +135,7 @@ trait CodegenSupport extends SparkPlan { } val evaluateInputs = evaluateVariables(outputVars) // generate the code to create a UnsafeRow - ctx.INPUT_ROW = row + ctx.INPUT_ROW = null ctx.currentVars = outputVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) val code = s""" @@ -150,10 +150,8 @@ trait CodegenSupport extends SparkPlan { } // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` - // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to - // generate code of `rowVar` manually. + // before calling `parent.doConsume`. ctx.currentVars = inputVars - ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 7afd704d7f34d..1281169b607c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -252,6 +252,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) val df2 = spark.read.parquet(path).select(expressions.map(Column(_)): _*) + val plan = df2.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) df2.collect() } } From 9f848be45dcc294d6f27f2c6eaeed1907d36f004 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 27 Nov 2017 07:37:02 +0000 Subject: [PATCH 03/23] Put input rows and evaluated columns referred by deferred expressions into parameter list. --- .../catalyst/expressions/BoundAttribute.scala | 6 +- .../sql/catalyst/expressions/Expression.scala | 168 ++++++++++++++++-- .../expressions/codegen/CodeGenerator.scala | 14 +- 3 files changed, 167 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6a17a397b3ef2..e347ee4ab387c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -61,7 +61,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { val oev = ctx.currentVars(ordinal) - ev.isNull = oev.isNull + if (nullable) { + ev.isNull = oev.isNull + } else { + ev.isNull = "false" + } ev.value = oev.value ev.copy(code = oev.code) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d5809d13fa956..ac70ed7cfe1e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -107,6 +107,7 @@ abstract class Expression extends TreeNode[Expression] { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val eval = doGenCode(ctx, ExprCode("", isNull, value)) + populateInputs(ctx, eval) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -117,6 +118,27 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Records current input row and variables for this expression into created `ExprCode`. + */ + private def populateInputs(ctx: CodegenContext, eval: ExprCode): Unit = { + if (ctx.INPUT_ROW != null) { + eval.inputRow = ctx.INPUT_ROW + } + if (ctx.currentVars != null) { + val boundRefs = children.flatMap(_.collect { + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => (ordinal, b) + }).toMap + + ctx.currentVars.zipWithIndex.filter(_._1 != null).foreach { case (currentVar, idx) => + if (boundRefs.contains(idx)) { + val inputVar = ExprInputVar(boundRefs(idx), exprCode = currentVar) + eval.inputVars += inputVar + } + } + } + } + /** * Returns the eliminated subexpressions in the children expressions. */ @@ -152,19 +174,113 @@ abstract class Expression extends TreeNode[Expression] { } /** - * Finds the bound attributes and corresponding input variables under wholestage codegen. - * If the input variables are not evaluated yet, don't need to include them into parameters, + * Retrieves previous input rows referred by children and deferred expressions. + */ + private def getInputRowsForChildren(ctx: CodegenContext): Seq[String] = { + children.flatMap(getInputRows(ctx, _)).distinct + } + + /** + * Given a child expression, retrieves previous input rows referred by it or deferred expressions + * which are needed to evaluate it. + */ + private def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = { + child.flatMap { + // An expression directly evaluates on current input row. + case BoundReference(ordinal, _, _) if ctx.currentVars == null || + ctx.currentVars(ordinal) == null => + Seq(ctx.INPUT_ROW) + + // An expression which is not evaluated yet. Tracks down to find input rows. + case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal).code != "" => + trackDownRow(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down input rows referred by the generated code snippet. + */ + private def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = { + var exprCodes: List[ExprCode] = List(exprCode) + val inputRows = mutable.ArrayBuffer.empty[String] + + while (exprCodes.nonEmpty) { + exprCodes match { + case first :: others => + exprCodes = others + if (first.inputRow != null) { + inputRows += first.inputRow + } + first.inputVars.foreach { inputVar => + if (inputVar.exprCode.code != "") { + exprCodes = inputVar.exprCode :: exprCodes + } + } + case _ => + } + } + inputRows.toSeq + } + + /** + * Retrieves previously evaluated columns referred by children and deferred expressions. + * Returned tuple contains the list of expressions and the list of generated codes. + */ + private def getInputVarsForChildren(ctx: CodegenContext): (Seq[Expression], Seq[ExprCode]) = { + children.flatMap(getInputVars(ctx, _)).distinct.unzip + } + + /** + * Given a child expression, retrieves previously evaluated columns referred by it or + * deferred expressions which are needed to evaluate it. */ - private def getInputVars(ctx: CodegenContext): (Seq[Expression], Seq[ExprCode]) = { + private def getInputVars(ctx: CodegenContext, child: Expression): Seq[(Expression, ExprCode)] = { if (ctx.currentVars == null) { - (Seq.empty, Seq.empty) - } else { - children.flatMap(_.collect { - case b @ BoundReference(ordinal, dt, nullable) if ctx.currentVars(ordinal) != null && - ctx.currentVars(ordinal).code == "" => - (b, ctx.currentVars(ordinal)) - }).distinct.unzip + return Seq.empty + } + + child.flatMap { + // An evaluated variable. + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null && + ctx.currentVars(ordinal).code == "" => + Seq((b, ctx.currentVars(ordinal))) + + // An input variable which is not evaluated yet. Tracks down to find any evaluated variables + // in the expression path. + // E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to + // "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so + // to include them into parameters, if not, we tract down further. + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => + trackDownVar(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down previously evaluated columns referred by the generated code snippet. + */ + private def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[(Expression, ExprCode)] = { + var exprCodes: List[ExprCode] = List(exprCode) + val inputVars = mutable.ArrayBuffer.empty[(Expression, ExprCode)] + + while (exprCodes.nonEmpty) { + exprCodes match { + case first :: others => + exprCodes = others + first.inputVars.foreach { inputVar => + if (inputVar.exprCode.code == "") { + inputVars += ((inputVar.expr, inputVar.exprCode)) + } else { + exprCodes = inputVar.exprCode :: exprCodes + } + } + case _ => + } } + inputVars.toSeq } /** @@ -184,16 +300,16 @@ abstract class Expression extends TreeNode[Expression] { * length of 255 or less. `this` contributes one unit and a parameter of type long or double * contributes two units. */ - private def isValidParamLength( + private def getValidParamLength( ctx: CodegenContext, inputs: Seq[Expression], - subExprs: Seq[Expression]): Boolean = { + subExprs: Seq[Expression]): Int = { // Start value is 1 for `this`. inputs.foldLeft(1) { case (curLength, input) => curLength + calculateParamLength(ctx, input) } + subExprs.foldLeft(0) { case (curLength, subExpr) => curLength + calculateParamLength(ctx, subExpr) - } <= 255 + } } /** @@ -222,14 +338,20 @@ abstract class Expression extends TreeNode[Expression] { */ private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { val isNotWholestageCodegen = ctx.INPUT_ROW != null && ctx.currentVars == null - val (inputAttrs, inputVars) = getInputVars(ctx) + val (inputAttrs, inputVars) = getInputVarsForChildren(ctx) + val inputRows = getInputRowsForChildren(ctx) val subExprs = getSubExprInChildren(ctx) - val isValidParams = isValidParamLength(ctx, inputAttrs, subExprs) + + // Params to include: + // 1. Evaluated columns referred by this, children or deferred expressions. + // 2. Input rows referred by this, children or deferred expressions. + // 3. Eliminated subexpressions. + val paramsLength = getValidParamLength(ctx, inputAttrs, subExprs) + inputRows.length // Puts code into a function if the code is big, when: // 1. Not in wholestage codegen, or - // 2. Parameter number is allowed for Java's method descriptor. - if (eval.code.trim.length > 1024 && (isNotWholestageCodegen || isValidParams)) { + // 2. Allowed parameter number for Java's method descriptor. + if (eval.code.trim.length > 1024 && (isNotWholestageCodegen || paramsLength <= 255)) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { val globalIsNull = ctx.freshName("globalIsNull") ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) @@ -243,6 +365,7 @@ abstract class Expression extends TreeNode[Expression] { val javaType = ctx.javaType(dataType) val newValue = ctx.freshName("value") + // Prepare function parameters. val callParams = mutable.ArrayBuffer[String]() val funcParams = mutable.ArrayBuffer[String]() @@ -251,6 +374,13 @@ abstract class Expression extends TreeNode[Expression] { funcParams += s"InternalRow ${ctx.INPUT_ROW}" } + if (inputRows.length > 0) { + inputRows.foreach { row => + callParams += row + funcParams += s"InternalRow $row" + } + } + if (inputAttrs.length > 0) { val params = prepareFunctionParams(ctx, inputAttrs, inputVars) params._1.foreach(callParams += _) @@ -266,7 +396,7 @@ abstract class Expression extends TreeNode[Expression] { val funcName = ctx.freshName(nodeName) val funcFullName = ctx.addNewFunction(funcName, s""" - |private $javaType $funcName(${funcParams.mkString(", ")}) { + |private $javaType $funcName(${funcParams.distinct.mkString(", ")}) { | ${eval.code.trim} | $setIsNull | return ${eval.value}; @@ -274,7 +404,7 @@ abstract class Expression extends TreeNode[Expression] { """.stripMargin) eval.value = newValue - eval.code = s"$javaType $newValue = $funcFullName(${callParams.mkString(", ")});" + eval.code = s"$javaType $newValue = $funcFullName(${callParams.distinct.mkString(", ")});" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0498e61819f48..98b2303f97445 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,8 +55,20 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * to null. * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. + * @param inputRow A term that holds the input row name when generating this code. + * @param inputVars A list of [[ExprInputVar]] that holds input variables when generating this code. */ -case class ExprCode(var code: String, var isNull: String, var value: String) +case class ExprCode( + var code: String, + var isNull: String, + var value: String, + var inputRow: String = null, + val inputVars: mutable.ArrayBuffer[ExprInputVar] = mutable.ArrayBuffer.empty) + +/** + * Represents an input variable that holds the java type and the [[ExprCode]]. + */ +case class ExprInputVar(val expr: Expression, val exprCode: ExprCode) /** * State used for subexpression elimination. From 57b1add4df4648862c76165f8ae10cc487af1221 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 27 Nov 2017 08:53:42 +0000 Subject: [PATCH 04/23] Revert unnecessary changes. --- .../apache/spark/sql/execution/WholeStageCodegenExec.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 55957b9c9d97a..7166b7771e4db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -135,7 +135,7 @@ trait CodegenSupport extends SparkPlan { } val evaluateInputs = evaluateVariables(outputVars) // generate the code to create a UnsafeRow - ctx.INPUT_ROW = null + ctx.INPUT_ROW = row ctx.currentVars = outputVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) val code = s""" @@ -150,8 +150,10 @@ trait CodegenSupport extends SparkPlan { } // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` - // before calling `parent.doConsume`. + // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to + // generate code of `rowVar` manually. ctx.currentVars = inputVars + ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) s""" From d051f9eef4d03f9027571419857f690c866dbd98 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 28 Nov 2017 02:35:16 +0000 Subject: [PATCH 05/23] Fix subexpression isNull for non nullable case. Fix columnar batch scan's rowIdx. --- .../expressions/codegen/CodeGenerator.scala | 29 +++++++++++++++---- .../sql/execution/ColumnarBatchScan.scala | 4 +++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 98b2303f97445..4305f325bb85b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -991,7 +991,11 @@ class CodegenContext { val expr = e.head // Generate the code for this expression tree. val eval = expr.genCode(this) - val state = SubExprEliminationState(eval.isNull, eval.value) + val state = if (expr.nullable) { + SubExprEliminationState(eval.isNull, eval.value) + } else { + SubExprEliminationState("false", eval.value) + } e.foreach(subExprEliminationExprs.put(_, state)) eval.code.trim } @@ -1013,16 +1017,25 @@ class CodegenContext { commonExprs.foreach { e => val expr = e.head val fnName = freshName("evalExpr") - val isNull = s"${fnName}IsNull" + val isNull = if (expr.nullable) { + s"${fnName}IsNull" + } else { + "" + } val value = s"${fnName}Value" // Generate the code for this expression tree and wrap it in a function. val eval = expr.genCode(this) + val nullValue = if (expr.nullable) { + s"$isNull = ${eval.isNull};" + } else { + "" + } val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { | ${eval.code.trim} - | $isNull = ${eval.isNull}; + | $nullValue | $value = ${eval.value}; |} """.stripMargin @@ -1040,12 +1053,18 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;") + if (expr.nullable) { + addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;") + } addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + val state = if (expr.nullable) { + SubExprEliminationState(isNull, value) + } else { + SubExprEliminationState("false", value) + } e.foreach(subExprEliminationExprs.put(_, state)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index a9bfb634fbdea..5ab3d5000c470 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -112,6 +112,9 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } + // `rowIdx` is a special variable which can't be referred if the parent nodes split expressions. + // So we evaluate column outputs right away. + val evalColumnsBatchInput = evaluateVariables(columnsBatchInput) val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val numRows = ctx.freshName("numRows") @@ -129,6 +132,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { | int $localEnd = $numRows - $idx; | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { | int $rowidx = $idx + $localIdx; + | $evalColumnsBatchInput | ${consume(ctx, columnsBatchInput).trim} | $shouldStop | } From 6368702e66948e26c41300da7136dffc5b963cb6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 28 Nov 2017 03:12:22 +0000 Subject: [PATCH 06/23] Let rowidx as global variable instead of early evaluation of column output. --- .../org/apache/spark/sql/execution/ColumnarBatchScan.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 5ab3d5000c470..dd3373e6a4142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -109,12 +109,10 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { ctx.currentVars = null val rowidx = ctx.freshName("rowIdx") + ctx.addMutableState(ctx.JAVA_INT, rowidx) val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } - // `rowIdx` is a special variable which can't be referred if the parent nodes split expressions. - // So we evaluate column outputs right away. - val evalColumnsBatchInput = evaluateVariables(columnsBatchInput) val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val numRows = ctx.freshName("numRows") @@ -131,8 +129,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { | int $numRows = $batch.numRows(); | int $localEnd = $numRows - $idx; | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | int $rowidx = $idx + $localIdx; - | $evalColumnsBatchInput + | $rowidx = $idx + $localIdx; | ${consume(ctx, columnsBatchInput).trim} | $shouldStop | } From 8c7f7496e610fdf4b512c57efd108ccf0238b126 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 28 Nov 2017 14:55:46 +0000 Subject: [PATCH 07/23] Fix the problematic case. --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ac70ed7cfe1e1..e4adaf0c90a36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -126,9 +126,9 @@ abstract class Expression extends TreeNode[Expression] { eval.inputRow = ctx.INPUT_ROW } if (ctx.currentVars != null) { - val boundRefs = children.flatMap(_.collect { + val boundRefs = this.collect { case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => (ordinal, b) - }).toMap + }.toMap ctx.currentVars.zipWithIndex.filter(_._1 != null).foreach { case (currentVar, idx) => if (boundRefs.contains(idx)) { From 7f005158b7b10fb2dc4db3ed15181e68ae33348f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 29 Nov 2017 15:52:55 +0000 Subject: [PATCH 08/23] Fix duplicate parameters. --- .../sql/catalyst/expressions/Expression.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e4adaf0c90a36..f763e83d07d63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -158,17 +158,17 @@ abstract class Expression extends TreeNode[Expression] { private def getParamsForSubExprs( ctx: CodegenContext, subExprs: Seq[Expression]): (Seq[String], Seq[String]) = { - subExprs.map { subExpr => + subExprs.flatMap { subExpr => val arguType = ctx.javaType(subExpr.dataType) val subExprState = ctx.subExprEliminationExprs(subExpr) (subExprState.value, subExprState.isNull) if (!subExpr.nullable || subExprState.isNull == "true" || subExprState.isNull == "false") { - (subExprState.value, s"$arguType ${subExprState.value}") + Seq((subExprState.value, s"$arguType ${subExprState.value}")) } else { - (subExprState.value + ", " + subExprState.isNull, - s"$arguType ${subExprState.value}, boolean ${subExprState.isNull}") + Seq((subExprState.value, s"$arguType ${subExprState.value}"), + (subExprState.isNull, s"boolean ${subExprState.isNull}")) } }.unzip } @@ -321,13 +321,13 @@ abstract class Expression extends TreeNode[Expression] { ctx: CodegenContext, inputAttrs: Seq[Expression], inputVars: Seq[ExprCode]): (Seq[String], Seq[String]) = { - inputAttrs.zip(inputVars).map { case (input, ev) => + inputAttrs.zip(inputVars).flatMap { case (input, ev) => val arguType = ctx.javaType(input.dataType) if (!input.nullable || ev.isNull == "true" || ev.isNull == "false") { - (ev.value, s"$arguType ${ev.value}") + Seq((ev.value, s"$arguType ${ev.value}")) } else { - (ev.value + ", " + ev.isNull, s"$arguType ${ev.value}, boolean ${ev.isNull}") + Seq((ev.value, s"$arguType ${ev.value}"), (ev.isNull, s"boolean ${ev.isNull}")) } }.unzip } From 777eb7a0c4db6695ee993be7b5d3b2d40c161591 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Nov 2017 02:00:24 +0000 Subject: [PATCH 09/23] Address comments. --- .../sql/catalyst/expressions/Expression.scala | 14 +++++++------- .../expressions/codegen/CodeGenerator.scala | 2 +- .../spark/sql/execution/ColumnarBatchScan.scala | 2 ++ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f763e83d07d63..aae639e4f97c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -159,15 +159,15 @@ abstract class Expression extends TreeNode[Expression] { ctx: CodegenContext, subExprs: Seq[Expression]): (Seq[String], Seq[String]) = { subExprs.flatMap { subExpr => - val arguType = ctx.javaType(subExpr.dataType) + val argType = ctx.javaType(subExpr.dataType) val subExprState = ctx.subExprEliminationExprs(subExpr) (subExprState.value, subExprState.isNull) if (!subExpr.nullable || subExprState.isNull == "true" || subExprState.isNull == "false") { - Seq((subExprState.value, s"$arguType ${subExprState.value}")) + Seq((subExprState.value, s"$argType ${subExprState.value}")) } else { - Seq((subExprState.value, s"$arguType ${subExprState.value}"), + Seq((subExprState.value, s"$argType ${subExprState.value}"), (subExprState.isNull, s"boolean ${subExprState.isNull}")) } }.unzip @@ -251,7 +251,7 @@ abstract class Expression extends TreeNode[Expression] { // in the expression path. // E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to // "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so - // to include them into parameters, if not, we tract down further. + // to include them into parameters, if not, we track down further. case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => trackDownVar(ctx, ctx.currentVars(ordinal)) @@ -322,12 +322,12 @@ abstract class Expression extends TreeNode[Expression] { inputAttrs: Seq[Expression], inputVars: Seq[ExprCode]): (Seq[String], Seq[String]) = { inputAttrs.zip(inputVars).flatMap { case (input, ev) => - val arguType = ctx.javaType(input.dataType) + val argType = ctx.javaType(input.dataType) if (!input.nullable || ev.isNull == "true" || ev.isNull == "false") { - Seq((ev.value, s"$arguType ${ev.value}")) + Seq((ev.value, s"$argType ${ev.value}")) } else { - Seq((ev.value, s"$arguType ${ev.value}"), (ev.isNull, s"boolean ${ev.isNull}")) + Seq((ev.value, s"$argType ${ev.value}"), (ev.isNull, s"boolean ${ev.isNull}")) } }.unzip } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4305f325bb85b..6e24bffc6571e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1054,7 +1054,7 @@ class CodegenContext { // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. if (expr.nullable) { - addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;") + addMutableState(JAVA_BOOLEAN, isNull) } addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index dd3373e6a4142..05186c4472566 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -108,6 +108,8 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { |}""".stripMargin) ctx.currentVars = null + // `rowIdx` isn't in `ctx.currentVars`. If the expressions are split later, we can't track it. + // So making it as global variable. val rowidx = ctx.freshName("rowIdx") ctx.addMutableState(ctx.JAVA_INT, rowidx) val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => From 7230997a54babaf62846ab538bb6756b3938d832 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Nov 2017 04:11:18 +0000 Subject: [PATCH 10/23] Polish the patch. --- .../sql/catalyst/expressions/Expression.scala | 248 +---------------- .../codegen/ExpressionCodegen.scala | 259 ++++++++++++++++++ 2 files changed, 267 insertions(+), 240 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index aae639e4f97c3..7925544c42d0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import scala.collection.mutable - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -139,219 +137,14 @@ abstract class Expression extends TreeNode[Expression] { } } - /** - * Returns the eliminated subexpressions in the children expressions. - */ - private def getSubExprInChildren(ctx: CodegenContext): Seq[Expression] = { - children.flatMap { child => - child.collect { - case e if ctx.subExprEliminationExprs.contains(e) => e - } - } - } - - /** - * Given the list of eliminated subexpressions used in the children expressions, returns the - * strings of funtion parameters. The first is the variable names used to call the function, - * the second is the parameters used to declare the function in generated code. - */ - private def getParamsForSubExprs( - ctx: CodegenContext, - subExprs: Seq[Expression]): (Seq[String], Seq[String]) = { - subExprs.flatMap { subExpr => - val argType = ctx.javaType(subExpr.dataType) - - val subExprState = ctx.subExprEliminationExprs(subExpr) - (subExprState.value, subExprState.isNull) - - if (!subExpr.nullable || subExprState.isNull == "true" || subExprState.isNull == "false") { - Seq((subExprState.value, s"$argType ${subExprState.value}")) - } else { - Seq((subExprState.value, s"$argType ${subExprState.value}"), - (subExprState.isNull, s"boolean ${subExprState.isNull}")) - } - }.unzip - } - - /** - * Retrieves previous input rows referred by children and deferred expressions. - */ - private def getInputRowsForChildren(ctx: CodegenContext): Seq[String] = { - children.flatMap(getInputRows(ctx, _)).distinct - } - - /** - * Given a child expression, retrieves previous input rows referred by it or deferred expressions - * which are needed to evaluate it. - */ - private def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = { - child.flatMap { - // An expression directly evaluates on current input row. - case BoundReference(ordinal, _, _) if ctx.currentVars == null || - ctx.currentVars(ordinal) == null => - Seq(ctx.INPUT_ROW) - - // An expression which is not evaluated yet. Tracks down to find input rows. - case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal).code != "" => - trackDownRow(ctx, ctx.currentVars(ordinal)) - - case _ => Seq.empty - }.distinct - } - - /** - * Tracks down input rows referred by the generated code snippet. - */ - private def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = { - var exprCodes: List[ExprCode] = List(exprCode) - val inputRows = mutable.ArrayBuffer.empty[String] - - while (exprCodes.nonEmpty) { - exprCodes match { - case first :: others => - exprCodes = others - if (first.inputRow != null) { - inputRows += first.inputRow - } - first.inputVars.foreach { inputVar => - if (inputVar.exprCode.code != "") { - exprCodes = inputVar.exprCode :: exprCodes - } - } - case _ => - } - } - inputRows.toSeq - } - - /** - * Retrieves previously evaluated columns referred by children and deferred expressions. - * Returned tuple contains the list of expressions and the list of generated codes. - */ - private def getInputVarsForChildren(ctx: CodegenContext): (Seq[Expression], Seq[ExprCode]) = { - children.flatMap(getInputVars(ctx, _)).distinct.unzip - } - - /** - * Given a child expression, retrieves previously evaluated columns referred by it or - * deferred expressions which are needed to evaluate it. - */ - private def getInputVars(ctx: CodegenContext, child: Expression): Seq[(Expression, ExprCode)] = { - if (ctx.currentVars == null) { - return Seq.empty - } - - child.flatMap { - // An evaluated variable. - case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null && - ctx.currentVars(ordinal).code == "" => - Seq((b, ctx.currentVars(ordinal))) - - // An input variable which is not evaluated yet. Tracks down to find any evaluated variables - // in the expression path. - // E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to - // "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so - // to include them into parameters, if not, we track down further. - case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => - trackDownVar(ctx, ctx.currentVars(ordinal)) - - case _ => Seq.empty - }.distinct - } - - /** - * Tracks down previously evaluated columns referred by the generated code snippet. - */ - private def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[(Expression, ExprCode)] = { - var exprCodes: List[ExprCode] = List(exprCode) - val inputVars = mutable.ArrayBuffer.empty[(Expression, ExprCode)] - - while (exprCodes.nonEmpty) { - exprCodes match { - case first :: others => - exprCodes = others - first.inputVars.foreach { inputVar => - if (inputVar.exprCode.code == "") { - inputVars += ((inputVar.expr, inputVar.exprCode)) - } else { - exprCodes = inputVar.exprCode :: exprCodes - } - } - case _ => - } - } - inputVars.toSeq - } - - /** - * Helper function to calculate the size of an expression as function parameter. - */ - private def calculateParamLength(ctx: CodegenContext, input: Expression): Int = { - ctx.javaType(input.dataType) match { - case (ctx.JAVA_LONG | ctx.JAVA_DOUBLE) if !input.nullable => 2 - case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 3 - case _ if !input.nullable => 1 - case _ => 2 - } - } - - /** - * In Java, a method descriptor is valid only if it represents method parameters with a total - * length of 255 or less. `this` contributes one unit and a parameter of type long or double - * contributes two units. - */ - private def getValidParamLength( - ctx: CodegenContext, - inputs: Seq[Expression], - subExprs: Seq[Expression]): Int = { - // Start value is 1 for `this`. - inputs.foldLeft(1) { case (curLength, input) => - curLength + calculateParamLength(ctx, input) - } + subExprs.foldLeft(0) { case (curLength, subExpr) => - curLength + calculateParamLength(ctx, subExpr) - } - } - - /** - * Given the lists of input attributes and variables to this expression, returns the strings of - * funtion parameters. The first is the variable names used to call the function, the second is - * the parameters used to declare the function in generated code. - */ - private def prepareFunctionParams( - ctx: CodegenContext, - inputAttrs: Seq[Expression], - inputVars: Seq[ExprCode]): (Seq[String], Seq[String]) = { - inputAttrs.zip(inputVars).flatMap { case (input, ev) => - val argType = ctx.javaType(input.dataType) - - if (!input.nullable || ev.isNull == "true" || ev.isNull == "false") { - Seq((ev.value, s"$argType ${ev.value}")) - } else { - Seq((ev.value, s"$argType ${ev.value}"), (ev.isNull, s"boolean ${ev.isNull}")) - } - }.unzip - } - /** * In order to prevent 64kb compile error, reducing the size of generated codes by * separating it into a function if the size exceeds a threshold. */ private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { - val isNotWholestageCodegen = ctx.INPUT_ROW != null && ctx.currentVars == null - val (inputAttrs, inputVars) = getInputVarsForChildren(ctx) - val inputRows = getInputRowsForChildren(ctx) - val subExprs = getSubExprInChildren(ctx) - - // Params to include: - // 1. Evaluated columns referred by this, children or deferred expressions. - // 2. Input rows referred by this, children or deferred expressions. - // 3. Eliminated subexpressions. - val paramsLength = getValidParamLength(ctx, inputAttrs, subExprs) + inputRows.length - - // Puts code into a function if the code is big, when: - // 1. Not in wholestage codegen, or - // 2. Allowed parameter number for Java's method descriptor. - if (eval.code.trim.length > 1024 && (isNotWholestageCodegen || paramsLength <= 255)) { + val funcParams = ExpressionCodegen.getExpressionInputParams(ctx, this) + + if (eval.code.trim.length > 1024 && funcParams.isDefined) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { val globalIsNull = ctx.freshName("globalIsNull") ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) @@ -365,38 +158,13 @@ abstract class Expression extends TreeNode[Expression] { val javaType = ctx.javaType(dataType) val newValue = ctx.freshName("value") - // Prepare function parameters. - val callParams = mutable.ArrayBuffer[String]() - val funcParams = mutable.ArrayBuffer[String]() - - if (ctx.INPUT_ROW != null) { - callParams += ctx.INPUT_ROW - funcParams += s"InternalRow ${ctx.INPUT_ROW}" - } - - if (inputRows.length > 0) { - inputRows.foreach { row => - callParams += row - funcParams += s"InternalRow $row" - } - } - - if (inputAttrs.length > 0) { - val params = prepareFunctionParams(ctx, inputAttrs, inputVars) - params._1.foreach(callParams += _) - params._2.foreach(funcParams += _) - } - - if (subExprs.length > 0) { - val subExprParams = getParamsForSubExprs(ctx, subExprs) - subExprParams._1.foreach(callParams += _) - subExprParams._2.foreach(funcParams += _) - } - val funcName = ctx.freshName(nodeName) + val callParams = funcParams.map(_._1.mkString(", ")).get + val declParams = funcParams.map(_._2.mkString(", ")).get + val funcFullName = ctx.addNewFunction(funcName, s""" - |private $javaType $funcName(${funcParams.distinct.mkString(", ")}) { + |private $javaType $funcName($declParams) { | ${eval.code.trim} | $setIsNull | return ${eval.value}; @@ -404,7 +172,7 @@ abstract class Expression extends TreeNode[Expression] { """.stripMargin) eval.value = newValue - eval.code = s"$javaType $newValue = $funcFullName(${callParams.distinct.mkString(", ")});" + eval.code = s"$javaType $newValue = $funcFullName($callParams);" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala new file mode 100644 index 0000000000000..d132d7a5a5677 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Defines APIs used in expression code generation. + */ +object ExpressionCodegen { + + /** + * Given an expression, returns the all necessary parameters to evaluate it, so the generated + * code of this expression can be split in a function. + * The 1st string in returned tuple is the parameter strings used to call the function. + * The 2nd string in returned tuple is the parameter strings used to declare the function. + * + * Returns `None` if it can't produce valid parameters. + * + * Params to include: + * 1. Evaluated columns referred by this, children or deferred expressions. + * 2. Rows referred by this, children or deferred expressions. + * 3. Eliminated subexpressions referred bu children expressions. + */ + def getExpressionInputParams( + ctx: CodegenContext, + expr: Expression): Option[(Seq[String], Seq[String])] = { + val (inputAttrs, inputVars) = getInputVarsForChildren(ctx, expr) + val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr) + val subExprs = getSubExprInChildren(ctx, expr) + + val paramsFromRows = inputRows.distinct.filter(_ != null).map { row => + (row, s"InternalRow $row") + } + val paramsFromColumns = prepareFunctionParams(ctx, inputAttrs, inputVars) + val paramsFromSubExprs = getParamsForSubExprs(ctx, subExprs) + val paramsLength = getParamLength(ctx, inputAttrs, subExprs) + paramsFromRows.length + + // Maximum allowed parameter number for Java's method descriptor. + if (paramsLength > 255) { + None + } else { + val allParams = (paramsFromRows ++ paramsFromColumns ++ paramsFromSubExprs).unzip + val callParams = allParams._1.distinct + val declParams = allParams._2.distinct + Some((callParams, declParams)) + } + } + + /** + * Returns the eliminated subexpressions in the children expressions. + */ + def getSubExprInChildren(ctx: CodegenContext, expr: Expression): Seq[Expression] = { + expr.children.flatMap { child => + child.collect { + case e if ctx.subExprEliminationExprs.contains(e) => e + } + }.distinct + } + + /** + * Given the list of eliminated subexpressions used in the children expressions, returns the + * strings of funtion parameters. The first is the variable names used to call the function, + * the second is the parameters used to declare the function in generated code. + */ + def getParamsForSubExprs( + ctx: CodegenContext, + subExprs: Seq[Expression]): Seq[(String, String)] = { + subExprs.flatMap { subExpr => + val argType = ctx.javaType(subExpr.dataType) + + val subExprState = ctx.subExprEliminationExprs(subExpr) + (subExprState.value, subExprState.isNull) + + if (!subExpr.nullable || subExprState.isNull == "true" || subExprState.isNull == "false") { + Seq((subExprState.value, s"$argType ${subExprState.value}")) + } else { + Seq((subExprState.value, s"$argType ${subExprState.value}"), + (subExprState.isNull, s"boolean ${subExprState.isNull}")) + } + }.distinct + } + + /** + * Retrieves previous input rows referred by children and deferred expressions. + */ + def getInputRowsForChildren(ctx: CodegenContext, expr: Expression): Seq[String] = { + expr.children.flatMap(getInputRows(ctx, _)).distinct + } + + /** + * Given a child expression, retrieves previous input rows referred by it or deferred expressions + * which are needed to evaluate it. + */ + def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = { + child.flatMap { + // An expression directly evaluates on current input row. + case BoundReference(ordinal, _, _) if ctx.currentVars == null || + ctx.currentVars(ordinal) == null => + Seq(ctx.INPUT_ROW) + + // An expression which is not evaluated yet. Tracks down to find input rows. + case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal).code != "" => + trackDownRow(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down input rows referred by the generated code snippet. + */ + def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = { + var exprCodes: List[ExprCode] = List(exprCode) + val inputRows = mutable.ArrayBuffer.empty[String] + + while (exprCodes.nonEmpty) { + exprCodes match { + case first :: others => + exprCodes = others + if (first.inputRow != null) { + inputRows += first.inputRow + } + first.inputVars.foreach { inputVar => + if (inputVar.exprCode.code != "") { + exprCodes = inputVar.exprCode :: exprCodes + } + } + case _ => + } + } + inputRows.toSeq + } + + /** + * Retrieves previously evaluated columns referred by children and deferred expressions. + * Returned tuple contains the list of expressions and the list of generated codes. + */ + def getInputVarsForChildren( + ctx: CodegenContext, + expr: Expression): (Seq[Expression], Seq[ExprCode]) = { + expr.children.flatMap(getInputVars(ctx, _)).distinct.unzip + } + + /** + * Given a child expression, retrieves previously evaluated columns referred by it or + * deferred expressions which are needed to evaluate it. + */ + def getInputVars(ctx: CodegenContext, child: Expression): Seq[(Expression, ExprCode)] = { + if (ctx.currentVars == null) { + return Seq.empty + } + + child.flatMap { + // An evaluated variable. + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null && + ctx.currentVars(ordinal).code == "" => + Seq((b, ctx.currentVars(ordinal))) + + // An input variable which is not evaluated yet. Tracks down to find any evaluated variables + // in the expression path. + // E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to + // "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so + // to include them into parameters, if not, we track down further. + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => + trackDownVar(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down previously evaluated columns referred by the generated code snippet. + */ + def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[(Expression, ExprCode)] = { + var exprCodes: List[ExprCode] = List(exprCode) + val inputVars = mutable.ArrayBuffer.empty[(Expression, ExprCode)] + + while (exprCodes.nonEmpty) { + exprCodes match { + case first :: others => + exprCodes = others + first.inputVars.foreach { inputVar => + if (inputVar.exprCode.code == "") { + inputVars += ((inputVar.expr, inputVar.exprCode)) + } else { + exprCodes = inputVar.exprCode :: exprCodes + } + } + case _ => + } + } + inputVars.toSeq + } + + /** + * Helper function to calculate the size of an expression as function parameter. + */ + def calculateParamLength(ctx: CodegenContext, input: Expression): Int = { + ctx.javaType(input.dataType) match { + case (ctx.JAVA_LONG | ctx.JAVA_DOUBLE) if !input.nullable => 2 + case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 3 + case _ if !input.nullable => 1 + case _ => 2 + } + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length of 255 or less. `this` contributes one unit and a parameter of type long or double + * contributes two units. + */ + def getParamLength( + ctx: CodegenContext, + inputs: Seq[Expression], + subExprs: Seq[Expression]): Int = { + // Start value is 1 for `this`. + (inputs ++ subExprs).distinct.foldLeft(1) { case (curLength, input) => + curLength + calculateParamLength(ctx, input) + } + } + + /** + * Given the lists of input attributes and variables to this expression, returns the strings of + * funtion parameters. The first is the variable names used to call the function, the second is + * the parameters used to declare the function in generated code. + */ + def prepareFunctionParams( + ctx: CodegenContext, + inputAttrs: Seq[Expression], + inputVars: Seq[ExprCode]): Seq[(String, String)] = { + inputAttrs.zip(inputVars).flatMap { case (input, ev) => + val argType = ctx.javaType(input.dataType) + + if (!input.nullable || ev.isNull == "true" || ev.isNull == "false") { + Seq((ev.value, s"$argType ${ev.value}")) + } else { + Seq((ev.value, s"$argType ${ev.value}"), (ev.isNull, s"boolean ${ev.isNull}")) + } + }.distinct + } +} From fd87e9ba324e0b45685e7873884a4fa7a6feaf17 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Nov 2017 07:48:36 +0000 Subject: [PATCH 11/23] Add test for new APIs. --- .../codegen/ExpressionCodegenSuite.scala | 135 ++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala new file mode 100644 index 0000000000000..f0f05436a2918 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.IntegerType + +class ExpressionCodegenSuite extends SparkFunSuite { + + test("Returns eliminated subexpressions for expression") { + val ctx = new CodegenContext() + val subExpr = Add(Literal(1), Literal(2)) + val exprs = Seq(Add(subExpr, Literal(3)), Add(subExpr, Literal(4))) + + ctx.generateExpressions(exprs, doSubexpressionElimination = true) + val subexpressions = ExpressionCodegen.getSubExprInChildren(ctx, exprs(0)) + assert(subexpressions.length == 1 && subexpressions(0) == subExpr) + } + + test("Gets parameters for subexpressions") { + val ctx = new CodegenContext() + val subExprs = Seq( + Add(Literal(1), AttributeReference("a", IntegerType, nullable = false)()), // non-nullable + Add(Literal(2), AttributeReference("b", IntegerType, nullable = true)())) // nullable + + ctx.subExprEliminationExprs.put(subExprs(0), SubExprEliminationState("false", "value1")) + ctx.subExprEliminationExprs.put(subExprs(1), SubExprEliminationState("isNull2", "value2")) + + val params = ExpressionCodegen.getParamsForSubExprs(ctx, subExprs) + assert(params.length == 3) + assert(params(0) == Tuple2("value1", "int value1")) + assert(params(1) == Tuple2("value2", "int value2")) + assert(params(2) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input variables for expression: current variables") { + val ctx = new CodegenContext() + val currentVars = Seq( + ExprCode("", isNull = "false", value = "value1"), + ExprCode("", isNull = "isNull2", value = "value2"), + ExprCode("fake code;", isNull = "isNull3", value = "value3")) + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = If(Literal(false), + Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)), + BoundReference(2, IntegerType, nullable = true)) + + val (inputAttrs, inputVars) = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + assert(inputAttrs.length == 2) + assert(inputAttrs(0) == BoundReference(0, IntegerType, nullable = false)) + assert(inputAttrs(1) == BoundReference(1, IntegerType, nullable = true)) + + assert(inputVars.length == 2) + assert(inputVars(0) == currentVars(0)) + assert(inputVars(1) == currentVars(1)) + + val params = ExpressionCodegen.prepareFunctionParams(ctx, inputAttrs, inputVars) + assert(params.length == 3) + assert(params(0) == Tuple2("value1", "int value1")) + assert(params(1) == Tuple2("value2", "int value2")) + assert(params(2) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input variables for expression: deferred variables") { + val ctx = new CodegenContext() + + // The referred column is not evaluated yet. But it depends on an evaluated column from + // other operator. + val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1")) + val fakeExpr = AttributeReference("a", IntegerType, nullable = true)() + + // currentVars(0) depends on this evaluated column. + currentVars(0).inputVars += ExprInputVar( + fakeExpr, + ExprCode("", isNull = "isNull2", value = "value2")) + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) + val (inputAttrs, inputVars) = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + assert(inputAttrs.length == 1) + assert(inputAttrs(0) == fakeExpr) + + val params = ExpressionCodegen.prepareFunctionParams(ctx, inputAttrs, inputVars) + assert(params.length == 2) + assert(params(0) == Tuple2("value2", "int value2")) + assert(params(1) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input rows for expression") { + val ctx = new CodegenContext() + ctx.currentVars = null + ctx.INPUT_ROW = "i" + + val expr = Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)) + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 1) + assert(inputRows(0) == "i") + } + + test("Returns input rows for expression: deferred expression") { + val ctx = new CodegenContext() + + // The referred column is not evaluated yet. But it depends on an input row from + // other operator. + val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1")) + currentVars(0).inputRow = "inputadaptor_row1" + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 1) + assert(inputRows(0) == "inputadaptor_row1") + } +} From 57a9fb77d7628e8a5815b8571ca9c99490419252 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Nov 2017 08:03:09 +0000 Subject: [PATCH 12/23] Generate function parameters if needed. --- .../org/apache/spark/sql/catalyst/expressions/Expression.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 7925544c42d0e..d5fb6977bd526 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -142,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { * separating it into a function if the size exceeds a threshold. */ private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { - val funcParams = ExpressionCodegen.getExpressionInputParams(ctx, this) + lazy val funcParams = ExpressionCodegen.getExpressionInputParams(ctx, this) if (eval.code.trim.length > 1024 && funcParams.isDefined) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { From 0d358d635494199582aa6e38fdbeec0f6446c029 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 1 Dec 2017 03:37:18 +0000 Subject: [PATCH 13/23] Address comments. --- .../sql/catalyst/expressions/Expression.scala | 22 ++++--- .../expressions/codegen/CodeGenerator.scala | 4 +- .../codegen/ExpressionCodegen.scala | 63 ++++++++----------- .../codegen/ExpressionCodegenSuite.scala | 5 +- 4 files changed, 43 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d5fb6977bd526..d55fb27f927ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -105,7 +105,11 @@ abstract class Expression extends TreeNode[Expression] { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val eval = doGenCode(ctx, ExprCode("", isNull, value)) - populateInputs(ctx, eval) + + // Records current input row and variables of this expression. + eval.inputRow = ctx.INPUT_ROW + eval.inputVars = findInputVars(ctx, eval) + reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -117,23 +121,23 @@ abstract class Expression extends TreeNode[Expression] { } /** - * Records current input row and variables for this expression into created `ExprCode`. + * Returns the input variables to this expression. */ - private def populateInputs(ctx: CodegenContext, eval: ExprCode): Unit = { - if (ctx.INPUT_ROW != null) { - eval.inputRow = ctx.INPUT_ROW - } + private def findInputVars(ctx: CodegenContext, eval: ExprCode): Seq[ExprInputVar] = { if (ctx.currentVars != null) { val boundRefs = this.collect { case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => (ordinal, b) }.toMap - ctx.currentVars.zipWithIndex.filter(_._1 != null).foreach { case (currentVar, idx) => + ctx.currentVars.zipWithIndex.filter(_._1 != null).flatMap { case (currentVar, idx) => if (boundRefs.contains(idx)) { - val inputVar = ExprInputVar(boundRefs(idx), exprCode = currentVar) - eval.inputVars += inputVar + Some(ExprInputVar(boundRefs(idx), exprCode = currentVar)) + } else { + None } } + } else { + Seq.empty } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6e24bffc6571e..edbf4e054c74f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -63,10 +63,10 @@ case class ExprCode( var isNull: String, var value: String, var inputRow: String = null, - val inputVars: mutable.ArrayBuffer[ExprInputVar] = mutable.ArrayBuffer.empty) + var inputVars: Seq[ExprInputVar] = Seq.empty) /** - * Represents an input variable that holds the java type and the [[ExprCode]]. + * Represents an input variable [[ExprCode]] to an evaluation of an [[Expression]]. */ case class ExprInputVar(val expr: Expression, val exprCode: ExprCode) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala index d132d7a5a5677..57ff7cc3a09a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -43,16 +43,17 @@ object ExpressionCodegen { ctx: CodegenContext, expr: Expression): Option[(Seq[String], Seq[String])] = { val (inputAttrs, inputVars) = getInputVarsForChildren(ctx, expr) - val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr) + val paramsFromColumns = prepareFunctionParams(ctx, inputAttrs, inputVars) + val subExprs = getSubExprInChildren(ctx, expr) + val paramsFromSubExprs = getParamsForSubExprs(ctx, subExprs) + val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr) val paramsFromRows = inputRows.distinct.filter(_ != null).map { row => (row, s"InternalRow $row") } - val paramsFromColumns = prepareFunctionParams(ctx, inputAttrs, inputVars) - val paramsFromSubExprs = getParamsForSubExprs(ctx, subExprs) - val paramsLength = getParamLength(ctx, inputAttrs, subExprs) + paramsFromRows.length + val paramsLength = getParamLength(ctx, inputAttrs, subExprs) + paramsFromRows.length // Maximum allowed parameter number for Java's method descriptor. if (paramsLength > 255) { None @@ -87,7 +88,6 @@ object ExpressionCodegen { val argType = ctx.javaType(subExpr.dataType) val subExprState = ctx.subExprEliminationExprs(subExpr) - (subExprState.value, subExprState.isNull) if (!subExpr.nullable || subExprState.isNull == "true" || subExprState.isNull == "false") { Seq((subExprState.value, s"$argType ${subExprState.value}")) @@ -128,25 +128,21 @@ object ExpressionCodegen { * Tracks down input rows referred by the generated code snippet. */ def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = { - var exprCodes: List[ExprCode] = List(exprCode) + val exprCodes = mutable.Queue[ExprCode](exprCode) val inputRows = mutable.ArrayBuffer.empty[String] while (exprCodes.nonEmpty) { - exprCodes match { - case first :: others => - exprCodes = others - if (first.inputRow != null) { - inputRows += first.inputRow - } - first.inputVars.foreach { inputVar => - if (inputVar.exprCode.code != "") { - exprCodes = inputVar.exprCode :: exprCodes - } - } - case _ => + val curExprCode = exprCodes.dequeue() + if (curExprCode.inputRow != null) { + inputRows += curExprCode.inputRow + } + curExprCode.inputVars.foreach { inputVar => + if (inputVar.exprCode.code != "") { + exprCodes.enqueue(inputVar.exprCode) + } } } - inputRows.toSeq + inputRows } /** @@ -179,7 +175,7 @@ object ExpressionCodegen { // E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to // "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so // to include them into parameters, if not, we track down further. - case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => + case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => trackDownVar(ctx, ctx.currentVars(ordinal)) case _ => Seq.empty @@ -190,24 +186,19 @@ object ExpressionCodegen { * Tracks down previously evaluated columns referred by the generated code snippet. */ def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[(Expression, ExprCode)] = { - var exprCodes: List[ExprCode] = List(exprCode) + val exprCodes = mutable.Queue[ExprCode](exprCode) val inputVars = mutable.ArrayBuffer.empty[(Expression, ExprCode)] while (exprCodes.nonEmpty) { - exprCodes match { - case first :: others => - exprCodes = others - first.inputVars.foreach { inputVar => - if (inputVar.exprCode.code == "") { - inputVars += ((inputVar.expr, inputVar.exprCode)) - } else { - exprCodes = inputVar.exprCode :: exprCodes - } - } - case _ => + exprCodes.dequeue().inputVars.foreach { inputVar => + if (inputVar.exprCode.code == "") { + inputVars += ((inputVar.expr, inputVar.exprCode)) + } else { + exprCodes.enqueue(inputVar.exprCode) + } } } - inputVars.toSeq + inputVars } /** @@ -231,10 +222,8 @@ object ExpressionCodegen { ctx: CodegenContext, inputs: Seq[Expression], subExprs: Seq[Expression]): Int = { - // Start value is 1 for `this`. - (inputs ++ subExprs).distinct.foldLeft(1) { case (curLength, input) => - curLength + calculateParamLength(ctx, input) - } + // Initial value is 1 for `this`. + 1 + (inputs ++ subExprs).distinct.map(calculateParamLength(ctx, _)).sum } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala index f0f05436a2918..dfe8633840a32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala @@ -88,9 +88,8 @@ class ExpressionCodegenSuite extends SparkFunSuite { val fakeExpr = AttributeReference("a", IntegerType, nullable = true)() // currentVars(0) depends on this evaluated column. - currentVars(0).inputVars += ExprInputVar( - fakeExpr, - ExprCode("", isNull = "isNull2", value = "value2")) + currentVars(0).inputVars = Seq(ExprInputVar(fakeExpr, + ExprCode("", isNull = "isNull2", value = "value2"))) ctx.currentVars = currentVars ctx.INPUT_ROW = null From aa3db2edca66ab04ecb8fbd54750cbd46544eb1d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 1 Dec 2017 04:58:42 +0000 Subject: [PATCH 14/23] Address comments. --- .../expressions/codegen/CodeGenerator.scala | 3 +-- .../codegen/ExpressionCodegen.scala | 27 ++++++------------- .../codegen/ExpressionCodegenSuite.scala | 3 ++- 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index edbf4e054c74f..4cff175519c2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1056,8 +1056,7 @@ class CodegenContext { if (expr.nullable) { addMutableState(JAVA_BOOLEAN, isNull) } - addMutableState(javaType(expr.dataType), value, - s"$value = ${defaultValue(expr.dataType)};") + addMutableState(javaType(expr.dataType), value) subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = if (expr.nullable) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala index 57ff7cc3a09a8..0a402fb0ee9c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -46,7 +46,8 @@ object ExpressionCodegen { val paramsFromColumns = prepareFunctionParams(ctx, inputAttrs, inputVars) val subExprs = getSubExprInChildren(ctx, expr) - val paramsFromSubExprs = getParamsForSubExprs(ctx, subExprs) + val subExprCodes = getSubExprCodes(ctx, subExprs) + val paramsFromSubExprs = prepareFunctionParams(ctx, subExprs, subExprCodes) val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr) val paramsFromRows = inputRows.distinct.filter(_ != null).map { row => @@ -77,25 +78,13 @@ object ExpressionCodegen { } /** - * Given the list of eliminated subexpressions used in the children expressions, returns the - * strings of funtion parameters. The first is the variable names used to call the function, - * the second is the parameters used to declare the function in generated code. + * A small helper function to return `ExprCode`s that represent subexpressions. */ - def getParamsForSubExprs( - ctx: CodegenContext, - subExprs: Seq[Expression]): Seq[(String, String)] = { - subExprs.flatMap { subExpr => - val argType = ctx.javaType(subExpr.dataType) - - val subExprState = ctx.subExprEliminationExprs(subExpr) - - if (!subExpr.nullable || subExprState.isNull == "true" || subExprState.isNull == "false") { - Seq((subExprState.value, s"$argType ${subExprState.value}")) - } else { - Seq((subExprState.value, s"$argType ${subExprState.value}"), - (subExprState.isNull, s"boolean ${subExprState.isNull}")) - } - }.distinct + def getSubExprCodes(ctx: CodegenContext, subExprs: Seq[Expression]): Seq[ExprCode] = { + subExprs.map { subExpr => + val stat = ctx.subExprEliminationExprs(subExpr) + ExprCode(code = "", value = stat.value, isNull = stat.isNull) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala index dfe8633840a32..606b294f9f8a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala @@ -42,7 +42,8 @@ class ExpressionCodegenSuite extends SparkFunSuite { ctx.subExprEliminationExprs.put(subExprs(0), SubExprEliminationState("false", "value1")) ctx.subExprEliminationExprs.put(subExprs(1), SubExprEliminationState("isNull2", "value2")) - val params = ExpressionCodegen.getParamsForSubExprs(ctx, subExprs) + val subExprCodes = ExpressionCodegen.getSubExprCodes(ctx, subExprs) + val params = ExpressionCodegen.prepareFunctionParams(ctx, subExprs, subExprCodes) assert(params.length == 3) assert(params(0) == Tuple2("value1", "int value1")) assert(params(1) == Tuple2("value2", "int value2")) From 429afbabef6f718870ca3c6caf0712a1e459681f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 4 Dec 2017 09:02:09 +0000 Subject: [PATCH 15/23] Rename variable. --- .../sql/catalyst/expressions/codegen/ExpressionCodegen.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala index 0a402fb0ee9c0..0150f7fd88642 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -82,8 +82,8 @@ object ExpressionCodegen { */ def getSubExprCodes(ctx: CodegenContext, subExprs: Seq[Expression]): Seq[ExprCode] = { subExprs.map { subExpr => - val stat = ctx.subExprEliminationExprs(subExpr) - ExprCode(code = "", value = stat.value, isNull = stat.isNull) + val state = ctx.subExprEliminationExprs(subExpr) + ExprCode(code = "", value = state.value, isNull = state.isNull) } } From 48add652f2df45ce6506f9464c10a6425bd92214 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Dec 2017 14:53:11 +0000 Subject: [PATCH 16/23] Address comments. --- .../spark/sql/catalyst/expressions/Expression.scala | 13 +++---------- .../expressions/codegen/CodeGenerator.scala | 5 ++++- .../expressions/codegen/ExpressionCodegen.scala | 2 +- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d55fb27f927ba..7a3b9b16032b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -125,16 +125,9 @@ abstract class Expression extends TreeNode[Expression] { */ private def findInputVars(ctx: CodegenContext, eval: ExprCode): Seq[ExprInputVar] = { if (ctx.currentVars != null) { - val boundRefs = this.collect { - case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => (ordinal, b) - }.toMap - - ctx.currentVars.zipWithIndex.filter(_._1 != null).flatMap { case (currentVar, idx) => - if (boundRefs.contains(idx)) { - Some(ExprInputVar(boundRefs(idx), exprCode = currentVar)) - } else { - None - } + this.collect { + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => + ExprInputVar(b, exprCode = ctx.currentVars(ordinal)) } } else { Seq.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4cff175519c2e..35943dd20c4d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -67,8 +67,11 @@ case class ExprCode( /** * Represents an input variable [[ExprCode]] to an evaluation of an [[Expression]]. + * + * @param expr The expression that is evaluated to the input variable. + * @param exprCode The [[ExprCode]] that represents the evaluation result for the input variable. */ -case class ExprInputVar(val expr: Expression, val exprCode: ExprCode) +case class ExprInputVar(expr: Expression, exprCode: ExprCode) /** * State used for subexpression elimination. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala index 0150f7fd88642..d994ae4d1ef17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ /** - * Defines APIs used in expression code generation. + * Defines util methods used in expression code generation. */ object ExpressionCodegen { From 9443011978c32c611e950a6193f05aa666437f50 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 8 Dec 2017 03:41:15 +0000 Subject: [PATCH 17/23] Address comments. --- .../sql/catalyst/expressions/Expression.scala | 3 +- .../expressions/codegen/CodeGenerator.scala | 25 +++++++- .../codegen/ExpressionCodegen.scala | 57 +++++++++++-------- .../expressions/codegen/ExprCodeSuite.scala | 55 ++++++++++++++++++ .../codegen/ExpressionCodegenSuite.scala | 21 +++---- 5 files changed, 123 insertions(+), 38 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprCodeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 7a3b9b16032b2..5e4641e94c9a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -127,7 +127,8 @@ abstract class Expression extends TreeNode[Expression] { if (ctx.currentVars != null) { this.collect { case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => - ExprInputVar(b, exprCode = ctx.currentVars(ordinal)) + ExprInputVar(exprCode = ctx.currentVars(ordinal), + dataType = b.dataType, nullable = b.nullable) } } else { Seq.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 35943dd20c4d5..2e1b0e6bddaa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -63,15 +63,34 @@ case class ExprCode( var isNull: String, var value: String, var inputRow: String = null, - var inputVars: Seq[ExprInputVar] = Seq.empty) + var inputVars: Seq[ExprInputVar] = Seq.empty) { + + // Returns true if this value is a literal. + def isLiteral(): Boolean = { + assert(value.nonEmpty, "ExprCode.value can't be empty string.") + + if (value == "true" || value == "false" || value == "null") { + true + } else { + // The valid characters for the first character of a Java variable is [a-zA-Z_$]. + value.head match { + case v if v >= 'a' && v <= 'z' => false + case v if v >= 'A' && v <= 'Z' => false + case '_' | '$' => false + case _ => true + } + } + } +} /** * Represents an input variable [[ExprCode]] to an evaluation of an [[Expression]]. * - * @param expr The expression that is evaluated to the input variable. * @param exprCode The [[ExprCode]] that represents the evaluation result for the input variable. + * @param dataType The data type of the input variable. + * @param nullable Whether the input variable can be null or not. */ -case class ExprInputVar(expr: Expression, exprCode: ExprCode) +case class ExprInputVar(exprCode: ExprCode, dataType: DataType, nullable: Boolean) /** * State used for subexpression elimination. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala index d994ae4d1ef17..c6fdafb32b965 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -20,12 +20,16 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DataType /** * Defines util methods used in expression code generation. */ object ExpressionCodegen { + // Type alias for a tuple representing the data type and nullable for an expression. + type ExprProperty = (DataType, Boolean) + /** * Given an expression, returns the all necessary parameters to evaluate it, so the generated * code of this expression can be split in a function. @@ -37,7 +41,7 @@ object ExpressionCodegen { * Params to include: * 1. Evaluated columns referred by this, children or deferred expressions. * 2. Rows referred by this, children or deferred expressions. - * 3. Eliminated subexpressions referred bu children expressions. + * 3. Eliminated subexpressions referred by children expressions. */ def getExpressionInputParams( ctx: CodegenContext, @@ -47,14 +51,15 @@ object ExpressionCodegen { val subExprs = getSubExprInChildren(ctx, expr) val subExprCodes = getSubExprCodes(ctx, subExprs) - val paramsFromSubExprs = prepareFunctionParams(ctx, subExprs, subExprCodes) + val subAttrs = subExprs.map(e => (e.dataType, e.nullable)) + val paramsFromSubExprs = prepareFunctionParams(ctx, subAttrs, subExprCodes) val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr) val paramsFromRows = inputRows.distinct.filter(_ != null).map { row => (row, s"InternalRow $row") } - val paramsLength = getParamLength(ctx, inputAttrs, subExprs) + paramsFromRows.length + val paramsLength = getParamLength(ctx, inputAttrs ++ subAttrs) + paramsFromRows.length // Maximum allowed parameter number for Java's method descriptor. if (paramsLength > 255) { None @@ -140,7 +145,7 @@ object ExpressionCodegen { */ def getInputVarsForChildren( ctx: CodegenContext, - expr: Expression): (Seq[Expression], Seq[ExprCode]) = { + expr: Expression): (Seq[ExprProperty], Seq[ExprCode]) = { expr.children.flatMap(getInputVars(ctx, _)).distinct.unzip } @@ -148,7 +153,7 @@ object ExpressionCodegen { * Given a child expression, retrieves previously evaluated columns referred by it or * deferred expressions which are needed to evaluate it. */ - def getInputVars(ctx: CodegenContext, child: Expression): Seq[(Expression, ExprCode)] = { + def getInputVars(ctx: CodegenContext, child: Expression): Seq[(ExprProperty, ExprCode)] = { if (ctx.currentVars == null) { return Seq.empty } @@ -157,7 +162,7 @@ object ExpressionCodegen { // An evaluated variable. case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null && ctx.currentVars(ordinal).code == "" => - Seq((b, ctx.currentVars(ordinal))) + Seq(((b.dataType, b.nullable), ctx.currentVars(ordinal))) // An input variable which is not evaluated yet. Tracks down to find any evaluated variables // in the expression path. @@ -174,14 +179,14 @@ object ExpressionCodegen { /** * Tracks down previously evaluated columns referred by the generated code snippet. */ - def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[(Expression, ExprCode)] = { + def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[(ExprProperty, ExprCode)] = { val exprCodes = mutable.Queue[ExprCode](exprCode) - val inputVars = mutable.ArrayBuffer.empty[(Expression, ExprCode)] + val inputVars = mutable.ArrayBuffer.empty[(ExprProperty, ExprCode)] while (exprCodes.nonEmpty) { exprCodes.dequeue().inputVars.foreach { inputVar => if (inputVar.exprCode.code == "") { - inputVars += ((inputVar.expr, inputVar.exprCode)) + inputVars += (((inputVar.dataType, inputVar.nullable), inputVar.exprCode)) } else { exprCodes.enqueue(inputVar.exprCode) } @@ -193,11 +198,11 @@ object ExpressionCodegen { /** * Helper function to calculate the size of an expression as function parameter. */ - def calculateParamLength(ctx: CodegenContext, input: Expression): Int = { - ctx.javaType(input.dataType) match { - case (ctx.JAVA_LONG | ctx.JAVA_DOUBLE) if !input.nullable => 2 + def calculateParamLength(ctx: CodegenContext, input: ExprProperty): Int = { + ctx.javaType(input._1) match { + case (ctx.JAVA_LONG | ctx.JAVA_DOUBLE) if !input._2 => 2 case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 3 - case _ if !input.nullable => 1 + case _ if !input._2 => 1 case _ => 2 } } @@ -207,12 +212,9 @@ object ExpressionCodegen { * length of 255 or less. `this` contributes one unit and a parameter of type long or double * contributes two units. */ - def getParamLength( - ctx: CodegenContext, - inputs: Seq[Expression], - subExprs: Seq[Expression]): Int = { + def getParamLength(ctx: CodegenContext, inputs: Seq[ExprProperty]): Int = { // Initial value is 1 for `this`. - 1 + (inputs ++ subExprs).distinct.map(calculateParamLength(ctx, _)).sum + 1 + inputs.map(calculateParamLength(ctx, _)).sum } /** @@ -222,16 +224,23 @@ object ExpressionCodegen { */ def prepareFunctionParams( ctx: CodegenContext, - inputAttrs: Seq[Expression], + inputAttrs: Seq[ExprProperty], inputVars: Seq[ExprCode]): Seq[(String, String)] = { inputAttrs.zip(inputVars).flatMap { case (input, ev) => - val argType = ctx.javaType(input.dataType) + val params = mutable.ArrayBuffer.empty[(String, String)] - if (!input.nullable || ev.isNull == "true" || ev.isNull == "false") { - Seq((ev.value, s"$argType ${ev.value}")) - } else { - Seq((ev.value, s"$argType ${ev.value}"), (ev.isNull, s"boolean ${ev.isNull}")) + // Only include the expression value if it is not a literal. + if (!ev.isLiteral()) { + val argType = ctx.javaType(input._1) + params += ((ev.value, s"$argType ${ev.value}")) } + + // If it is a nullable expression and `isNull` is not a literal. + if (input._2 && ev.isNull != "true" && ev.isNull != "false") { + params += ((ev.isNull, s"boolean ${ev.isNull}")) + } + + params }.distinct } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprCodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprCodeSuite.scala new file mode 100644 index 0000000000000..ca895352a4964 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprCodeSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite + +class ExprCodeSuite extends SparkFunSuite { + + test("ExprCode.isLiteral: literals") { + val literals = Seq( + ExprCode("", "", "true"), + ExprCode("", "", "false"), + ExprCode("", "", "1"), + ExprCode("", "", "-1"), + ExprCode("", "", "1L"), + ExprCode("", "", "-1L"), + ExprCode("", "", "1.0f"), + ExprCode("", "", "-1.0f"), + ExprCode("", "", "0.1f"), + ExprCode("", "", "-0.1f"), + ExprCode("", "", """"string""""), + ExprCode("", "", "(byte)-1"), + ExprCode("", "", "(short)-1"), + ExprCode("", "", "null")) + + literals.foreach(l => assert(l.isLiteral() == true)) + } + + test("ExprCode.isLiteral: non literals") { + val variables = Seq( + ExprCode("", "", "var1"), + ExprCode("", "", "_var2"), + ExprCode("", "", "$var3"), + ExprCode("", "", "v1a2r3"), + ExprCode("", "", "_1v2a3r"), + ExprCode("", "", "$1v2a3r")) + + variables.foreach(v => assert(v.isLiteral() == false)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala index 606b294f9f8a6..ee037e2d1f3a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala @@ -43,7 +43,8 @@ class ExpressionCodegenSuite extends SparkFunSuite { ctx.subExprEliminationExprs.put(subExprs(1), SubExprEliminationState("isNull2", "value2")) val subExprCodes = ExpressionCodegen.getSubExprCodes(ctx, subExprs) - val params = ExpressionCodegen.prepareFunctionParams(ctx, subExprs, subExprCodes) + val subAttrs = subExprs.map(e => (e.dataType, e.nullable)) + val params = ExpressionCodegen.prepareFunctionParams(ctx, subAttrs, subExprCodes) assert(params.length == 3) assert(params(0) == Tuple2("value1", "int value1")) assert(params(1) == Tuple2("value2", "int value2")) @@ -53,9 +54,9 @@ class ExpressionCodegenSuite extends SparkFunSuite { test("Returns input variables for expression: current variables") { val ctx = new CodegenContext() val currentVars = Seq( - ExprCode("", isNull = "false", value = "value1"), - ExprCode("", isNull = "isNull2", value = "value2"), - ExprCode("fake code;", isNull = "isNull3", value = "value3")) + ExprCode("", isNull = "false", value = "value1"), // evaluated + ExprCode("", isNull = "isNull2", value = "value2"), // evaluated + ExprCode("fake code;", isNull = "isNull3", value = "value3")) // not evaluated ctx.currentVars = currentVars ctx.INPUT_ROW = null @@ -65,9 +66,10 @@ class ExpressionCodegenSuite extends SparkFunSuite { BoundReference(2, IntegerType, nullable = true)) val (inputAttrs, inputVars) = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + // Only two evaluated variables included. assert(inputAttrs.length == 2) - assert(inputAttrs(0) == BoundReference(0, IntegerType, nullable = false)) - assert(inputAttrs(1) == BoundReference(1, IntegerType, nullable = true)) + assert(inputAttrs(0) == Tuple2(IntegerType, false)) + assert(inputAttrs(1) == Tuple2(IntegerType, true)) assert(inputVars.length == 2) assert(inputVars(0) == currentVars(0)) @@ -86,18 +88,17 @@ class ExpressionCodegenSuite extends SparkFunSuite { // The referred column is not evaluated yet. But it depends on an evaluated column from // other operator. val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1")) - val fakeExpr = AttributeReference("a", IntegerType, nullable = true)() // currentVars(0) depends on this evaluated column. - currentVars(0).inputVars = Seq(ExprInputVar(fakeExpr, - ExprCode("", isNull = "isNull2", value = "value2"))) + currentVars(0).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull2", value = "value2"), + dataType = IntegerType, nullable = true)) ctx.currentVars = currentVars ctx.INPUT_ROW = null val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) val (inputAttrs, inputVars) = ExpressionCodegen.getInputVarsForChildren(ctx, expr) assert(inputAttrs.length == 1) - assert(inputAttrs(0) == fakeExpr) + assert(inputAttrs(0) == Tuple2(IntegerType, true)) val params = ExpressionCodegen.prepareFunctionParams(ctx, inputAttrs, inputVars) assert(params.length == 2) From 2f4014fe7de0ae634231a5aae36e7272defa3d9e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 11 Dec 2017 14:53:06 +0000 Subject: [PATCH 18/23] Address comments again. --- .../sql/catalyst/expressions/Expression.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 7 ++- .../codegen/ExpressionCodegen.scala | 55 +++++++++---------- .../codegen/ExpressionCodegenSuite.scala | 28 +++++----- 4 files changed, 47 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 5e4641e94c9a4..329ea5d421509 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -105,6 +105,7 @@ abstract class Expression extends TreeNode[Expression] { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val eval = doGenCode(ctx, ExprCode("", isNull, value)) + eval.isNull = if (this.nullable) eval.isNull else "false" // Records current input row and variables of this expression. eval.inputRow = ctx.INPUT_ROW diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2e1b0e6bddaa5..07296fcd2737a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -81,6 +81,9 @@ case class ExprCode( } } } + + // The code is emptied after evaluation. + def isEvaluated(): Boolean = code == "" } /** @@ -1048,7 +1051,7 @@ class CodegenContext { // Generate the code for this expression tree and wrap it in a function. val eval = expr.genCode(this) - val nullValue = if (expr.nullable) { + val assignIsNull = if (expr.nullable) { s"$isNull = ${eval.isNull};" } else { "" @@ -1057,7 +1060,7 @@ class CodegenContext { s""" |private void $fnName(InternalRow $INPUT_ROW) { | ${eval.code.trim} - | $nullValue + | $assignIsNull | $value = ${eval.value}; |} """.stripMargin diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala index c6fdafb32b965..50e907e671c05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -27,9 +27,6 @@ import org.apache.spark.sql.types.DataType */ object ExpressionCodegen { - // Type alias for a tuple representing the data type and nullable for an expression. - type ExprProperty = (DataType, Boolean) - /** * Given an expression, returns the all necessary parameters to evaluate it, so the generated * code of this expression can be split in a function. @@ -46,20 +43,22 @@ object ExpressionCodegen { def getExpressionInputParams( ctx: CodegenContext, expr: Expression): Option[(Seq[String], Seq[String])] = { - val (inputAttrs, inputVars) = getInputVarsForChildren(ctx, expr) - val paramsFromColumns = prepareFunctionParams(ctx, inputAttrs, inputVars) - val subExprs = getSubExprInChildren(ctx, expr) val subExprCodes = getSubExprCodes(ctx, subExprs) - val subAttrs = subExprs.map(e => (e.dataType, e.nullable)) - val paramsFromSubExprs = prepareFunctionParams(ctx, subAttrs, subExprCodes) + val subVars = subExprs.zip(subExprCodes).map { case (subExpr, subExprCode) => + ExprInputVar(subExprCode, subExpr.dataType, subExpr.nullable) + } + val paramsFromSubExprs = prepareFunctionParams(ctx, subVars) + + val inputVars = getInputVarsForChildren(ctx, expr) + val paramsFromColumns = prepareFunctionParams(ctx, inputVars) val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr) val paramsFromRows = inputRows.distinct.filter(_ != null).map { row => (row, s"InternalRow $row") } - val paramsLength = getParamLength(ctx, inputAttrs ++ subAttrs) + paramsFromRows.length + val paramsLength = getParamLength(ctx, inputVars ++ subVars) + paramsFromRows.length // Maximum allowed parameter number for Java's method descriptor. if (paramsLength > 255) { None @@ -145,15 +144,15 @@ object ExpressionCodegen { */ def getInputVarsForChildren( ctx: CodegenContext, - expr: Expression): (Seq[ExprProperty], Seq[ExprCode]) = { - expr.children.flatMap(getInputVars(ctx, _)).distinct.unzip + expr: Expression): Seq[ExprInputVar] = { + expr.children.flatMap(getInputVars(ctx, _)).distinct } /** * Given a child expression, retrieves previously evaluated columns referred by it or * deferred expressions which are needed to evaluate it. */ - def getInputVars(ctx: CodegenContext, child: Expression): Seq[(ExprProperty, ExprCode)] = { + def getInputVars(ctx: CodegenContext, child: Expression): Seq[ExprInputVar] = { if (ctx.currentVars == null) { return Seq.empty } @@ -161,8 +160,8 @@ object ExpressionCodegen { child.flatMap { // An evaluated variable. case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null && - ctx.currentVars(ordinal).code == "" => - Seq(((b.dataType, b.nullable), ctx.currentVars(ordinal))) + ctx.currentVars(ordinal).isEvaluated() => + Seq(ExprInputVar(ctx.currentVars(ordinal), b.dataType, b.nullable)) // An input variable which is not evaluated yet. Tracks down to find any evaluated variables // in the expression path. @@ -179,14 +178,14 @@ object ExpressionCodegen { /** * Tracks down previously evaluated columns referred by the generated code snippet. */ - def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[(ExprProperty, ExprCode)] = { + def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[ExprInputVar] = { val exprCodes = mutable.Queue[ExprCode](exprCode) - val inputVars = mutable.ArrayBuffer.empty[(ExprProperty, ExprCode)] + val inputVars = mutable.ArrayBuffer.empty[ExprInputVar] while (exprCodes.nonEmpty) { exprCodes.dequeue().inputVars.foreach { inputVar => - if (inputVar.exprCode.code == "") { - inputVars += (((inputVar.dataType, inputVar.nullable), inputVar.exprCode)) + if (inputVar.exprCode.isEvaluated()) { + inputVars += inputVar } else { exprCodes.enqueue(inputVar.exprCode) } @@ -198,11 +197,11 @@ object ExpressionCodegen { /** * Helper function to calculate the size of an expression as function parameter. */ - def calculateParamLength(ctx: CodegenContext, input: ExprProperty): Int = { - ctx.javaType(input._1) match { - case (ctx.JAVA_LONG | ctx.JAVA_DOUBLE) if !input._2 => 2 + def calculateParamLength(ctx: CodegenContext, input: ExprInputVar): Int = { + ctx.javaType(input.dataType) match { + case (ctx.JAVA_LONG | ctx.JAVA_DOUBLE) if !input.nullable => 2 case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 3 - case _ if !input._2 => 1 + case _ if !input.nullable => 1 case _ => 2 } } @@ -212,7 +211,7 @@ object ExpressionCodegen { * length of 255 or less. `this` contributes one unit and a parameter of type long or double * contributes two units. */ - def getParamLength(ctx: CodegenContext, inputs: Seq[ExprProperty]): Int = { + def getParamLength(ctx: CodegenContext, inputs: Seq[ExprInputVar]): Int = { // Initial value is 1 for `this`. 1 + inputs.map(calculateParamLength(ctx, _)).sum } @@ -224,19 +223,19 @@ object ExpressionCodegen { */ def prepareFunctionParams( ctx: CodegenContext, - inputAttrs: Seq[ExprProperty], - inputVars: Seq[ExprCode]): Seq[(String, String)] = { - inputAttrs.zip(inputVars).flatMap { case (input, ev) => + inputVars: Seq[ExprInputVar]): Seq[(String, String)] = { + inputVars.flatMap { inputVar => val params = mutable.ArrayBuffer.empty[(String, String)] + val ev = inputVar.exprCode // Only include the expression value if it is not a literal. if (!ev.isLiteral()) { - val argType = ctx.javaType(input._1) + val argType = ctx.javaType(inputVar.dataType) params += ((ev.value, s"$argType ${ev.value}")) } // If it is a nullable expression and `isNull` is not a literal. - if (input._2 && ev.isNull != "true" && ev.isNull != "false") { + if (inputVar.nullable && ev.isNull != "true" && ev.isNull != "false") { params += ((ev.isNull, s"boolean ${ev.isNull}")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala index ee037e2d1f3a7..f2a08daafdece 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala @@ -43,8 +43,10 @@ class ExpressionCodegenSuite extends SparkFunSuite { ctx.subExprEliminationExprs.put(subExprs(1), SubExprEliminationState("isNull2", "value2")) val subExprCodes = ExpressionCodegen.getSubExprCodes(ctx, subExprs) - val subAttrs = subExprs.map(e => (e.dataType, e.nullable)) - val params = ExpressionCodegen.prepareFunctionParams(ctx, subAttrs, subExprCodes) + val subVars = subExprs.zip(subExprCodes).map { case (expr, exprCode) => + ExprInputVar(exprCode, expr.dataType, expr.nullable) + } + val params = ExpressionCodegen.prepareFunctionParams(ctx, subVars) assert(params.length == 3) assert(params(0) == Tuple2("value1", "int value1")) assert(params(1) == Tuple2("value2", "int value2")) @@ -65,17 +67,15 @@ class ExpressionCodegenSuite extends SparkFunSuite { BoundReference(1, IntegerType, nullable = true)), BoundReference(2, IntegerType, nullable = true)) - val (inputAttrs, inputVars) = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) // Only two evaluated variables included. - assert(inputAttrs.length == 2) - assert(inputAttrs(0) == Tuple2(IntegerType, false)) - assert(inputAttrs(1) == Tuple2(IntegerType, true)) - assert(inputVars.length == 2) - assert(inputVars(0) == currentVars(0)) - assert(inputVars(1) == currentVars(1)) + assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == false) + assert(inputVars(1).dataType == IntegerType && inputVars(1).nullable == true) + assert(inputVars(0).exprCode == currentVars(0)) + assert(inputVars(1).exprCode == currentVars(1)) - val params = ExpressionCodegen.prepareFunctionParams(ctx, inputAttrs, inputVars) + val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) assert(params.length == 3) assert(params(0) == Tuple2("value1", "int value1")) assert(params(1) == Tuple2("value2", "int value2")) @@ -96,11 +96,11 @@ class ExpressionCodegenSuite extends SparkFunSuite { ctx.INPUT_ROW = null val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) - val (inputAttrs, inputVars) = ExpressionCodegen.getInputVarsForChildren(ctx, expr) - assert(inputAttrs.length == 1) - assert(inputAttrs(0) == Tuple2(IntegerType, true)) + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + assert(inputVars.length == 1) + assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == true) - val params = ExpressionCodegen.prepareFunctionParams(ctx, inputAttrs, inputVars) + val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) assert(params.length == 2) assert(params(0) == Tuple2("value2", "int value2")) assert(params(1) == Tuple2("isNull2", "boolean isNull2")) From 655917cadf86ab17b8a730f282db544cb348d63f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Dec 2017 00:21:14 +0000 Subject: [PATCH 19/23] Remove redundant optimization. --- .../spark/sql/catalyst/expressions/BoundAttribute.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index e347ee4ab387c..6a17a397b3ef2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -61,11 +61,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { val oev = ctx.currentVars(ordinal) - if (nullable) { - ev.isNull = oev.isNull - } else { - ev.isNull = "false" - } + ev.isNull = oev.isNull ev.value = oev.value ev.copy(code = oev.code) } else { From c083a7955cd6fb54e0448176d9684496fae48e6f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Dec 2017 00:41:37 +0000 Subject: [PATCH 20/23] Use utility method. --- .../sql/catalyst/expressions/codegen/ExpressionCodegen.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala index 50e907e671c05..aa9e1ae144be6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -110,7 +110,7 @@ object ExpressionCodegen { Seq(ctx.INPUT_ROW) // An expression which is not evaluated yet. Tracks down to find input rows. - case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal).code != "" => + case BoundReference(ordinal, _, _) if !ctx.currentVars(ordinal).isEvaluated() => trackDownRow(ctx, ctx.currentVars(ordinal)) case _ => Seq.empty @@ -130,7 +130,7 @@ object ExpressionCodegen { inputRows += curExprCode.inputRow } curExprCode.inputVars.foreach { inputVar => - if (inputVar.exprCode.code != "") { + if (!inputVar.exprCode.isEvaluated()) { exprCodes.enqueue(inputVar.exprCode) } } From 1251dfa305f4f1f8e34d7deb235bfa500d057fb4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Dec 2017 07:56:49 +0000 Subject: [PATCH 21/23] Address comments. --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 6 +----- .../catalyst/expressions/codegen/ExpressionCodegen.scala | 8 +++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 07296fcd2737a..517d01b6bb4da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1016,11 +1016,7 @@ class CodegenContext { val expr = e.head // Generate the code for this expression tree. val eval = expr.genCode(this) - val state = if (expr.nullable) { - SubExprEliminationState(eval.isNull, eval.value) - } else { - SubExprEliminationState("false", eval.value) - } + val state = SubExprEliminationState(eval.isNull, eval.value) e.foreach(subExprEliminationExprs.put(_, state)) eval.code.trim } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala index aa9e1ae144be6..a89f60e5a4df3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -198,11 +198,9 @@ object ExpressionCodegen { * Helper function to calculate the size of an expression as function parameter. */ def calculateParamLength(ctx: CodegenContext, input: ExprInputVar): Int = { - ctx.javaType(input.dataType) match { - case (ctx.JAVA_LONG | ctx.JAVA_DOUBLE) if !input.nullable => 2 - case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 3 - case _ if !input.nullable => 1 - case _ => 2 + (if (input.nullable) 1 else 0) + ctx.javaType(input.dataType) match { + case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 2 + case _ => 1 } } From c4f15f79f42350ae62ef7452a880cd4ada9ab275 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Dec 2017 08:09:22 +0000 Subject: [PATCH 22/23] Move isLiteral and isEvaluated into ExpressionCodegen. --- .../codegen/ExpressionCodegen.scala | 36 ++++++++++-- .../expressions/codegen/ExprCodeSuite.scala | 55 ------------------- .../codegen/ExpressionCodegenSuite.scala | 32 +++++++++++ 3 files changed, 63 insertions(+), 60 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprCodeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala index a89f60e5a4df3..a2dda48e951d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -110,7 +110,7 @@ object ExpressionCodegen { Seq(ctx.INPUT_ROW) // An expression which is not evaluated yet. Tracks down to find input rows. - case BoundReference(ordinal, _, _) if !ctx.currentVars(ordinal).isEvaluated() => + case BoundReference(ordinal, _, _) if !isEvaluated(ctx.currentVars(ordinal)) => trackDownRow(ctx, ctx.currentVars(ordinal)) case _ => Seq.empty @@ -130,7 +130,7 @@ object ExpressionCodegen { inputRows += curExprCode.inputRow } curExprCode.inputVars.foreach { inputVar => - if (!inputVar.exprCode.isEvaluated()) { + if (!isEvaluated(inputVar.exprCode)) { exprCodes.enqueue(inputVar.exprCode) } } @@ -160,7 +160,7 @@ object ExpressionCodegen { child.flatMap { // An evaluated variable. case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null && - ctx.currentVars(ordinal).isEvaluated() => + isEvaluated(ctx.currentVars(ordinal)) => Seq(ExprInputVar(ctx.currentVars(ordinal), b.dataType, b.nullable)) // An input variable which is not evaluated yet. Tracks down to find any evaluated variables @@ -184,7 +184,7 @@ object ExpressionCodegen { while (exprCodes.nonEmpty) { exprCodes.dequeue().inputVars.foreach { inputVar => - if (inputVar.exprCode.isEvaluated()) { + if (isEvaluated(inputVar.exprCode)) { inputVars += inputVar } else { exprCodes.enqueue(inputVar.exprCode) @@ -227,7 +227,7 @@ object ExpressionCodegen { val ev = inputVar.exprCode // Only include the expression value if it is not a literal. - if (!ev.isLiteral()) { + if (!isLiteral(ev)) { val argType = ctx.javaType(inputVar.dataType) params += ((ev.value, s"$argType ${ev.value}")) } @@ -240,4 +240,30 @@ object ExpressionCodegen { params }.distinct } + + /** + * Only applied to the `ExprCode` in `ctx.currentVars`. + * Returns true if this value is a literal. + */ + def isLiteral(exprCode: ExprCode): Boolean = { + assert(exprCode.value.nonEmpty, "ExprCode.value can't be empty string.") + + if (exprCode.value == "true" || exprCode.value == "false" || exprCode.value == "null") { + true + } else { + // The valid characters for the first character of a Java variable is [a-zA-Z_$]. + exprCode.value.head match { + case v if v >= 'a' && v <= 'z' => false + case v if v >= 'A' && v <= 'Z' => false + case '_' | '$' => false + case _ => true + } + } + } + + /** + * Only applied to the `ExprCode` in `ctx.currentVars`. + * The code is emptied after evaluation. + */ + def isEvaluated(exprCode: ExprCode): Boolean = exprCode.code == "" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprCodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprCodeSuite.scala deleted file mode 100644 index ca895352a4964..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprCodeSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import org.apache.spark.SparkFunSuite - -class ExprCodeSuite extends SparkFunSuite { - - test("ExprCode.isLiteral: literals") { - val literals = Seq( - ExprCode("", "", "true"), - ExprCode("", "", "false"), - ExprCode("", "", "1"), - ExprCode("", "", "-1"), - ExprCode("", "", "1L"), - ExprCode("", "", "-1L"), - ExprCode("", "", "1.0f"), - ExprCode("", "", "-1.0f"), - ExprCode("", "", "0.1f"), - ExprCode("", "", "-0.1f"), - ExprCode("", "", """"string""""), - ExprCode("", "", "(byte)-1"), - ExprCode("", "", "(short)-1"), - ExprCode("", "", "null")) - - literals.foreach(l => assert(l.isLiteral() == true)) - } - - test("ExprCode.isLiteral: non literals") { - val variables = Seq( - ExprCode("", "", "var1"), - ExprCode("", "", "_var2"), - ExprCode("", "", "$var3"), - ExprCode("", "", "v1a2r3"), - ExprCode("", "", "_1v2a3r"), - ExprCode("", "", "$1v2a3r")) - - variables.foreach(v => assert(v.isLiteral() == false)) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala index f2a08daafdece..f139c740f798a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala @@ -133,4 +133,36 @@ class ExpressionCodegenSuite extends SparkFunSuite { assert(inputRows.length == 1) assert(inputRows(0) == "inputadaptor_row1") } + + test("isLiteral: literals") { + val literals = Seq( + ExprCode("", "", "true"), + ExprCode("", "", "false"), + ExprCode("", "", "1"), + ExprCode("", "", "-1"), + ExprCode("", "", "1L"), + ExprCode("", "", "-1L"), + ExprCode("", "", "1.0f"), + ExprCode("", "", "-1.0f"), + ExprCode("", "", "0.1f"), + ExprCode("", "", "-0.1f"), + ExprCode("", "", """"string""""), + ExprCode("", "", "(byte)-1"), + ExprCode("", "", "(short)-1"), + ExprCode("", "", "null")) + + literals.foreach(l => assert(ExpressionCodegen.isLiteral(l) == true)) + } + + test("isLiteral: non literals") { + val variables = Seq( + ExprCode("", "", "var1"), + ExprCode("", "", "_var2"), + ExprCode("", "", "$var3"), + ExprCode("", "", "v1a2r3"), + ExprCode("", "", "_1v2a3r"), + ExprCode("", "", "$1v2a3r")) + + variables.foreach(v => assert(ExpressionCodegen.isLiteral(v) == false)) + } } From f35974e1dfb47387dc952d30a55eee0354bdea63 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Dec 2017 14:07:22 +0000 Subject: [PATCH 23/23] Remove useless isLiteral and isEvaluted. Add one more test. --- .../expressions/codegen/CodeGenerator.scala | 23 +------- .../codegen/ExpressionCodegenSuite.scala | 52 +++++++++++++++++++ 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 517d01b6bb4da..9ee17774b7424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -63,28 +63,7 @@ case class ExprCode( var isNull: String, var value: String, var inputRow: String = null, - var inputVars: Seq[ExprInputVar] = Seq.empty) { - - // Returns true if this value is a literal. - def isLiteral(): Boolean = { - assert(value.nonEmpty, "ExprCode.value can't be empty string.") - - if (value == "true" || value == "false" || value == "null") { - true - } else { - // The valid characters for the first character of a Java variable is [a-zA-Z_$]. - value.head match { - case v if v >= 'a' && v <= 'z' => false - case v if v >= 'A' && v <= 'Z' => false - case '_' | '$' => false - case _ => true - } - } - } - - // The code is emptied after evaluation. - def isEvaluated(): Boolean = code == "" -} + var inputVars: Seq[ExprInputVar] = Seq.empty) /** * Represents an input variable [[ExprCode]] to an evaluation of an [[Expression]]. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala index f139c740f798a..39d58cabff228 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala @@ -134,6 +134,58 @@ class ExpressionCodegenSuite extends SparkFunSuite { assert(inputRows(0) == "inputadaptor_row1") } + test("Returns both input rows and variables for expression") { + val ctx = new CodegenContext() + // 5 input variables in currentVars: + // 1 evaluated variable (value1). + // 3 not evaluated variables. + // value2 depends on an evaluated column from other operator. + // value3 depends on an input row from other operator. + // value4 depends on a not evaluated yet column from other operator. + // 1 null indicating to use input row "i". + val currentVars = Seq( + ExprCode("", isNull = "false", value = "value1"), + ExprCode("fake code;", isNull = "isNull2", value = "value2"), + ExprCode("fake code;", isNull = "isNull3", value = "value3"), + ExprCode("fake code;", isNull = "isNull4", value = "value4"), + null) + // value2 depends on this evaluated column. + currentVars(1).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull5", value = "value5"), + dataType = IntegerType, nullable = true)) + // value3 depends on an input row "inputadaptor_row1". + currentVars(2).inputRow = "inputadaptor_row1" + // value4 depends on another not evaluated yet column. + currentVars(3).inputVars = Seq(ExprInputVar(ExprCode("fake code;", + isNull = "isNull6", value = "value6"), dataType = IntegerType, nullable = true)) + ctx.currentVars = currentVars + ctx.INPUT_ROW = "i" + + // expr: if (false) { value1 + value2 } else { (value3 + value4) + i[5] } + val expr = If(Literal(false), + Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)), + Add(Add(BoundReference(2, IntegerType, nullable = true), + BoundReference(3, IntegerType, nullable = true)), + BoundReference(4, IntegerType, nullable = true))) // this is based on input row "i". + + // input rows: "i", "inputadaptor_row1". + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 2) + assert(inputRows(0) == "inputadaptor_row1") + assert(inputRows(1) == "i") + + // input variables: value1 and value5 + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + assert(inputVars.length == 2) + + // value1 has inlined isNull "false", so don't need to include it in the params. + val inputVarParams = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) + assert(inputVarParams.length == 3) + assert(inputVarParams(0) == Tuple2("value1", "int value1")) + assert(inputVarParams(1) == Tuple2("value5", "int value5")) + assert(inputVarParams(2) == Tuple2("isNull5", "boolean isNull5")) + } + test("isLiteral: literals") { val literals = Seq( ExprCode("", "", "true"),