Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ private[codegen] case class NewFunctionSpec(
* A context for codegen, tracking a list of objects that could be passed into generated Java
* function.
*/
class CodegenContext {
class CodegenContext extends Logging {

import CodeGenerator._

Expand Down Expand Up @@ -1028,13 +1028,67 @@ class CodegenContext {
// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
val codes = commonExprs.map { e =>
val expr = e.head
// Generate the code for this expression tree.
val eval = expr.genCode(this)
val state = SubExprEliminationState(eval.isNull, eval.value)
e.foreach(localSubExprEliminationExprs.put(_, state))
eval.code.toString
val commonExprVals = commonExprs.map(_.head.genCode(this))

lazy val nonSplitExprCode = {
commonExprs.zip(commonExprVals).map { case (exprs, eval) =>
// Generate the code for this expression tree.
val state = SubExprEliminationState(eval.isNull, eval.value)
exprs.foreach(localSubExprEliminationExprs.put(_, state))
eval.code.toString
}
}

val codes = if (commonExprVals.map(_.code.length).sum > SQLConf.get.methodSplitThreshold) {
Copy link
Member

@viirya viirya Sep 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the original method should contain not only common expressions, this is probably good enough.

if (commonExprs.map(calculateParamLength).forall(isValidParamLength)) {
commonExprs.zipWithIndex.map { case (exprs, i) =>
val expr = exprs.head
val eval = commonExprVals(i)

val isNullLiteral = eval.isNull match {
case TrueLiteral | FalseLiteral => true
case _ => false
}
val (isNull, isNullEvalCode) = if (!isNullLiteral) {
val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
(JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};")
} else {
(eval.isNull, "")
}

// Generate the code for this expression tree and wrap it in a function.
val fnName = freshName("subExpr")
val inputVars = getLocalInputVariableValues(this, expr).toSeq
val argList = inputVars.map(v => s"${v.javaType.getName} ${v.variableName}")
val returnType = javaType(expr.dataType)
val fn =
s"""
|private $returnType $fnName(${argList.mkString(", ")}) {
| ${eval.code}
| $isNullEvalCode
| return ${eval.value};
|}
""".stripMargin
Copy link
Member Author

@maropu maropu Sep 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, we can do it in a followup.


val value = freshName("subExprValue")
val state = SubExprEliminationState(isNull, JavaCode.variable(value, expr.dataType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One advantage of global variable is we don't care how this expr value is used later. It is ok even it is used in a split function. It is a local variable means we need to be careful and guarantee that these expressions would only be used at same scope.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I see. But, I just want add more pressure on the constant pool.... WDYT? @cloud-fan

Copy link
Contributor

@cloud-fan cloud-fan Sep 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK the code of common subexpression execution is always put together, not split. I don't think we need to worry about it now.

BTW I think one principle is: for corner cases which are really hard to generate code, we should just fallback to interpreted mode.

exprs.foreach(localSubExprEliminationExprs.put(_, state))
val inputVariables = inputVars.map(_.variableName).mkString(", ")
s"$returnType $value = ${addNewFunction(fnName, fn)}($inputVariables);"
}
} else {
val errMsg = "Failed to split subexpression code into small functions because the " +
"parameter length of at least one split function went over the JVM limit: " +
MAX_JVM_METHOD_PARAMS_LENGTH
if (Utils.isTesting) {
throw new IllegalStateException(errMsg)
} else {
logInfo(errMsg)
nonSplitExprCode
}
}
} else {
nonSplitExprCode
}
SubExprCodes(codes, localSubExprEliminationExprs.toMap)
}
Expand Down Expand Up @@ -1620,7 +1674,7 @@ object CodeGenerator extends Logging {
def getLocalInputVariableValues(
ctx: CodegenContext,
expr: Expression,
subExprs: Map[Expression, SubExprEliminationState]): Set[VariableValue] = {
subExprs: Map[Expression, SubExprEliminationState] = Map.empty): Set[VariableValue] = {
val argSet = mutable.Set[VariableValue]()
if (ctx.INPUT_ROW != null) {
argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow])
Expand Down Expand Up @@ -1775,6 +1829,10 @@ object CodeGenerator extends Logging {
* length less than a pre-defined constant.
*/
def isValidParamLength(paramLength: Int): Boolean = {
paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH
// This config is only for testing
SQLConf.get.getConfString("spark.sql.CodeGenerator.validParamLength", null) match {
case null | "" => paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH
case validLength => paramLength <= validLength.toInt
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,6 @@ case class HashAggregateExec(
""".stripMargin
}

private def isValidParamLength(paramLength: Int): Boolean = {
// This config is only for testing
sqlContext.getConf("spark.sql.HashAggregateExec.validParamLength", null) match {
case null | "" => CodeGenerator.isValidParamLength(paramLength)
case validLength => paramLength <= validLength.toInt
}
}

// Splits aggregate code into small functions because the most of JVM implementations
// can not compile too long functions. Returns None if we are not able to split the given code.
//
Expand All @@ -294,7 +286,7 @@ case class HashAggregateExec(
val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)

// Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit
if (isValidParamLength(paramLength)) {
if (CodeGenerator.isValidParamLength(paramLength)) {
Some(inputVarsForOneFunc)
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession {
withSQLConf(
SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true",
SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1",
"spark.sql.HashAggregateExec.validParamLength" -> "0") {
"spark.sql.CodeGenerator.validParamLength" -> "0") {
withTable("t") {
val expectedErrMsg = "Failed to split aggregate code into small functions"
Seq(
Expand All @@ -419,4 +419,27 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession {
}
}
}

test("Give up splitting subexpression code if a parameter length goes over the limit") {
withSQLConf(
SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "false",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test must be run under CODEGEN_SPLIT_AGGREGATE_FUNC = false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1",
"spark.sql.CodeGenerator.validParamLength" -> "0") {
withTable("t") {
val expectedErrMsg = "Failed to split subexpression code into small functions"
Seq(
// Test case without keys
"SELECT AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1)) t(a, b, c)",
// Tet case with keys
"SELECT k, AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1, 1)) t(k, a, b, c) " +
"GROUP BY k").foreach { query =>
val e = intercept[Exception] {
sql(query).collect
}.getCause
assert(e.isInstanceOf[IllegalStateException])
assert(e.getMessage.contains(expectedErrMsg))
}
}
}
}
}