Skip to content

Commit cb0cddf

Browse files
maropucloud-fan
authored andcommitted
[SPARK-21870][SQL] Split aggregation code into small functions
## What changes were proposed in this pull request? This pr proposed to split aggregation code into small functions in `HashAggregateExec`. In #18810, we got performance regression if JVMs didn't compile too long functions. I checked and I found the codegen of `HashAggregateExec` frequently goes over the limit when a query has too many aggregate functions (e.g., q66 in TPCDS). The current master places all the generated aggregation code in a single function. In this pr, I modified the code to assign an individual function for each aggregate function (e.g., `SUM` and `AVG`). For example, in a query `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, the proposed code defines two functions for `SUM(a)` and `AVG(a)` as follows; - generated code with this pr (https://gist.github.com/maropu/812990012bc967a78364be0fa793f559): ``` /* 173 */ private void agg_doConsume_0(InternalRow inputadapter_row_0, long agg_expr_0_0, boolean agg_exprIsNull_0_0, double agg_expr_1_0, boolean agg_exprIsNull_1_0, long agg_expr_2_0, boolean agg_exprIsNull_2_0) throws java.io.IOException { /* 174 */ // do aggregate /* 175 */ // common sub-expressions /* 176 */ /* 177 */ // evaluate aggregate functions and update aggregation buffers /* 178 */ agg_doAggregate_sum_0(agg_exprIsNull_0_0, agg_expr_0_0); /* 179 */ agg_doAggregate_avg_0(agg_expr_1_0, agg_exprIsNull_1_0, agg_exprIsNull_2_0, agg_expr_2_0); /* 180 */ /* 181 */ } ... /* 071 */ private void agg_doAggregate_avg_0(double agg_expr_1_0, boolean agg_exprIsNull_1_0, boolean agg_exprIsNull_2_0, long agg_expr_2_0) throws java.io.IOException { /* 072 */ // do aggregate for avg /* 073 */ // evaluate aggregate function /* 074 */ boolean agg_isNull_19 = true; /* 075 */ double agg_value_19 = -1.0; ... /* 114 */ private void agg_doAggregate_sum_0(boolean agg_exprIsNull_0_0, long agg_expr_0_0) throws java.io.IOException { /* 115 */ // do aggregate for sum /* 116 */ // evaluate aggregate function /* 117 */ agg_agg_isNull_11_0 = true; /* 118 */ long agg_value_11 = -1L; ``` - generated code in the current master (https://gist.github.com/maropu/e9d772af2c98d8991a6a5f0af7841760) ``` /* 059 */ private void agg_doConsume_0(InternalRow localtablescan_row_0, int agg_expr_0_0) throws java.io.IOException { /* 060 */ // do aggregate /* 061 */ // common sub-expressions /* 062 */ boolean agg_isNull_4 = false; /* 063 */ long agg_value_4 = -1L; /* 064 */ if (!false) { /* 065 */ agg_value_4 = (long) agg_expr_0_0; /* 066 */ } /* 067 */ // evaluate aggregate function /* 068 */ agg_agg_isNull_7_0 = true; /* 069 */ long agg_value_7 = -1L; /* 070 */ do { /* 071 */ if (!agg_bufIsNull_0) { /* 072 */ agg_agg_isNull_7_0 = false; /* 073 */ agg_value_7 = agg_bufValue_0; /* 074 */ continue; /* 075 */ } /* 076 */ /* 077 */ boolean agg_isNull_9 = false; /* 078 */ long agg_value_9 = -1L; /* 079 */ if (!false) { /* 080 */ agg_value_9 = (long) 0; /* 081 */ } /* 082 */ if (!agg_isNull_9) { /* 083 */ agg_agg_isNull_7_0 = false; /* 084 */ agg_value_7 = agg_value_9; /* 085 */ continue; /* 086 */ } /* 087 */ /* 088 */ } while (false); /* 089 */ /* 090 */ long agg_value_6 = -1L; /* 091 */ /* 092 */ agg_value_6 = agg_value_7 + agg_value_4; /* 093 */ boolean agg_isNull_11 = true; /* 094 */ double agg_value_11 = -1.0; /* 095 */ /* 096 */ if (!agg_bufIsNull_1) { /* 097 */ agg_agg_isNull_13_0 = true; /* 098 */ double agg_value_13 = -1.0; /* 099 */ do { /* 100 */ boolean agg_isNull_14 = agg_isNull_4; /* 101 */ double agg_value_14 = -1.0; /* 102 */ if (!agg_isNull_4) { /* 103 */ agg_value_14 = (double) agg_value_4; /* 104 */ } /* 105 */ if (!agg_isNull_14) { /* 106 */ agg_agg_isNull_13_0 = false; /* 107 */ agg_value_13 = agg_value_14; /* 108 */ continue; /* 109 */ } /* 110 */ /* 111 */ boolean agg_isNull_15 = false; /* 112 */ double agg_value_15 = -1.0; /* 113 */ if (!false) { /* 114 */ agg_value_15 = (double) 0; /* 115 */ } /* 116 */ if (!agg_isNull_15) { /* 117 */ agg_agg_isNull_13_0 = false; /* 118 */ agg_value_13 = agg_value_15; /* 119 */ continue; /* 120 */ } /* 121 */ /* 122 */ } while (false); /* 123 */ /* 124 */ agg_isNull_11 = false; // resultCode could change nullability. /* 125 */ /* 126 */ agg_value_11 = agg_bufValue_1 + agg_value_13; /* 127 */ /* 128 */ } /* 129 */ boolean agg_isNull_17 = false; /* 130 */ long agg_value_17 = -1L; /* 131 */ if (!false && agg_isNull_4) { /* 132 */ agg_isNull_17 = agg_bufIsNull_2; /* 133 */ agg_value_17 = agg_bufValue_2; /* 134 */ } else { /* 135 */ boolean agg_isNull_20 = true; /* 136 */ long agg_value_20 = -1L; /* 137 */ /* 138 */ if (!agg_bufIsNull_2) { /* 139 */ agg_isNull_20 = false; // resultCode could change nullability. /* 140 */ /* 141 */ agg_value_20 = agg_bufValue_2 + 1L; /* 142 */ /* 143 */ } /* 144 */ agg_isNull_17 = agg_isNull_20; /* 145 */ agg_value_17 = agg_value_20; /* 146 */ } /* 147 */ // update aggregation buffer /* 148 */ agg_bufIsNull_0 = false; /* 149 */ agg_bufValue_0 = agg_value_6; /* 150 */ /* 151 */ agg_bufIsNull_1 = agg_isNull_11; /* 152 */ agg_bufValue_1 = agg_value_11; /* 153 */ /* 154 */ agg_bufIsNull_2 = agg_isNull_17; /* 155 */ agg_bufValue_2 = agg_value_17; /* 156 */ /* 157 */ } ``` You can check the previous discussion in #19082 ## How was this patch tested? Existing tests Closes #20965 from maropu/SPARK-21870-2. Authored-by: Takeshi Yamamuro <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 36f8e53 commit cb0cddf

