diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 743782a6453e..329ea5d42150 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -105,6 +105,12 @@ abstract class Expression extends TreeNode[Expression] { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val eval = doGenCode(ctx, ExprCode("", isNull, value)) + eval.isNull = if (this.nullable) eval.isNull else "false" + + // Records current input row and variables of this expression. + eval.inputRow = ctx.INPUT_ROW + eval.inputVars = findInputVars(ctx, eval) + reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -115,9 +121,29 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Returns the input variables to this expression. + */ + private def findInputVars(ctx: CodegenContext, eval: ExprCode): Seq[ExprInputVar] = { + if (ctx.currentVars != null) { + this.collect { + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => + ExprInputVar(exprCode = ctx.currentVars(ordinal), + dataType = b.dataType, nullable = b.nullable) + } + } else { + Seq.empty + } + } + + /** + * In order to prevent 64kb compile error, reducing the size of generated codes by + * separating it into a function if the size exceeds a threshold. + */ private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { - // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + lazy val funcParams = ExpressionCodegen.getExpressionInputParams(ctx, this) + + if (eval.code.trim.length > 1024 && funcParams.isDefined) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { val globalIsNull = ctx.freshName("globalIsNull") ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) @@ -132,9 +158,12 @@ abstract class Expression extends TreeNode[Expression] { val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) + val callParams = funcParams.map(_._1.mkString(", ")).get + val declParams = funcParams.map(_._2.mkString(", ")).get + val funcFullName = ctx.addNewFunction(funcName, s""" - |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { + |private $javaType $funcName($declParams) { | ${eval.code.trim} | $setIsNull | return ${eval.value}; @@ -142,7 +171,7 @@ abstract class Expression extends TreeNode[Expression] { """.stripMargin) eval.value = newValue - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = s"$javaType $newValue = $funcFullName($callParams);" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0498e61819f4..9ee17774b742 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,8 +55,24 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * to null. * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. + * @param inputRow A term that holds the input row name when generating this code. + * @param inputVars A list of [[ExprInputVar]] that holds input variables when generating this code. */ -case class ExprCode(var code: String, var isNull: String, var value: String) +case class ExprCode( + var code: String, + var isNull: String, + var value: String, + var inputRow: String = null, + var inputVars: Seq[ExprInputVar] = Seq.empty) + +/** + * Represents an input variable [[ExprCode]] to an evaluation of an [[Expression]]. + * + * @param exprCode The [[ExprCode]] that represents the evaluation result for the input variable. + * @param dataType The data type of the input variable. + * @param nullable Whether the input variable can be null or not. + */ +case class ExprInputVar(exprCode: ExprCode, dataType: DataType, nullable: Boolean) /** * State used for subexpression elimination. @@ -1001,16 +1017,25 @@ class CodegenContext { commonExprs.foreach { e => val expr = e.head val fnName = freshName("evalExpr") - val isNull = s"${fnName}IsNull" + val isNull = if (expr.nullable) { + s"${fnName}IsNull" + } else { + "" + } val value = s"${fnName}Value" // Generate the code for this expression tree and wrap it in a function. val eval = expr.genCode(this) + val assignIsNull = if (expr.nullable) { + s"$isNull = ${eval.isNull};" + } else { + "" + } val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { | ${eval.code.trim} - | $isNull = ${eval.isNull}; + | $assignIsNull | $value = ${eval.value}; |} """.stripMargin @@ -1028,12 +1053,17 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;") - addMutableState(javaType(expr.dataType), value, - s"$value = ${defaultValue(expr.dataType)};") + if (expr.nullable) { + addMutableState(JAVA_BOOLEAN, isNull) + } + addMutableState(javaType(expr.dataType), value) subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + val state = if (expr.nullable) { + SubExprEliminationState(isNull, value) + } else { + SubExprEliminationState("false", value) + } e.foreach(subExprEliminationExprs.put(_, state)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala new file mode 100644 index 000000000000..a2dda48e951d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DataType + +/** + * Defines util methods used in expression code generation. + */ +object ExpressionCodegen { + + /** + * Given an expression, returns the all necessary parameters to evaluate it, so the generated + * code of this expression can be split in a function. + * The 1st string in returned tuple is the parameter strings used to call the function. + * The 2nd string in returned tuple is the parameter strings used to declare the function. + * + * Returns `None` if it can't produce valid parameters. + * + * Params to include: + * 1. Evaluated columns referred by this, children or deferred expressions. + * 2. Rows referred by this, children or deferred expressions. + * 3. Eliminated subexpressions referred by children expressions. + */ + def getExpressionInputParams( + ctx: CodegenContext, + expr: Expression): Option[(Seq[String], Seq[String])] = { + val subExprs = getSubExprInChildren(ctx, expr) + val subExprCodes = getSubExprCodes(ctx, subExprs) + val subVars = subExprs.zip(subExprCodes).map { case (subExpr, subExprCode) => + ExprInputVar(subExprCode, subExpr.dataType, subExpr.nullable) + } + val paramsFromSubExprs = prepareFunctionParams(ctx, subVars) + + val inputVars = getInputVarsForChildren(ctx, expr) + val paramsFromColumns = prepareFunctionParams(ctx, inputVars) + + val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr) + val paramsFromRows = inputRows.distinct.filter(_ != null).map { row => + (row, s"InternalRow $row") + } + + val paramsLength = getParamLength(ctx, inputVars ++ subVars) + paramsFromRows.length + // Maximum allowed parameter number for Java's method descriptor. + if (paramsLength > 255) { + None + } else { + val allParams = (paramsFromRows ++ paramsFromColumns ++ paramsFromSubExprs).unzip + val callParams = allParams._1.distinct + val declParams = allParams._2.distinct + Some((callParams, declParams)) + } + } + + /** + * Returns the eliminated subexpressions in the children expressions. + */ + def getSubExprInChildren(ctx: CodegenContext, expr: Expression): Seq[Expression] = { + expr.children.flatMap { child => + child.collect { + case e if ctx.subExprEliminationExprs.contains(e) => e + } + }.distinct + } + + /** + * A small helper function to return `ExprCode`s that represent subexpressions. + */ + def getSubExprCodes(ctx: CodegenContext, subExprs: Seq[Expression]): Seq[ExprCode] = { + subExprs.map { subExpr => + val state = ctx.subExprEliminationExprs(subExpr) + ExprCode(code = "", value = state.value, isNull = state.isNull) + } + } + + /** + * Retrieves previous input rows referred by children and deferred expressions. + */ + def getInputRowsForChildren(ctx: CodegenContext, expr: Expression): Seq[String] = { + expr.children.flatMap(getInputRows(ctx, _)).distinct + } + + /** + * Given a child expression, retrieves previous input rows referred by it or deferred expressions + * which are needed to evaluate it. + */ + def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = { + child.flatMap { + // An expression directly evaluates on current input row. + case BoundReference(ordinal, _, _) if ctx.currentVars == null || + ctx.currentVars(ordinal) == null => + Seq(ctx.INPUT_ROW) + + // An expression which is not evaluated yet. Tracks down to find input rows. + case BoundReference(ordinal, _, _) if !isEvaluated(ctx.currentVars(ordinal)) => + trackDownRow(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down input rows referred by the generated code snippet. + */ + def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = { + val exprCodes = mutable.Queue[ExprCode](exprCode) + val inputRows = mutable.ArrayBuffer.empty[String] + + while (exprCodes.nonEmpty) { + val curExprCode = exprCodes.dequeue() + if (curExprCode.inputRow != null) { + inputRows += curExprCode.inputRow + } + curExprCode.inputVars.foreach { inputVar => + if (!isEvaluated(inputVar.exprCode)) { + exprCodes.enqueue(inputVar.exprCode) + } + } + } + inputRows + } + + /** + * Retrieves previously evaluated columns referred by children and deferred expressions. + * Returned tuple contains the list of expressions and the list of generated codes. + */ + def getInputVarsForChildren( + ctx: CodegenContext, + expr: Expression): Seq[ExprInputVar] = { + expr.children.flatMap(getInputVars(ctx, _)).distinct + } + + /** + * Given a child expression, retrieves previously evaluated columns referred by it or + * deferred expressions which are needed to evaluate it. + */ + def getInputVars(ctx: CodegenContext, child: Expression): Seq[ExprInputVar] = { + if (ctx.currentVars == null) { + return Seq.empty + } + + child.flatMap { + // An evaluated variable. + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null && + isEvaluated(ctx.currentVars(ordinal)) => + Seq(ExprInputVar(ctx.currentVars(ordinal), b.dataType, b.nullable)) + + // An input variable which is not evaluated yet. Tracks down to find any evaluated variables + // in the expression path. + // E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to + // "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so + // to include them into parameters, if not, we track down further. + case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => + trackDownVar(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down previously evaluated columns referred by the generated code snippet. + */ + def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[ExprInputVar] = { + val exprCodes = mutable.Queue[ExprCode](exprCode) + val inputVars = mutable.ArrayBuffer.empty[ExprInputVar] + + while (exprCodes.nonEmpty) { + exprCodes.dequeue().inputVars.foreach { inputVar => + if (isEvaluated(inputVar.exprCode)) { + inputVars += inputVar + } else { + exprCodes.enqueue(inputVar.exprCode) + } + } + } + inputVars + } + + /** + * Helper function to calculate the size of an expression as function parameter. + */ + def calculateParamLength(ctx: CodegenContext, input: ExprInputVar): Int = { + (if (input.nullable) 1 else 0) + ctx.javaType(input.dataType) match { + case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 2 + case _ => 1 + } + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length of 255 or less. `this` contributes one unit and a parameter of type long or double + * contributes two units. + */ + def getParamLength(ctx: CodegenContext, inputs: Seq[ExprInputVar]): Int = { + // Initial value is 1 for `this`. + 1 + inputs.map(calculateParamLength(ctx, _)).sum + } + + /** + * Given the lists of input attributes and variables to this expression, returns the strings of + * funtion parameters. The first is the variable names used to call the function, the second is + * the parameters used to declare the function in generated code. + */ + def prepareFunctionParams( + ctx: CodegenContext, + inputVars: Seq[ExprInputVar]): Seq[(String, String)] = { + inputVars.flatMap { inputVar => + val params = mutable.ArrayBuffer.empty[(String, String)] + val ev = inputVar.exprCode + + // Only include the expression value if it is not a literal. + if (!isLiteral(ev)) { + val argType = ctx.javaType(inputVar.dataType) + params += ((ev.value, s"$argType ${ev.value}")) + } + + // If it is a nullable expression and `isNull` is not a literal. + if (inputVar.nullable && ev.isNull != "true" && ev.isNull != "false") { + params += ((ev.isNull, s"boolean ${ev.isNull}")) + } + + params + }.distinct + } + + /** + * Only applied to the `ExprCode` in `ctx.currentVars`. + * Returns true if this value is a literal. + */ + def isLiteral(exprCode: ExprCode): Boolean = { + assert(exprCode.value.nonEmpty, "ExprCode.value can't be empty string.") + + if (exprCode.value == "true" || exprCode.value == "false" || exprCode.value == "null") { + true + } else { + // The valid characters for the first character of a Java variable is [a-zA-Z_$]. + exprCode.value.head match { + case v if v >= 'a' && v <= 'z' => false + case v if v >= 'A' && v <= 'Z' => false + case '_' | '$' => false + case _ => true + } + } + } + + /** + * Only applied to the `ExprCode` in `ctx.currentVars`. + * The code is emptied after evaluation. + */ + def isEvaluated(exprCode: ExprCode): Boolean = exprCode.code == "" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala new file mode 100644 index 000000000000..39d58cabff22 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.IntegerType + +class ExpressionCodegenSuite extends SparkFunSuite { + + test("Returns eliminated subexpressions for expression") { + val ctx = new CodegenContext() + val subExpr = Add(Literal(1), Literal(2)) + val exprs = Seq(Add(subExpr, Literal(3)), Add(subExpr, Literal(4))) + + ctx.generateExpressions(exprs, doSubexpressionElimination = true) + val subexpressions = ExpressionCodegen.getSubExprInChildren(ctx, exprs(0)) + assert(subexpressions.length == 1 && subexpressions(0) == subExpr) + } + + test("Gets parameters for subexpressions") { + val ctx = new CodegenContext() + val subExprs = Seq( + Add(Literal(1), AttributeReference("a", IntegerType, nullable = false)()), // non-nullable + Add(Literal(2), AttributeReference("b", IntegerType, nullable = true)())) // nullable + + ctx.subExprEliminationExprs.put(subExprs(0), SubExprEliminationState("false", "value1")) + ctx.subExprEliminationExprs.put(subExprs(1), SubExprEliminationState("isNull2", "value2")) + + val subExprCodes = ExpressionCodegen.getSubExprCodes(ctx, subExprs) + val subVars = subExprs.zip(subExprCodes).map { case (expr, exprCode) => + ExprInputVar(exprCode, expr.dataType, expr.nullable) + } + val params = ExpressionCodegen.prepareFunctionParams(ctx, subVars) + assert(params.length == 3) + assert(params(0) == Tuple2("value1", "int value1")) + assert(params(1) == Tuple2("value2", "int value2")) + assert(params(2) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input variables for expression: current variables") { + val ctx = new CodegenContext() + val currentVars = Seq( + ExprCode("", isNull = "false", value = "value1"), // evaluated + ExprCode("", isNull = "isNull2", value = "value2"), // evaluated + ExprCode("fake code;", isNull = "isNull3", value = "value3")) // not evaluated + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = If(Literal(false), + Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)), + BoundReference(2, IntegerType, nullable = true)) + + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + // Only two evaluated variables included. + assert(inputVars.length == 2) + assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == false) + assert(inputVars(1).dataType == IntegerType && inputVars(1).nullable == true) + assert(inputVars(0).exprCode == currentVars(0)) + assert(inputVars(1).exprCode == currentVars(1)) + + val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) + assert(params.length == 3) + assert(params(0) == Tuple2("value1", "int value1")) + assert(params(1) == Tuple2("value2", "int value2")) + assert(params(2) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input variables for expression: deferred variables") { + val ctx = new CodegenContext() + + // The referred column is not evaluated yet. But it depends on an evaluated column from + // other operator. + val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1")) + + // currentVars(0) depends on this evaluated column. + currentVars(0).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull2", value = "value2"), + dataType = IntegerType, nullable = true)) + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + assert(inputVars.length == 1) + assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == true) + + val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) + assert(params.length == 2) + assert(params(0) == Tuple2("value2", "int value2")) + assert(params(1) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input rows for expression") { + val ctx = new CodegenContext() + ctx.currentVars = null + ctx.INPUT_ROW = "i" + + val expr = Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)) + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 1) + assert(inputRows(0) == "i") + } + + test("Returns input rows for expression: deferred expression") { + val ctx = new CodegenContext() + + // The referred column is not evaluated yet. But it depends on an input row from + // other operator. + val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1")) + currentVars(0).inputRow = "inputadaptor_row1" + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 1) + assert(inputRows(0) == "inputadaptor_row1") + } + + test("Returns both input rows and variables for expression") { + val ctx = new CodegenContext() + // 5 input variables in currentVars: + // 1 evaluated variable (value1). + // 3 not evaluated variables. + // value2 depends on an evaluated column from other operator. + // value3 depends on an input row from other operator. + // value4 depends on a not evaluated yet column from other operator. + // 1 null indicating to use input row "i". + val currentVars = Seq( + ExprCode("", isNull = "false", value = "value1"), + ExprCode("fake code;", isNull = "isNull2", value = "value2"), + ExprCode("fake code;", isNull = "isNull3", value = "value3"), + ExprCode("fake code;", isNull = "isNull4", value = "value4"), + null) + // value2 depends on this evaluated column. + currentVars(1).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull5", value = "value5"), + dataType = IntegerType, nullable = true)) + // value3 depends on an input row "inputadaptor_row1". + currentVars(2).inputRow = "inputadaptor_row1" + // value4 depends on another not evaluated yet column. + currentVars(3).inputVars = Seq(ExprInputVar(ExprCode("fake code;", + isNull = "isNull6", value = "value6"), dataType = IntegerType, nullable = true)) + ctx.currentVars = currentVars + ctx.INPUT_ROW = "i" + + // expr: if (false) { value1 + value2 } else { (value3 + value4) + i[5] } + val expr = If(Literal(false), + Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)), + Add(Add(BoundReference(2, IntegerType, nullable = true), + BoundReference(3, IntegerType, nullable = true)), + BoundReference(4, IntegerType, nullable = true))) // this is based on input row "i". + + // input rows: "i", "inputadaptor_row1". + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 2) + assert(inputRows(0) == "inputadaptor_row1") + assert(inputRows(1) == "i") + + // input variables: value1 and value5 + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + assert(inputVars.length == 2) + + // value1 has inlined isNull "false", so don't need to include it in the params. + val inputVarParams = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) + assert(inputVarParams.length == 3) + assert(inputVarParams(0) == Tuple2("value1", "int value1")) + assert(inputVarParams(1) == Tuple2("value5", "int value5")) + assert(inputVarParams(2) == Tuple2("isNull5", "boolean isNull5")) + } + + test("isLiteral: literals") { + val literals = Seq( + ExprCode("", "", "true"), + ExprCode("", "", "false"), + ExprCode("", "", "1"), + ExprCode("", "", "-1"), + ExprCode("", "", "1L"), + ExprCode("", "", "-1L"), + ExprCode("", "", "1.0f"), + ExprCode("", "", "-1.0f"), + ExprCode("", "", "0.1f"), + ExprCode("", "", "-0.1f"), + ExprCode("", "", """"string""""), + ExprCode("", "", "(byte)-1"), + ExprCode("", "", "(short)-1"), + ExprCode("", "", "null")) + + literals.foreach(l => assert(ExpressionCodegen.isLiteral(l) == true)) + } + + test("isLiteral: non literals") { + val variables = Seq( + ExprCode("", "", "var1"), + ExprCode("", "", "_var2"), + ExprCode("", "", "$var3"), + ExprCode("", "", "v1a2r3"), + ExprCode("", "", "_1v2a3r"), + ExprCode("", "", "$1v2a3r")) + + variables.foreach(v => assert(ExpressionCodegen.isLiteral(v) == false)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index a9bfb634fbde..05186c447256 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -108,7 +108,10 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { |}""".stripMargin) ctx.currentVars = null + // `rowIdx` isn't in `ctx.currentVars`. If the expressions are split later, we can't track it. + // So making it as global variable. val rowidx = ctx.freshName("rowIdx") + ctx.addMutableState(ctx.JAVA_INT, rowidx) val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } @@ -128,7 +131,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { | int $numRows = $batch.numRows(); | int $localEnd = $numRows - $idx; | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | int $rowidx = $idx + $localIdx; + | $rowidx = $idx + $localIdx; | ${consume(ctx, columnsBatchInput).trim} | $shouldStop | } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca578c4..1281169b607c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{QueryTest, Row, SaveMode} +import org.apache.spark.sql.{Column, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec @@ -236,4 +237,24 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-22551: Fix 64kb limit for deeply nested expressions under wholestage codegen") { + import testImplicits._ + withTempPath { dir => + val path = dir.getCanonicalPath + val df = Seq(("abc", 1)).toDF("key", "int") + df.write.parquet(path) + + var strExpr: Expression = col("key").expr + for (_ <- 1 to 150) { + strExpr = Decode(Encode(strExpr, Literal("utf-8")), Literal("utf-8")) + } + val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) + + val df2 = spark.read.parquet(path).select(expressions.map(Column(_)): _*) + val plan = df2.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) + df2.collect() + } + } }