-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-29008][SQL] Define an individual method for each common subexpression in HashAggregateExec #25710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-29008][SQL] Define an individual method for each common subexpression in HashAggregateExec #25710
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -112,7 +112,7 @@ private[codegen] case class NewFunctionSpec( | |
| * A context for codegen, tracking a list of objects that could be passed into generated Java | ||
| * function. | ||
| */ | ||
| class CodegenContext { | ||
| class CodegenContext extends Logging { | ||
|
|
||
| import CodeGenerator._ | ||
|
|
||
|
|
@@ -1028,13 +1028,67 @@ class CodegenContext { | |
| // Get all the expressions that appear at least twice and set up the state for subexpression | ||
| // elimination. | ||
| val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) | ||
| val codes = commonExprs.map { e => | ||
| val expr = e.head | ||
| // Generate the code for this expression tree. | ||
| val eval = expr.genCode(this) | ||
| val state = SubExprEliminationState(eval.isNull, eval.value) | ||
| e.foreach(localSubExprEliminationExprs.put(_, state)) | ||
| eval.code.toString | ||
| val commonExprVals = commonExprs.map(_.head.genCode(this)) | ||
|
|
||
| lazy val nonSplitExprCode = { | ||
| commonExprs.zip(commonExprVals).map { case (exprs, eval) => | ||
| // Generate the code for this expression tree. | ||
| val state = SubExprEliminationState(eval.isNull, eval.value) | ||
| exprs.foreach(localSubExprEliminationExprs.put(_, state)) | ||
| eval.code.toString | ||
| } | ||
| } | ||
|
|
||
| val codes = if (commonExprVals.map(_.code.length).sum > SQLConf.get.methodSplitThreshold) { | ||
| if (commonExprs.map(calculateParamLength).forall(isValidParamLength)) { | ||
| commonExprs.zipWithIndex.map { case (exprs, i) => | ||
| val expr = exprs.head | ||
| val eval = commonExprVals(i) | ||
|
|
||
| val isNullLiteral = eval.isNull match { | ||
| case TrueLiteral | FalseLiteral => true | ||
| case _ => false | ||
| } | ||
| val (isNull, isNullEvalCode) = if (!isNullLiteral) { | ||
| val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") | ||
| (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") | ||
| } else { | ||
| (eval.isNull, "") | ||
| } | ||
|
|
||
| // Generate the code for this expression tree and wrap it in a function. | ||
| val fnName = freshName("subExpr") | ||
| val inputVars = getLocalInputVariableValues(this, expr).toSeq | ||
| val argList = inputVars.map(v => s"${v.javaType.getName} ${v.variableName}") | ||
| val returnType = javaType(expr.dataType) | ||
| val fn = | ||
| s""" | ||
| |private $returnType $fnName(${argList.mkString(", ")}) { | ||
| | ${eval.code} | ||
| | $isNullEvalCode | ||
| | return ${eval.value}; | ||
| |} | ||
| """.stripMargin | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ISTM we might be able to apply the same change in https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala#L1060-L1069 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, we can do it in a followup. |
||
|
|
||
| val value = freshName("subExprValue") | ||
| val state = SubExprEliminationState(isNull, JavaCode.variable(value, expr.dataType)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One advantage of global variable is we don't care how this expr value is used later. It is ok even it is used in a split function. It is a local variable means we need to be careful and guarantee that these expressions would only be used at same scope. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, I see. But, I just want add more pressure on the constant pool.... WDYT? @cloud-fan There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK the code of common subexpression execution is always put together, not split. I don't think we need to worry about it now. BTW I think one principle is: for corner cases which are really hard to generate code, we should just fallback to interpreted mode. |
||
| exprs.foreach(localSubExprEliminationExprs.put(_, state)) | ||
| val inputVariables = inputVars.map(_.variableName).mkString(", ") | ||
| s"$returnType $value = ${addNewFunction(fnName, fn)}($inputVariables);" | ||
| } | ||
| } else { | ||
| val errMsg = "Failed to split subexpression code into small functions because the " + | ||
| "parameter length of at least one split function went over the JVM limit: " + | ||
| MAX_JVM_METHOD_PARAMS_LENGTH | ||
| if (Utils.isTesting) { | ||
| throw new IllegalStateException(errMsg) | ||
| } else { | ||
| logInfo(errMsg) | ||
| nonSplitExprCode | ||
| } | ||
| } | ||
| } else { | ||
| nonSplitExprCode | ||
| } | ||
| SubExprCodes(codes, localSubExprEliminationExprs.toMap) | ||
| } | ||
|
|
@@ -1620,7 +1674,7 @@ object CodeGenerator extends Logging { | |
| def getLocalInputVariableValues( | ||
| ctx: CodegenContext, | ||
| expr: Expression, | ||
| subExprs: Map[Expression, SubExprEliminationState]): Set[VariableValue] = { | ||
| subExprs: Map[Expression, SubExprEliminationState] = Map.empty): Set[VariableValue] = { | ||
| val argSet = mutable.Set[VariableValue]() | ||
| if (ctx.INPUT_ROW != null) { | ||
| argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow]) | ||
|
|
@@ -1775,6 +1829,10 @@ object CodeGenerator extends Logging { | |
| * length less than a pre-defined constant. | ||
| */ | ||
| def isValidParamLength(paramLength: Int): Boolean = { | ||
| paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH | ||
| // This config is only for testing | ||
| SQLConf.get.getConfString("spark.sql.CodeGenerator.validParamLength", null) match { | ||
| case null | "" => paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH | ||
| case validLength => paramLength <= validLength.toInt | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -403,7 +403,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession { | |
| withSQLConf( | ||
| SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true", | ||
| SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", | ||
| "spark.sql.HashAggregateExec.validParamLength" -> "0") { | ||
| "spark.sql.CodeGenerator.validParamLength" -> "0") { | ||
| withTable("t") { | ||
| val expectedErrMsg = "Failed to split aggregate code into small functions" | ||
| Seq( | ||
|
|
@@ -419,4 +419,27 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession { | |
| } | ||
| } | ||
| } | ||
|
|
||
| test("Give up splitting subexpression code if a parameter length goes over the limit") { | ||
| withSQLConf( | ||
| SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "false", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test must be run under CODEGEN_SPLIT_AGGREGATE_FUNC = false? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, we need to. If that flag is true, |
||
| SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", | ||
| "spark.sql.CodeGenerator.validParamLength" -> "0") { | ||
| withTable("t") { | ||
| val expectedErrMsg = "Failed to split subexpression code into small functions" | ||
| Seq( | ||
| // Test case without keys | ||
| "SELECT AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1)) t(a, b, c)", | ||
| // Tet case with keys | ||
| "SELECT k, AVG(a + b), SUM(a + b + c) FROM VALUES((1, 1, 1, 1)) t(k, a, b, c) " + | ||
| "GROUP BY k").foreach { query => | ||
| val e = intercept[Exception] { | ||
| sql(query).collect | ||
| }.getCause | ||
| assert(e.isInstanceOf[IllegalStateException]) | ||
| assert(e.getMessage.contains(expectedErrMsg)) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although the original method should contain not only common expressions, this is probably good enough.