File tree

7 files changed

+348
-69
lines changed

7 files changed

+348
-69
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,13 @@ package object dsl {
115115
def getField(fieldName: String): UnresolvedExtractValue =
116116
UnresolvedExtractValue(expr, Literal(fieldName))
117117

118-
def cast(to: DataType): Expression = Cast(expr, to)
118+
def cast(to: DataType): Expression = {
119+
if (expr.resolved && expr.dataType.sameType(to)) {
120+
expr
121+
} else {
122+
Cast(expr, to)
123+
}
124+
}
119125

120126
def asc: SortOrder = SortOrder(expr, Ascending)
121127
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,48 @@ object CodeGenerator extends Logging {
16121612
}
16131613
}
16141614

1615+
/**
1616+
* Extracts all the input variables from references and subexpression elimination states
1617+
* for a given `expr`. This result will be used to split the generated code of
1618+
* expressions into multiple functions.
1619+
*/
1620+
def getLocalInputVariableValues(
1621+
ctx: CodegenContext,
1622+
expr: Expression,
1623+
subExprs: Map[Expression, SubExprEliminationState]): Set[VariableValue] = {
1624+
val argSet = mutable.Set[VariableValue]()
1625+
if (ctx.INPUT_ROW != null) {
1626+
argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow])
1627+
}
1628+
1629+
// Collects local variables from a given `expr` tree
1630+
val collectLocalVariable = (ev: ExprValue) => ev match {
1631+
case vv: VariableValue => argSet += vv
1632+
case _ =>
1633+
}
1634+
1635+
val stack = mutable.Stack[Expression](expr)
1636+
while (stack.nonEmpty) {
1637+
stack.pop() match {
1638+
case e if subExprs.contains(e) =>
1639+
val SubExprEliminationState(isNull, value) = subExprs(e)
1640+
collectLocalVariable(value)
1641+
collectLocalVariable(isNull)
1642+
1643+
case ref: BoundReference if ctx.currentVars != null &&
1644+
ctx.currentVars(ref.ordinal) != null =>
1645+
val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal)
1646+
collectLocalVariable(value)
1647+
collectLocalVariable(isNull)
1648+
1649+
case e =>
1650+
stack.pushAll(e.children)
1651+
}
1652+
}
1653+
1654+
argSet.toSet
1655+
}
1656+
16151657
/**
16161658
* Returns the name used in accessor and setter for a Java primitive type.
16171659
*/
@@ -1719,6 +1761,15 @@ object CodeGenerator extends Logging {
17191761
1 + params.map(paramLengthForExpr).sum
17201762
}
17211763

