-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-21870][SQL] Split aggregation code into small functions #19082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I checked the TPCDS performance; |
|
Test build #81235 has finished for PR 19082 at commit
|
|
Jenkins, retest this please. |
|
Test build #81248 has finished for PR 19082 at commit
|
|
Aggregation code is not the only too long gen'd codes. Any big enough queries under current whole-stage codegen can produce too long gen'd codes. Compared with special split for aggregation code like this, I'd like to push #18931 which is a more general fix. |
|
I didn't look into the code yet though, #18931 also could solve a query below? |
|
What it means to solve the query? Is it too long gen'd codes? |
|
yes and #18931 covers all the case this pr described? (I feel these two prs solve different too-long gen'd code issues?) |
|
After quickly scanning this change, I found actually this and #18931 are orthogonal. This PR goes to split the aggregation expressions into functions. So it prevents long aggregation code caused by long gen'd codes expressions like #18931 goes to break the long gen'd codes caused by chain of execution operators by splitting gen'd codes of operators into functions. |
|
yea, I'm with you; it seems they are orthogonal..., but If we have a more general way to solve both, it'd be the best. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since $doAggVal is non-static method, this number should be 255.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, good catch! I'll fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a check code for the case that a user specify a value that is more than 255?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a test-only option, so I think we need not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OuterReference actually has special meaning in correlated subquery. This name can be confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I'll rename this.
rednaxelafx
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I like the direction this PR is going. Just some nits in inline comments:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know the regular expression is tempting, but there's actually a better way to do this along your idea, under the current framework.
I've got a piece of code sitting in my own workspace that checks for Java identifiers:
object CodegenContext {
private val javaKeywords = Set(
"abstract", "assert", "boolean", "break", "byte", "case", "catch", "char", "class",
"const", "continue", "default", "do", "double", "else", "extends", "false", "final",
"finally", "float", "for", "goto", "if", "implements", "import", "instanceof", "int",
"interface", "long", "native", "new", "null", "package", "private", "protected", "public",
"return", "short", "static", "strictfp", "super", "switch", "synchronized", "this",
"throw", "throws", "transient", "true", "try", "void", "volatile", "while"
)
def isJavaIdentifier(str: String): Boolean = str match {
case null | "" => false
case _ => java.lang.Character.isJavaIdentifierStart(str.charAt(0)) &&
(1 until str.length).forall(
i => java.lang.Character.isJavaIdentifierPart(str.charAt(i))) &&
!javaKeywords.contains(str)
}
}Feel free to use it here if you'd like. This is the way java.lang.Character.isJavaIdentifierStart() and java.lang.Character.isJavaIdentifierPart() is supposed to be used anyway, nothing creative.
If you want to use it in a case like the way you're using the regular expression, just wrap the util above into an unapply(). But I'd say simply making def isVariable(nameId: String) = CodegenContext.isJavaIdentifier(nameId) is clean enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion and I'm also looking for other better one. I'll try to fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. Just a cosmetic style comment: I would have declared addIfNotLiteral with a def instead of making it a scala.Function2[String, String, Unit].
BTW, can we add a comment to val argSet for what those two fields of the Tuple2[String, String] means? And then also make this addIfNotLiteral function take the arguments in the same order as the tuple.
e432136 to
57c8280
Compare
|
Test build #81330 has finished for PR 19082 at commit
|
|
Test build #81331 has finished for PR 19082 at commit
|
|
Test build #81329 has finished for PR 19082 at commit
|
|
Test build #81328 has finished for PR 19082 at commit
|
|
Test build #81332 has finished for PR 19082 at commit
|
|
Test build #81351 has finished for PR 19082 at commit
|
|
Jenkins, retest this please. |
|
Test build #81354 has finished for PR 19082 at commit
|
|
@viirya @rednaxelafx @kiszk okay, could you check again? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If line 314 uses <=, this should be 254. In the previous commit, < was used.
|
Test build #81400 has finished for PR 19082 at commit
|
|
Test build #84717 has finished for PR 19082 at commit
|
|
Jenkins, retest this please |
|
Test build #84724 has finished for PR 19082 at commit
|
|
Just FYI, the test failure is not related to this PR. |
| ctx.currentVars = bufVars ++ input | ||
|
|
||
| // We need to copy the aggregation buffer to local variables first because each aggregate | ||
| // function directly updates the buffer when it finishes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just FYI: we must need local copys from this discussions, too #19865
|
Test build #84734 has finished for PR 19082 at commit
|
|
retest this please |
|
Test build #84770 has finished for PR 19082 at commit
|
|
Test build #84774 has finished for PR 19082 at commit
|
|
Test build #84807 has finished for PR 19082 at commit
|
|
Since #19813 fixes this issue, so I'll close this. Thanks! |
|
@gatorsmile can you close the jira ticket with @viirya assigned? Thanks! |
| val updateRowInRegularHashMap: String = { | ||
| ctx.INPUT_ROW = unsafeRowBuffer | ||
| // We need to copy the aggregation row buffer to a local row first because each aggregate | ||
| // function directly updates the buffer when it finishes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this matter? We should avoid unnecessary data copy as possible as we can.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need this copy because: #19082 (comment)
|
Sorry for missing the excellent discussions here. After more thoughts, I feel this patch still makes sense. Currently we have 3 cases that may cause large method with whole stage codegen:
All of them are orthogonal but non-trivial, I don't see any of them can cover all other cases. |
rednaxelafx
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rebasing this PR to the latest master would be interestingly involved...
| stack.pop() match { | ||
| case e if subExprs.contains(e) => | ||
| val exprCode = subExprs(e) | ||
| if (CodegenContext.isJavaIdentifier(exprCode.value)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey, good news! Thanks for letting me know ;)
I track to here from jira ticket SPARK-23791, which contains a test case. Test result of 3 rounds : |
|
Yea, this issue is not fixed now. As you know, it's a long story for this issue... |
|
Thanks for your reply. |
## What changes were proposed in this pull request? This pr proposed to split aggregation code into small functions in `HashAggregateExec`. In #18810, we got performance regression if JVMs didn't compile too long functions. I checked and I found the codegen of `HashAggregateExec` frequently goes over the limit when a query has too many aggregate functions (e.g., q66 in TPCDS). The current master places all the generated aggregation code in a single function. In this pr, I modified the code to assign an individual function for each aggregate function (e.g., `SUM` and `AVG`). For example, in a query `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, the proposed code defines two functions for `SUM(a)` and `AVG(a)` as follows; - generated code with this pr (https://gist.github.com/maropu/812990012bc967a78364be0fa793f559): ``` /* 173 */ private void agg_doConsume_0(InternalRow inputadapter_row_0, long agg_expr_0_0, boolean agg_exprIsNull_0_0, double agg_expr_1_0, boolean agg_exprIsNull_1_0, long agg_expr_2_0, boolean agg_exprIsNull_2_0) throws java.io.IOException { /* 174 */ // do aggregate /* 175 */ // common sub-expressions /* 176 */ /* 177 */ // evaluate aggregate functions and update aggregation buffers /* 178 */ agg_doAggregate_sum_0(agg_exprIsNull_0_0, agg_expr_0_0); /* 179 */ agg_doAggregate_avg_0(agg_expr_1_0, agg_exprIsNull_1_0, agg_exprIsNull_2_0, agg_expr_2_0); /* 180 */ /* 181 */ } ... /* 071 */ private void agg_doAggregate_avg_0(double agg_expr_1_0, boolean agg_exprIsNull_1_0, boolean agg_exprIsNull_2_0, long agg_expr_2_0) throws java.io.IOException { /* 072 */ // do aggregate for avg /* 073 */ // evaluate aggregate function /* 074 */ boolean agg_isNull_19 = true; /* 075 */ double agg_value_19 = -1.0; ... /* 114 */ private void agg_doAggregate_sum_0(boolean agg_exprIsNull_0_0, long agg_expr_0_0) throws java.io.IOException { /* 115 */ // do aggregate for sum /* 116 */ // evaluate aggregate function /* 117 */ agg_agg_isNull_11_0 = true; /* 118 */ long agg_value_11 = -1L; ``` - generated code in the current master (https://gist.github.com/maropu/e9d772af2c98d8991a6a5f0af7841760) ``` /* 059 */ private void agg_doConsume_0(InternalRow localtablescan_row_0, int agg_expr_0_0) throws java.io.IOException { /* 060 */ // do aggregate /* 061 */ // common sub-expressions /* 062 */ boolean agg_isNull_4 = false; /* 063 */ long agg_value_4 = -1L; /* 064 */ if (!false) { /* 065 */ agg_value_4 = (long) agg_expr_0_0; /* 066 */ } /* 067 */ // evaluate aggregate function /* 068 */ agg_agg_isNull_7_0 = true; /* 069 */ long agg_value_7 = -1L; /* 070 */ do { /* 071 */ if (!agg_bufIsNull_0) { /* 072 */ agg_agg_isNull_7_0 = false; /* 073 */ agg_value_7 = agg_bufValue_0; /* 074 */ continue; /* 075 */ } /* 076 */ /* 077 */ boolean agg_isNull_9 = false; /* 078 */ long agg_value_9 = -1L; /* 079 */ if (!false) { /* 080 */ agg_value_9 = (long) 0; /* 081 */ } /* 082 */ if (!agg_isNull_9) { /* 083 */ agg_agg_isNull_7_0 = false; /* 084 */ agg_value_7 = agg_value_9; /* 085 */ continue; /* 086 */ } /* 087 */ /* 088 */ } while (false); /* 089 */ /* 090 */ long agg_value_6 = -1L; /* 091 */ /* 092 */ agg_value_6 = agg_value_7 + agg_value_4; /* 093 */ boolean agg_isNull_11 = true; /* 094 */ double agg_value_11 = -1.0; /* 095 */ /* 096 */ if (!agg_bufIsNull_1) { /* 097 */ agg_agg_isNull_13_0 = true; /* 098 */ double agg_value_13 = -1.0; /* 099 */ do { /* 100 */ boolean agg_isNull_14 = agg_isNull_4; /* 101 */ double agg_value_14 = -1.0; /* 102 */ if (!agg_isNull_4) { /* 103 */ agg_value_14 = (double) agg_value_4; /* 104 */ } /* 105 */ if (!agg_isNull_14) { /* 106 */ agg_agg_isNull_13_0 = false; /* 107 */ agg_value_13 = agg_value_14; /* 108 */ continue; /* 109 */ } /* 110 */ /* 111 */ boolean agg_isNull_15 = false; /* 112 */ double agg_value_15 = -1.0; /* 113 */ if (!false) { /* 114 */ agg_value_15 = (double) 0; /* 115 */ } /* 116 */ if (!agg_isNull_15) { /* 117 */ agg_agg_isNull_13_0 = false; /* 118 */ agg_value_13 = agg_value_15; /* 119 */ continue; /* 120 */ } /* 121 */ /* 122 */ } while (false); /* 123 */ /* 124 */ agg_isNull_11 = false; // resultCode could change nullability. /* 125 */ /* 126 */ agg_value_11 = agg_bufValue_1 + agg_value_13; /* 127 */ /* 128 */ } /* 129 */ boolean agg_isNull_17 = false; /* 130 */ long agg_value_17 = -1L; /* 131 */ if (!false && agg_isNull_4) { /* 132 */ agg_isNull_17 = agg_bufIsNull_2; /* 133 */ agg_value_17 = agg_bufValue_2; /* 134 */ } else { /* 135 */ boolean agg_isNull_20 = true; /* 136 */ long agg_value_20 = -1L; /* 137 */ /* 138 */ if (!agg_bufIsNull_2) { /* 139 */ agg_isNull_20 = false; // resultCode could change nullability. /* 140 */ /* 141 */ agg_value_20 = agg_bufValue_2 + 1L; /* 142 */ /* 143 */ } /* 144 */ agg_isNull_17 = agg_isNull_20; /* 145 */ agg_value_17 = agg_value_20; /* 146 */ } /* 147 */ // update aggregation buffer /* 148 */ agg_bufIsNull_0 = false; /* 149 */ agg_bufValue_0 = agg_value_6; /* 150 */ /* 151 */ agg_bufIsNull_1 = agg_isNull_11; /* 152 */ agg_bufValue_1 = agg_value_11; /* 153 */ /* 154 */ agg_bufIsNull_2 = agg_isNull_17; /* 155 */ agg_bufValue_2 = agg_value_17; /* 156 */ /* 157 */ } ``` You can check the previous discussion in #19082 ## How was this patch tested? Existing tests Closes #20965 from maropu/SPARK-21870-2. Authored-by: Takeshi Yamamuro <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
## What changes were proposed in this pull request? This pr proposed to split aggregation code into small functions in `HashAggregateExec`. In apache#18810, we got performance regression if JVMs didn't compile too long functions. I checked and I found the codegen of `HashAggregateExec` frequently goes over the limit when a query has too many aggregate functions (e.g., q66 in TPCDS). The current master places all the generated aggregation code in a single function. In this pr, I modified the code to assign an individual function for each aggregate function (e.g., `SUM` and `AVG`). For example, in a query `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, the proposed code defines two functions for `SUM(a)` and `AVG(a)` as follows; - generated code with this pr (https://gist.github.com/maropu/812990012bc967a78364be0fa793f559): ``` /* 173 */ private void agg_doConsume_0(InternalRow inputadapter_row_0, long agg_expr_0_0, boolean agg_exprIsNull_0_0, double agg_expr_1_0, boolean agg_exprIsNull_1_0, long agg_expr_2_0, boolean agg_exprIsNull_2_0) throws java.io.IOException { /* 174 */ // do aggregate /* 175 */ // common sub-expressions /* 176 */ /* 177 */ // evaluate aggregate functions and update aggregation buffers /* 178 */ agg_doAggregate_sum_0(agg_exprIsNull_0_0, agg_expr_0_0); /* 179 */ agg_doAggregate_avg_0(agg_expr_1_0, agg_exprIsNull_1_0, agg_exprIsNull_2_0, agg_expr_2_0); /* 180 */ /* 181 */ } ... /* 071 */ private void agg_doAggregate_avg_0(double agg_expr_1_0, boolean agg_exprIsNull_1_0, boolean agg_exprIsNull_2_0, long agg_expr_2_0) throws java.io.IOException { /* 072 */ // do aggregate for avg /* 073 */ // evaluate aggregate function /* 074 */ boolean agg_isNull_19 = true; /* 075 */ double agg_value_19 = -1.0; ... /* 114 */ private void agg_doAggregate_sum_0(boolean agg_exprIsNull_0_0, long agg_expr_0_0) throws java.io.IOException { /* 115 */ // do aggregate for sum /* 116 */ // evaluate aggregate function /* 117 */ agg_agg_isNull_11_0 = true; /* 118 */ long agg_value_11 = -1L; ``` - generated code in the current master (https://gist.github.com/maropu/e9d772af2c98d8991a6a5f0af7841760) ``` /* 059 */ private void agg_doConsume_0(InternalRow localtablescan_row_0, int agg_expr_0_0) throws java.io.IOException { /* 060 */ // do aggregate /* 061 */ // common sub-expressions /* 062 */ boolean agg_isNull_4 = false; /* 063 */ long agg_value_4 = -1L; /* 064 */ if (!false) { /* 065 */ agg_value_4 = (long) agg_expr_0_0; /* 066 */ } /* 067 */ // evaluate aggregate function /* 068 */ agg_agg_isNull_7_0 = true; /* 069 */ long agg_value_7 = -1L; /* 070 */ do { /* 071 */ if (!agg_bufIsNull_0) { /* 072 */ agg_agg_isNull_7_0 = false; /* 073 */ agg_value_7 = agg_bufValue_0; /* 074 */ continue; /* 075 */ } /* 076 */ /* 077 */ boolean agg_isNull_9 = false; /* 078 */ long agg_value_9 = -1L; /* 079 */ if (!false) { /* 080 */ agg_value_9 = (long) 0; /* 081 */ } /* 082 */ if (!agg_isNull_9) { /* 083 */ agg_agg_isNull_7_0 = false; /* 084 */ agg_value_7 = agg_value_9; /* 085 */ continue; /* 086 */ } /* 087 */ /* 088 */ } while (false); /* 089 */ /* 090 */ long agg_value_6 = -1L; /* 091 */ /* 092 */ agg_value_6 = agg_value_7 + agg_value_4; /* 093 */ boolean agg_isNull_11 = true; /* 094 */ double agg_value_11 = -1.0; /* 095 */ /* 096 */ if (!agg_bufIsNull_1) { /* 097 */ agg_agg_isNull_13_0 = true; /* 098 */ double agg_value_13 = -1.0; /* 099 */ do { /* 100 */ boolean agg_isNull_14 = agg_isNull_4; /* 101 */ double agg_value_14 = -1.0; /* 102 */ if (!agg_isNull_4) { /* 103 */ agg_value_14 = (double) agg_value_4; /* 104 */ } /* 105 */ if (!agg_isNull_14) { /* 106 */ agg_agg_isNull_13_0 = false; /* 107 */ agg_value_13 = agg_value_14; /* 108 */ continue; /* 109 */ } /* 110 */ /* 111 */ boolean agg_isNull_15 = false; /* 112 */ double agg_value_15 = -1.0; /* 113 */ if (!false) { /* 114 */ agg_value_15 = (double) 0; /* 115 */ } /* 116 */ if (!agg_isNull_15) { /* 117 */ agg_agg_isNull_13_0 = false; /* 118 */ agg_value_13 = agg_value_15; /* 119 */ continue; /* 120 */ } /* 121 */ /* 122 */ } while (false); /* 123 */ /* 124 */ agg_isNull_11 = false; // resultCode could change nullability. /* 125 */ /* 126 */ agg_value_11 = agg_bufValue_1 + agg_value_13; /* 127 */ /* 128 */ } /* 129 */ boolean agg_isNull_17 = false; /* 130 */ long agg_value_17 = -1L; /* 131 */ if (!false && agg_isNull_4) { /* 132 */ agg_isNull_17 = agg_bufIsNull_2; /* 133 */ agg_value_17 = agg_bufValue_2; /* 134 */ } else { /* 135 */ boolean agg_isNull_20 = true; /* 136 */ long agg_value_20 = -1L; /* 137 */ /* 138 */ if (!agg_bufIsNull_2) { /* 139 */ agg_isNull_20 = false; // resultCode could change nullability. /* 140 */ /* 141 */ agg_value_20 = agg_bufValue_2 + 1L; /* 142 */ /* 143 */ } /* 144 */ agg_isNull_17 = agg_isNull_20; /* 145 */ agg_value_17 = agg_value_20; /* 146 */ } /* 147 */ // update aggregation buffer /* 148 */ agg_bufIsNull_0 = false; /* 149 */ agg_bufValue_0 = agg_value_6; /* 150 */ /* 151 */ agg_bufIsNull_1 = agg_isNull_11; /* 152 */ agg_bufValue_1 = agg_value_11; /* 153 */ /* 154 */ agg_bufIsNull_2 = agg_isNull_17; /* 155 */ agg_bufValue_2 = agg_value_17; /* 156 */ /* 157 */ } ``` You can check the previous discussion in apache#19082 ## How was this patch tested? Existing tests Closes apache#20965 from maropu/SPARK-21870-2. Authored-by: Takeshi Yamamuro <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
What changes were proposed in this pull request?
This pr proposes to split aggregation code into pieces in
HashAggregateExec. In #18810, we got performance regression if JVMs didn't compile too long functions (the limit is 8000 in bytecode size). I checked and I found the codegen ofHashAggregateExecfrequently goes over the limit, for example:This query goes over the limit and the actual bytecode size is
12356.This pr split the aggregation code into small separate functions and, in a simple example;
I did also performance checks;
How was this patch tested?
Added tests in
WholeStageCodegenSuite.