diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 97dff6ae8829..724decd240fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,9 +104,16 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") + val eval = doGenCode(ctx, ExprCode( JavaCode.isNullVariable(isNull), JavaCode.variable(value, dataType))) + eval.isNull = if (this.nullable) eval.isNull else FalseLiteral + + // Records current input row and variables of this expression. + eval.inputRow = ctx.INPUT_ROW + eval.inputVars = findInputVars(ctx, eval) + reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -117,9 +124,29 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Returns the input variables to this expression. + */ + private def findInputVars(ctx: CodegenContext, eval: ExprCode): Seq[ExprInputVar] = { + if (ctx.currentVars != null) { + this.collect { + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => + ExprInputVar(exprCode = ctx.currentVars(ordinal), + dataType = b.dataType, nullable = b.nullable) + } + } else { + Seq.empty + } + } + + /** + * In order to prevent 64kb compile error, reducing the size of generated codes by + * separating it into a function if the size exceeds a threshold. + */ private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { - // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + lazy val funcParams = ExpressionCodegen.getExpressionInputParams(ctx, this) + + if (eval.code.trim.length > 1024 && funcParams.isDefined) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull @@ -133,9 +160,12 @@ abstract class Expression extends TreeNode[Expression] { val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) + val callParams = funcParams.map(_._1.mkString(", ")).get + val declParams = funcParams.map(_._2.mkString(", ")).get + val funcFullName = ctx.addNewFunction(funcName, s""" - |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { + |private $javaType $funcName($declParams) { | ${eval.code.trim} | $setIsNull | return ${eval.value}; @@ -143,7 +173,7 @@ abstract class Expression extends TreeNode[Expression] { """.stripMargin) eval.value = JavaCode.variable(newValue, dataType) - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = s"$javaType $newValue = $funcFullName($callParams);" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f6b6775923ac..65f96a03c177 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,8 +55,15 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * to null. * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. + * @param inputRow A term that holds the input row name when generating this code. + * @param inputVars A list of [[ExprInputVar]] that holds input variables when generating this code. */ -case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) +case class ExprCode( + var code: String, + var isNull: ExprValue, + var value: ExprValue, + var inputRow: String = null, + var inputVars: Seq[ExprInputVar] = Seq.empty) object ExprCode { def apply(isNull: ExprValue, value: ExprValue): ExprCode = { @@ -72,6 +79,15 @@ object ExprCode { } } +/** + * Represents an input variable [[ExprCode]] to an evaluation of an [[Expression]]. + * + * @param exprCode The [[ExprCode]] that represents the evaluation result for the input variable. + * @param dataType The data type of the input variable. + * @param nullable Whether the input variable can be null or not. + */ +case class ExprInputVar(exprCode: ExprCode, dataType: DataType, nullable: Boolean) + /** * State used for subexpression elimination. * @@ -1006,16 +1022,25 @@ class CodegenContext { commonExprs.foreach { e => val expr = e.head val fnName = freshName("subExpr") - val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + val isNull = if (expr.nullable) { + addMutableState(JAVA_BOOLEAN, "subExprIsNull") + } else { + "" + } val value = addMutableState(javaType(expr.dataType), "subExprValue") // Generate the code for this expression tree and wrap it in a function. val eval = expr.genCode(this) + val assignIsNull = if (expr.nullable) { + s"$isNull = ${eval.isNull};" + } else { + "" + } val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { | ${eval.code.trim} - | $isNull = ${eval.isNull}; + | $assignIsNull | $value = ${eval.value}; |} """.stripMargin @@ -1035,9 +1060,15 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState( - JavaCode.isNullGlobal(isNull), - JavaCode.global(value, expr.dataType)) + val state = if (expr.nullable) { + SubExprEliminationState( + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, expr.dataType)) + } else { + SubExprEliminationState( + FalseLiteral, + JavaCode.global(value, expr.dataType)) + } subExprEliminationExprs ++= e.map(_ -> state).toMap } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala new file mode 100644 index 000000000000..32c33fffd880 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegen.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DataType + +/** + * Defines util methods used in expression code generation. + */ +object ExpressionCodegen { + + /** + * Given an expression, returns the all necessary parameters to evaluate it, so the generated + * code of this expression can be split in a function. + * The 1st string in returned tuple is the parameter strings used to call the function. + * The 2nd string in returned tuple is the parameter strings used to declare the function. + * + * Returns `None` if it can't produce valid parameters. + * + * Params to include: + * 1. Evaluated columns referred by this, children or deferred expressions. + * 2. Rows referred by this, children or deferred expressions. + * 3. Eliminated subexpressions referred by children expressions. + */ + def getExpressionInputParams( + ctx: CodegenContext, + expr: Expression): Option[(Seq[String], Seq[String])] = { + val subExprs = getSubExprInChildren(ctx, expr) + val subExprCodes = getSubExprCodes(ctx, subExprs) + val subVars = subExprs.zip(subExprCodes).map { case (subExpr, subExprCode) => + ExprInputVar(subExprCode, subExpr.dataType, subExpr.nullable) + } + val paramsFromSubExprs = prepareFunctionParams(ctx, subVars) + + val inputVars = getInputVarsForChildren(ctx, expr) + val paramsFromColumns = prepareFunctionParams(ctx, inputVars) + + val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr) + val paramsFromRows = inputRows.distinct.filter(_ != null).map { row => + (row, s"InternalRow $row") + } + + val paramsLength = getParamLength(ctx, inputVars ++ subVars) + paramsFromRows.length + // Maximum allowed parameter number for Java's method descriptor. + if (paramsLength > 255) { + None + } else { + val allParams = (paramsFromRows ++ paramsFromColumns ++ paramsFromSubExprs).unzip + val callParams = allParams._1.distinct + val declParams = allParams._2.distinct + Some((callParams, declParams)) + } + } + + /** + * Returns the eliminated subexpressions in the children expressions. + */ + def getSubExprInChildren(ctx: CodegenContext, expr: Expression): Seq[Expression] = { + expr.children.flatMap { child => + child.collect { + case e if ctx.subExprEliminationExprs.contains(e) => e + } + }.distinct + } + + /** + * A small helper function to return `ExprCode`s that represent subexpressions. + */ + def getSubExprCodes(ctx: CodegenContext, subExprs: Seq[Expression]): Seq[ExprCode] = { + subExprs.map { subExpr => + val state = ctx.subExprEliminationExprs(subExpr) + ExprCode(code = "", value = state.value, isNull = state.isNull) + } + } + + /** + * Retrieves previous input rows referred by children and deferred expressions. + */ + def getInputRowsForChildren(ctx: CodegenContext, expr: Expression): Seq[String] = { + expr.children.flatMap(getInputRows(ctx, _)).distinct + } + + /** + * Given a child expression, retrieves previous input rows referred by it or deferred expressions + * which are needed to evaluate it. + */ + def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = { + child.flatMap { + // An expression directly evaluates on current input row. + case BoundReference(ordinal, _, _) if ctx.currentVars == null || + ctx.currentVars(ordinal) == null => + Seq(ctx.INPUT_ROW) + + // An expression which is not evaluated yet. Tracks down to find input rows. + case BoundReference(ordinal, _, _) if !isEvaluated(ctx.currentVars(ordinal)) => + trackDownRow(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down input rows referred by the generated code snippet. + */ + def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = { + val exprCodes = mutable.Queue[ExprCode](exprCode) + val inputRows = mutable.ArrayBuffer.empty[String] + + while (exprCodes.nonEmpty) { + val curExprCode = exprCodes.dequeue() + if (curExprCode.inputRow != null) { + inputRows += curExprCode.inputRow + } + curExprCode.inputVars.foreach { inputVar => + if (!isEvaluated(inputVar.exprCode)) { + exprCodes.enqueue(inputVar.exprCode) + } + } + } + inputRows + } + + /** + * Retrieves previously evaluated columns referred by children and deferred expressions. + * Returned tuple contains the list of expressions and the list of generated codes. + */ + def getInputVarsForChildren( + ctx: CodegenContext, + expr: Expression): Seq[ExprInputVar] = { + expr.children.flatMap(getInputVars(ctx, _)).distinct + } + + /** + * Given a child expression, retrieves previously evaluated columns referred by it or + * deferred expressions which are needed to evaluate it. + */ + def getInputVars(ctx: CodegenContext, child: Expression): Seq[ExprInputVar] = { + if (ctx.currentVars == null) { + return Seq.empty + } + + child.flatMap { + // An evaluated variable. + case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null && + isEvaluated(ctx.currentVars(ordinal)) => + Seq(ExprInputVar(ctx.currentVars(ordinal), b.dataType, b.nullable)) + + // An input variable which is not evaluated yet. Tracks down to find any evaluated variables + // in the expression path. + // E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to + // "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so + // to include them into parameters, if not, we track down further. + case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null => + trackDownVar(ctx, ctx.currentVars(ordinal)) + + case _ => Seq.empty + }.distinct + } + + /** + * Tracks down previously evaluated columns referred by the generated code snippet. + */ + def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[ExprInputVar] = { + val exprCodes = mutable.Queue[ExprCode](exprCode) + val inputVars = mutable.ArrayBuffer.empty[ExprInputVar] + + while (exprCodes.nonEmpty) { + exprCodes.dequeue().inputVars.foreach { inputVar => + if (isEvaluated(inputVar.exprCode)) { + inputVars += inputVar + } else { + exprCodes.enqueue(inputVar.exprCode) + } + } + } + inputVars + } + + /** + * Determines the parameter length in a Java method for given parameters. + */ + def getParamLength(ctx: CodegenContext, inputs: Seq[ExprInputVar]): Int = { + // Method parameter length only depends on data type and nullability. Make fake catalyst + // expressions for calculation. + val exprs = inputs.map(inputVar => BoundReference(1, inputVar.dataType, inputVar.nullable)) + CodeGenerator.calculateParamLength(exprs) + } + + /** + * Given the lists of input attributes and variables to this expression, returns the strings of + * funtion parameters. The first is the variable names used to call the function, the second is + * the parameters used to declare the function in generated code. + */ + def prepareFunctionParams( + ctx: CodegenContext, + inputVars: Seq[ExprInputVar]): Seq[(String, String)] = { + inputVars.flatMap { inputVar => + val params = mutable.ArrayBuffer.empty[(String, String)] + val ev = inputVar.exprCode + + // Only include the expression value if it can't be accessed in a method. + if (!ev.value.canGlobalAccess) { + val argType = CodeGenerator.javaType(inputVar.dataType) + params += ((ev.value, s"$argType ${ev.value}")) + } + + // If it is a nullable expression and `isNull` can't be accessed in a method. + if (inputVar.nullable && !ev.isNull.canGlobalAccess) { + params += ((ev.isNull, s"boolean ${ev.isNull}")) + } + + params + }.distinct + } + + /** + * Only applied to the `ExprCode` in `ctx.currentVars`. + * The code is emptied after evaluation. + */ + def isEvaluated(exprCode: ExprCode): Boolean = exprCode.code == "" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 74ff01848886..e53a85c767ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -120,6 +120,7 @@ object JavaCode { trait ExprValue extends JavaCode { def javaType: Class[_] def isPrimitive: Boolean = javaType.isPrimitive + def canGlobalAccess: Boolean = false } object ExprValue { @@ -146,6 +147,7 @@ case class VariableValue(variableName: String, javaType: Class[_]) extends ExprV */ case class GlobalValue(value: String, javaType: Class[_]) extends ExprValue { override def code: String = value + override def canGlobalAccess: Boolean = true } /** @@ -160,6 +162,8 @@ class LiteralValue(val value: String, val javaType: Class[_]) extends ExprValue } override def hashCode(): Int = value.hashCode() * 31 + javaType.hashCode() + + override def canGlobalAccess: Boolean = true } case object TrueLiteral extends LiteralValue("true", JBool.TYPE) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala new file mode 100644 index 000000000000..78aedfad2035 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExpressionCodegenSuite.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.IntegerType + +class ExpressionCodegenSuite extends SparkFunSuite { + + test("Returns eliminated subexpressions for expression") { + val ctx = new CodegenContext() + val subExpr = Add(Literal(1), Literal(2)) + val exprs = Seq(Add(subExpr, Literal(3)), Add(subExpr, Literal(4))) + + ctx.generateExpressions(exprs, doSubexpressionElimination = true) + val subexpressions = ExpressionCodegen.getSubExprInChildren(ctx, exprs(0)) + assert(subexpressions.length == 1 && subexpressions(0) == subExpr) + } + + test("Gets parameters for subexpressions") { + val ctx = new CodegenContext() + val subExprs = Seq( + Add(Literal(1), AttributeReference("a", IntegerType, nullable = false)()), // non-nullable + Add(Literal(2), AttributeReference("b", IntegerType, nullable = true)())) // nullable + + ctx.subExprEliminationExprs += (subExprs(0) -> + SubExprEliminationState(FalseLiteral, JavaCode.variable("value1", IntegerType))) + ctx.subExprEliminationExprs += (subExprs(1) -> + SubExprEliminationState( + JavaCode.isNullVariable("isNull2"), + JavaCode.variable("value2", IntegerType))) + + val subExprCodes = ExpressionCodegen.getSubExprCodes(ctx, subExprs) + val subVars = subExprs.zip(subExprCodes).map { case (expr, exprCode) => + ExprInputVar(exprCode, expr.dataType, expr.nullable) + } + val params = ExpressionCodegen.prepareFunctionParams(ctx, subVars) + assert(params.length == 3) + assert(params(0) == Tuple2("value1", "int value1")) + assert(params(1) == Tuple2("value2", "int value2")) + assert(params(2) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input variables for expression: current variables") { + val ctx = new CodegenContext() + val currentVars = Seq( + // evaluated + ExprCode("", + isNull = FalseLiteral, + value = JavaCode.variable("value1", IntegerType)), + // evaluated + ExprCode("", + isNull = JavaCode.isNullVariable("isNull2"), + value = JavaCode.variable("value2", IntegerType)), + // not evaluated + ExprCode("fake code;", + isNull = JavaCode.isNullVariable("isNull3"), + value = JavaCode.variable("value3", IntegerType))) + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = If(Literal(false), + Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)), + BoundReference(2, IntegerType, nullable = true)) + + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + // Only two evaluated variables included. + assert(inputVars.length == 2) + assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == false) + assert(inputVars(1).dataType == IntegerType && inputVars(1).nullable == true) + assert(inputVars(0).exprCode == currentVars(0)) + assert(inputVars(1).exprCode == currentVars(1)) + + val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) + assert(params.length == 3) + assert(params(0) == Tuple2("value1", "int value1")) + assert(params(1) == Tuple2("value2", "int value2")) + assert(params(2) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input variables for expression: deferred variables") { + val ctx = new CodegenContext() + + // The referred column is not evaluated yet. But it depends on an evaluated column from + // other operator. + val currentVars = Seq(ExprCode("fake code;", + isNull = JavaCode.isNullVariable("isNull1"), + value = JavaCode.variable("value1", IntegerType))) + + // currentVars(0) depends on this evaluated column. + currentVars(0).inputVars = Seq( + ExprInputVar(ExprCode("", + isNull = JavaCode.isNullVariable("isNull2"), + value = JavaCode.variable("value2", IntegerType)), + dataType = IntegerType, nullable = true)) + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + assert(inputVars.length == 1) + assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == true) + + val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) + assert(params.length == 2) + assert(params(0) == Tuple2("value2", "int value2")) + assert(params(1) == Tuple2("isNull2", "boolean isNull2")) + } + + test("Returns input rows for expression") { + val ctx = new CodegenContext() + ctx.currentVars = null + ctx.INPUT_ROW = "i" + + val expr = Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)) + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 1) + assert(inputRows(0) == "i") + } + + test("Returns input rows for expression: deferred expression") { + val ctx = new CodegenContext() + + // The referred column is not evaluated yet. But it depends on an input row from + // other operator. + val currentVars = Seq(ExprCode("fake code;", + isNull = JavaCode.isNullVariable("isNull1"), + value = JavaCode.variable("value1", IntegerType))) + currentVars(0).inputRow = "inputadaptor_row1" + ctx.currentVars = currentVars + ctx.INPUT_ROW = null + + val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false)) + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 1) + assert(inputRows(0) == "inputadaptor_row1") + } + + test("Returns both input rows and variables for expression") { + val ctx = new CodegenContext() + // 5 input variables in currentVars: + // 1 evaluated variable (value1). + // 3 not evaluated variables. + // value2 depends on an evaluated column from other operator. + // value3 depends on an input row from other operator. + // value4 depends on a not evaluated yet column from other operator. + // 1 null indicating to use input row "i". + val currentVars = Seq( + ExprCode("", + isNull = FalseLiteral, + value = JavaCode.variable("value1", IntegerType)), + ExprCode("fake code;", + isNull = JavaCode.isNullVariable("isNull2"), + value = JavaCode.variable("value2", IntegerType)), + ExprCode("fake code;", + isNull = JavaCode.isNullVariable("isNull3"), + value = JavaCode.variable("value3", IntegerType)), + ExprCode("fake code;", + isNull = JavaCode.isNullVariable("isNull4"), + value = JavaCode.variable("value4", IntegerType)), + null) + // value2 depends on this evaluated column. + currentVars(1).inputVars = Seq( + ExprInputVar(ExprCode("", + isNull = JavaCode.isNullVariable("isNull5"), + value = JavaCode.variable("value5", IntegerType)), + dataType = IntegerType, nullable = true)) + // value3 depends on an input row "inputadaptor_row1". + currentVars(2).inputRow = "inputadaptor_row1" + // value4 depends on another not evaluated yet column. + currentVars(3).inputVars = Seq( + ExprInputVar(ExprCode("fake code;", + isNull = JavaCode.isNullVariable("isNull6"), + value = JavaCode.variable("value6", IntegerType)), + dataType = IntegerType, nullable = true)) + ctx.currentVars = currentVars + ctx.INPUT_ROW = "i" + + // expr: if (false) { value1 + value2 } else { (value3 + value4) + i[5] } + val expr = If(Literal(false), + Add(BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = true)), + Add(Add(BoundReference(2, IntegerType, nullable = true), + BoundReference(3, IntegerType, nullable = true)), + BoundReference(4, IntegerType, nullable = true))) // this is based on input row "i". + + // input rows: "i", "inputadaptor_row1". + val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr) + assert(inputRows.length == 2) + assert(inputRows(0) == "inputadaptor_row1") + assert(inputRows(1) == "i") + + // input variables: value1 and value5 + val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr) + assert(inputVars.length == 2) + + // value1 has inlined isNull "false", so don't need to include it in the params. + val inputVarParams = ExpressionCodegen.prepareFunctionParams(ctx, inputVars) + assert(inputVarParams.length == 3) + assert(inputVarParams(0) == Tuple2("value1", "int value1")) + assert(inputVarParams(1) == Tuple2("value5", "int value5")) + assert(inputVarParams(2) == Tuple2("isNull5", "boolean isNull5")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index fc3dbc1c5591..26f8bbc77085 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -119,7 +119,9 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { |}""".stripMargin) ctx.currentVars = null - val rowidx = ctx.freshName("rowIdx") + // `rowIdx` isn't in `ctx.currentVars`. If the expressions are split later, we can't track it. + // So making it as global variable. + val rowidx = ctx.addMutableState(CodeGenerator.JAVA_INT, "rowIdx") val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } @@ -139,7 +141,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { | int $numRows = $batch.numRows(); | int $localEnd = $numRows - $idx; | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | int $rowidx = $idx + $localIdx; + | $rowidx = $idx + $localIdx; | ${consume(ctx, columnsBatchInput).trim} | $shouldStop | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 6fa716d9fade..20f8be91d943 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -192,7 +192,8 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType), + inputRow = matched) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9180a22c260f..49865a2393f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution import org.apache.spark.metrics.source.CodegenMetrics -import org.apache.spark.sql.{QueryTest, Row, SaveMode} +import org.apache.spark.sql.{Column, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec @@ -232,6 +233,26 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-22551: Fix 64kb limit for deeply nested expressions under wholestage codegen") { + import testImplicits._ + withTempPath { dir => + val path = dir.getCanonicalPath + val df = Seq(("abc", 1)).toDF("key", "int") + df.write.parquet(path) + + var strExpr: Expression = col("key").expr + for (_ <- 1 to 150) { + strExpr = Decode(Encode(strExpr, Literal("utf-8")), Literal("utf-8")) + } + val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) + + val df2 = spark.read.parquet(path).select(expressions.map(Column(_)): _*) + val plan = df2.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) + df2.collect() + } + } + test("Control splitting consume function by operators with config") { import testImplicits._ val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*)