1764+
def calculateParamLengthFromExprValues(params: Seq[ExprValue]): Int = {
1765+
def paramLengthForExpr(input: ExprValue): Int = input.javaType match {
1766+
case java.lang.Long.TYPE | java.lang.Double.TYPE => 2
1767+
case _ => 1
1768+
}
1769+
// Initial value is 1 for `this`.
1770+
1 + params.map(paramLengthForExpr).sum
1771+
}
1772+
17221773
/**
17231774
* In Java, a method descriptor is valid only if it represents method parameters with a total
17241775
* length less than a pre-defined constant.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,10 @@ trait Block extends TreeNode[Block] with JavaCode {
143143
case _ => code.trim
144144
}
145145

146-
def length: Int = toString.length
146+
def length: Int = {
147+
// Returns a code length without comments
148+
CodeFormatter.stripExtraNewLinesAndComments(toString).length
149+
}
147150

148151
def isEmpty: Boolean = toString.isEmpty
149152

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,12 +354,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
354354

355355
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
356356
val eval = child.genCode(ctx)
357-
val value = eval.isNull match {
358-
case TrueLiteral => FalseLiteral
359-
case FalseLiteral => TrueLiteral
360-
case v => JavaCode.isNullExpression(s"!$v")
357+
val (value, newCode) = eval.isNull match {
358+
case TrueLiteral => (FalseLiteral, EmptyBlock)
359+
case FalseLiteral => (TrueLiteral, EmptyBlock)
360+
case v =>
361+
val value = ctx.freshName("value")
362+
(JavaCode.variable(value, BooleanType), code"boolean $value = !$v;")
361363
}
362-
ExprCode(code = eval.code, isNull = FalseLiteral, value = value)
364+
ExprCode(code = eval.code + newCode, isNull = FalseLiteral, value = value)
363365
}
364366

365367
override def sql: String = s"(${child.sql} IS NOT NULL)"

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,15 @@ object SQLConf {
10801080
.booleanConf
10811081
.createWithDefault(false)
10821082

1083+
val CODEGEN_SPLIT_AGGREGATE_FUNC =
1084+
buildConf("spark.sql.codegen.aggregate.splitAggregateFunc.enabled")
1085+
.internal()
1086+
.doc("When true, the code generator would split aggregate code into individual methods " +
1087+
"instead of a single big method. This can be used to avoid oversized function that " +
1088+
"can miss the opportunity of JIT optimization.")
1089+
.booleanConf
1090+
.createWithDefault(true)
1091+
10831092
val MAX_NESTED_VIEW_DEPTH =
10841093
buildConf("spark.sql.view.maxNestedViewDepth")
10851094
.internal()
@@ -2353,6 +2362,8 @@ class SQLConf extends Serializable with Logging {
23532362
def cartesianProductExecBufferSpillThreshold: Int =
23542363
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)
23552364

2365+
def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC)
2366+
23562367
def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH)
23572368

23582369
def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION)

0 commit comments

Comments
 (0)