diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4c1bfcfdf7f17..1309cd30e391c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -112,7 +112,7 @@ private[codegen] case class NewFunctionSpec( * A context for codegen, tracking a list of objects that could be passed into generated Java * function. */ -class CodegenContext { +class CodegenContext extends Logging { import CodeGenerator._ @@ -1028,13 +1028,67 @@ class CodegenContext { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) - val codes = commonExprs.map { e => - val expr = e.head - // Generate the code for this expression tree. - val eval = expr.genCode(this) - val state = SubExprEliminationState(eval.isNull, eval.value) - e.foreach(localSubExprEliminationExprs.put(_, state)) - eval.code.toString + val commonExprVals = commonExprs.map(_.head.genCode(this)) + + lazy val nonSplitExprCode = { + commonExprs.zip(commonExprVals).map { case (exprs, eval) => + // Generate the code for this expression tree. + val state = SubExprEliminationState(eval.isNull, eval.value) + exprs.foreach(localSubExprEliminationExprs.put(_, state)) + eval.code.toString + } + } + + val codes = if (commonExprVals.map(_.code.length).sum > SQLConf.get.methodSplitThreshold) { + if (commonExprs.map(calculateParamLength).forall(isValidParamLength)) { + commonExprs.zipWithIndex.map { case (exprs, i) => + val expr = exprs.head + val eval = commonExprVals(i) + + val isNullLiteral = eval.isNull match { + case TrueLiteral | FalseLiteral => true + case _ => false + } + val (isNull, isNullEvalCode) = if (!isNullLiteral) { + val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") + } else { + (eval.isNull, "") + } + + // Generate the code for this expression tree and wrap it in a function. + val fnName = freshName("subExpr") + val inputVars = getLocalInputVariableValues(this, expr).toSeq + val argList = inputVars.map(v => s"${v.javaType.getName} ${v.variableName}") + val returnType = javaType(expr.dataType) + val fn = + s""" + |private $returnType $fnName(${argList.mkString(", ")}) { + | ${eval.code} + | $isNullEvalCode + | return ${eval.value}; + |} + """.stripMargin + + val value = freshName("subExprValue") + val state = SubExprEliminationState(isNull, JavaCode.variable(value, expr.dataType)) + exprs.foreach(localSubExprEliminationExprs.put(_, state)) + val inputVariables = inputVars.map(_.variableName).mkString(", ") + s"$returnType $value = ${addNewFunction(fnName, fn)}($inputVariables);" + } + } else { + val errMsg = "Failed to split subexpression code into small functions because the " + + "parameter length of at least one split function went over the JVM limit: " + + MAX_JVM_METHOD_PARAMS_LENGTH + if (Utils.isTesting) { + throw new IllegalStateException(errMsg) + } else { + logInfo(errMsg) + nonSplitExprCode + } + } + } else { + nonSplitExprCode } SubExprCodes(codes, localSubExprEliminationExprs.toMap) } @@ -1620,7 +1674,7 @@ object CodeGenerator extends Logging { def getLocalInputVariableValues( ctx: CodegenContext, expr: Expression, - subExprs: Map[Expression, SubExprEliminationState]): Set[VariableValue] = { + subExprs: Map[Expression, SubExprEliminationState] = Map.empty): Set[VariableValue] = { val argSet = mutable.Set[VariableValue]() if (ctx.INPUT_ROW != null) { argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow]) @@ -1775,6 +1829,10 @@ object CodeGenerator extends Logging { * length less than a pre-defined constant. */ def isValidParamLength(paramLength: Int): Boolean = { - paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH + // This config is only for testing + SQLConf.get.getConfString("spark.sql.CodeGenerator.validParamLength", null) match { + case null | "" => paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH + case validLength => paramLength <= validLength.toInt + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9242583d36717..c24d6f141a48d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -261,14 +261,6 @@ case class HashAggregateExec( """.stripMargin } - private def isValidParamLength(paramLength: Int): Boolean = { - // This config is only for testing - sqlContext.getConf("spark.sql.HashAggregateExec.validParamLength", null) match { - case null | "" => CodeGenerator.isValidParamLength(paramLength) - case validLength => paramLength <= validLength.toInt - } - } - // Splits aggregate code into small functions because the most of JVM implementations // can not compile too long functions. Returns None if we are not able to split the given code. // @@ -294,7 +286,7 @@ case class HashAggregateExec( val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit - if (isValidParamLength(paramLength)) { + if (CodeGenerator.isValidParamLength(paramLength)) { Some(inputVarsForOneFunc) } else { None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index d8727d5b584f1..2ed9df778cae1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -403,7 +403,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession { withSQLConf( SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true", SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", - "spark.sql.HashAggregateExec.validParamLength" -> "0") { + "spark.sql.CodeGenerator.validParamLength" -> "0") { withTable("t") { val expectedErrMsg = "Failed to split aggregate code into small functions" Seq( @@ -419,4 +419,27 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession { } } } + + test("Give up splitting subexpression code if a parameter length goes over the limit") { + withSQLConf( + SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "false", + SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", + "spark.sql.CodeGenerator.validParamLength" -> "0") { + withTable("t") { + val expectedErrMsg = "Failed to split subexpression code into small functions" + Seq( + // Test case without keys + "SELECT AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1)) t(a, b, c)", + // Tet case with keys + "SELECT k, AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1, 1)) t(k, a, b, c) " + + "GROUP BY k").foreach { query => + val e = intercept[Exception] { + sql(query).collect + }.getCause + assert(e.isInstanceOf[IllegalStateException]) + assert(e.getMessage.contains(expectedErrMsg)) + } + } + } + } }