-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-13373] [SQL] generate sort merge join #11248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Generated code for val df1 = sqlContext.range(N).selectExpr(s"id as k1")
val df2 = sqlContext.range(N).selectExpr(s"id as k2")
df1.join(df2, col("k1") === col("k2")).count()/* 001 */
/* 002 */ public Object generate(Object[] references) {
/* 003 */ return new GeneratedIterator(references);
/* 004 */ }
/* 005 */
/* 006 */ /** Codegened pipeline for:
/* 007 */ * TungstenAggregate(key=[], functions=[(count(1),mode=Partial,isDistinct=false)], output=[count#175L])
/* 008 */ +- Project
/* 009 */ +- SortMergeJoin [k1#169L], [k2#171L], None
/* 010 */ :- INPUT
/* 011 */ +- INPUT
/* 012 */ */
/* 013 */ class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 014 */ private Object[] references;
/* 015 */ private boolean agg_initAgg;
/* 016 */ private boolean agg_bufIsNull;
/* 017 */ private long agg_bufValue;
/* 018 */ private scala.collection.Iterator smj_leftInput;
/* 019 */ private scala.collection.Iterator smj_rightInput;
/* 020 */ private InternalRow smj_leftRow;
/* 021 */ private InternalRow smj_rightRow;
/* 022 */ private long smj_value2;
/* 023 */ private java.util.ArrayList smj_matches;
/* 024 */ private long smj_value3;
/* 025 */ private org.apache.spark.sql.execution.metric.LongSQLMetric smj_numOutputRows;
/* 026 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue smj_metricValue;
/* 027 */ private org.apache.spark.sql.execution.metric.LongSQLMetric agg_numOutputRows;
/* 028 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue agg_metricValue;
/* 029 */ private UnsafeRow agg_result;
/* 030 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
/* 031 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 032 */
/* 033 */ public GeneratedIterator(Object[] references) {
/* 034 */ this.references = references;
/* 035 */ }
/* 036 */
/* 037 */ public void init(scala.collection.Iterator inputs[]) {
/* 038 */ agg_initAgg = false;
/* 039 */
/* 040 */ smj_leftInput = inputs[0];
/* 041 */ smj_rightInput = inputs[1];
/* 042 */
/* 043 */ smj_rightRow = null;
/* 044 */
/* 045 */ smj_matches = new java.util.ArrayList();
/* 046 */
/* 047 */ this.smj_numOutputRows = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[0];
/* 048 */ smj_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) smj_numOutputRows.localValue();
/* 049 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[1];
/* 050 */ agg_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) agg_numOutputRows.localValue();
/* 051 */ agg_result = new UnsafeRow(1);
/* 052 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
/* 053 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1);
/* 054 */ }
/* 055 */
/* 056 */ private void agg_doAggregateWithoutKey() throws java.io.IOException {
/* 057 */ // initialize aggregation buffer
/* 058 */
/* 059 */ agg_bufIsNull = false;
/* 060 */ agg_bufValue = 0L;
/* 061 */
/* 062 */ while (findNextInnerJoinRows(smj_leftInput, smj_rightInput)) {
/* 063 */ /* input[0, bigint] */
/* 064 */ boolean smj_isNull2 = smj_leftRow.isNullAt(0);
/* 065 */ long smj_value4 = smj_isNull2 ? -1L : (smj_leftRow.getLong(0));
/* 066 */ int smj_size = smj_matches.size();
/* 067 */ for (int smj_i = 0; smj_i < smj_size; smj_i ++) {
/* 068 */ InternalRow smj_rightRow1 = (InternalRow) smj_matches.get(smj_i);
/* 069 */ /* input[0, bigint] */
/* 070 */ boolean smj_isNull3 = smj_rightRow1.isNullAt(0);
/* 071 */ long smj_value5 = smj_isNull3 ? -1L : (smj_rightRow1.getLong(0));
/* 072 */
/* 073 */ if (false || !true) continue;
/* 074 */ smj_metricValue.add(1);
/* 075 */
/* 076 */ // do aggregate
/* 077 */ /* (input[0, bigint] + 1) */
/* 078 */ long agg_value1 = -1L;
/* 079 */ agg_value1 = agg_bufValue + 1L;
/* 080 */ // update aggregation buffer
/* 081 */ agg_bufIsNull = false;
/* 082 */ agg_bufValue = agg_value1;
/* 083 */
/* 084 */ }
/* 085 */ if (shouldStop()) return;
/* 086 */ }
/* 087 */
/* 088 */ }
/* 089 */
/* 090 */ private boolean findNextInnerJoinRows(
/* 091 */ scala.collection.Iterator leftIter,
/* 092 */ scala.collection.Iterator rightIter) {
/* 093 */ smj_leftRow = null;
/* 094 */ int comp = 0;
/* 095 */ while (smj_leftRow == null) {
/* 096 */ if (!leftIter.hasNext()) return false;
/* 097 */ smj_leftRow = (InternalRow) leftIter.next();
/* 098 */ /* input[0, bigint] */
/* 099 */ boolean smj_isNull = smj_leftRow.isNullAt(0);
/* 100 */ long smj_value = smj_isNull ? -1L : (smj_leftRow.getLong(0));
/* 101 */ if (smj_isNull) {
/* 102 */ smj_leftRow = null;
/* 103 */ continue;
/* 104 */ }
/* 105 */ if (!smj_matches.isEmpty()) {
/* 106 */ comp = 0;
/* 107 */ if (comp == 0) {
/* 108 */ comp = (smj_value > smj_value3 ? 1 : smj_value < smj_value3 ? -1 : 0);
/* 109 */ }
/* 110 */
/* 111 */ if (comp == 0) {
/* 112 */ return true;
/* 113 */ }
/* 114 */ smj_matches.clear();
/* 115 */ }
/* 116 */
/* 117 */ do {
/* 118 */ if (smj_rightRow == null) {
/* 119 */ if (!rightIter.hasNext()) {
/* 120 */ smj_value3 = smj_value;
/* 121 */
/* 122 */ return !smj_matches.isEmpty();
/* 123 */ }
/* 124 */ smj_rightRow = (InternalRow) rightIter.next();
/* 125 */ /* input[0, bigint] */
/* 126 */ boolean smj_isNull1 = smj_rightRow.isNullAt(0);
/* 127 */ long smj_value1 = smj_isNull1 ? -1L : (smj_rightRow.getLong(0));
/* 128 */ if (smj_isNull1) {
/* 129 */ smj_rightRow = null;
/* 130 */ continue;
/* 131 */ }
/* 132 */
/* 133 */ smj_value2 = smj_value1;
/* 134 */
/* 135 */ }
/* 136 */
/* 137 */ comp = 0;
/* 138 */ if (comp == 0) {
/* 139 */ comp = (smj_value > smj_value2 ? 1 : smj_value < smj_value2 ? -1 : 0);
/* 140 */ }
/* 141 */
/* 142 */ if (comp > 0) {
/* 143 */ smj_rightRow = null;
/* 144 */ } else if (comp < 0) {
/* 145 */ if (!smj_matches.isEmpty()) {
/* 146 */ smj_value3 = smj_value;
/* 147 */
/* 148 */ return true;
/* 149 */ }
/* 150 */ smj_leftRow = null;
/* 151 */ } else {
/* 152 */ smj_matches.add(smj_rightRow.copy());
/* 153 */ smj_rightRow = null;;
/* 154 */ }
/* 155 */ } while (smj_leftRow != null);
/* 156 */ }
/* 157 */ return false; // unreachable
/* 158 */ }
/* 159 */
/* 160 */ protected void processNext() throws java.io.IOException {
/* 161 */ if (!agg_initAgg) {
/* 162 */ agg_initAgg = true;
/* 163 */ agg_doAggregateWithoutKey();
/* 164 */
/* 165 */ // output the result
/* 166 */
/* 167 */ agg_metricValue.add(1);
/* 168 */ agg_rowWriter.zeroOutNullBytes();
/* 169 */
/* 170 */ if (agg_bufIsNull) {
/* 171 */ agg_rowWriter.setNullAt(0);
/* 172 */ } else {
/* 173 */ agg_rowWriter.write(0, agg_bufValue);
/* 174 */ }
/* 175 */ append(agg_result.copy());
/* 176 */ }
/* 177 */ }
/* 178 */ } |
|
Test build #51471 has finished for PR 11248 at commit
|
|
Test build #51477 has finished for PR 11248 at commit
|
|
Test build #51478 has finished for PR 11248 at commit
|
|
Shouldn't we add a more realistic benchmark then? |
|
@hvanhovell I think the purpose of micro benchmark is to show the baseline cost of each operator, the result will be very different from case to case, it's hard to say which one is more realistic. |
|
Here is part of the code generated from Q72 /* 160 */ private boolean findNextInnerJoinRows(
/* 161 */ scala.collection.Iterator leftIter,
/* 162 */ scala.collection.Iterator rightIter) {
/* 163 */ smj_leftRow = null;
/* 164 */ int comp = 0;
/* 165 */ while (smj_leftRow == null) {
/* 166 */ if (!leftIter.hasNext()) return false;
/* 167 */ smj_leftRow = (InternalRow) leftIter.next();
/* 168 */ /* input[7, int] */
/* 169 */ boolean smj_isNull = smj_leftRow.isNullAt(7);
/* 170 */ int smj_value = smj_isNull ? -1 : (smj_leftRow.getInt(7));
/* 171 */ if (smj_isNull) {
/* 172 */ smj_leftRow = null;
/* 173 */ continue;
/* 174 */ }
/* 175 */ if (!smj_matches.isEmpty()) {
/* 176 */ comp = 0;
/* 177 */ if (comp == 0) {
/* 178 */ comp = (smj_value > smj_value3 ? 1 : smj_value < smj_value3 ? -1 : 0);
/* 179 */ }
/* 180 */
/* 181 */ if (comp == 0) {
/* 182 */ return true;
/* 183 */ }
/* 184 */ smj_matches.clear();
/* 185 */ }
/* 186 */
/* 187 */ do {
/* 188 */ if (smj_rightRow == null) {
/* 189 */ if (!rightIter.hasNext()) {
/* 190 */ smj_value3 = smj_value;
/* 191 */
/* 192 */ return !smj_matches.isEmpty();
/* 193 */ }
/* 194 */ smj_rightRow = (InternalRow) rightIter.next();
/* 195 */ /* input[1, int] */
/* 196 */ boolean smj_isNull1 = smj_rightRow.isNullAt(1);
/* 197 */ int smj_value1 = smj_isNull1 ? -1 : (smj_rightRow.getInt(1));
/* 198 */ if (smj_isNull1) {
/* 199 */ smj_rightRow = null;
/* 200 */ continue;
/* 201 */ }
/* 202 */
/* 203 */ smj_value2 = smj_value1;
/* 204 */
/* 205 */ }
/* 206 */
/* 207 */ comp = 0;
/* 208 */ if (comp == 0) {
/* 209 */ comp = (smj_value > smj_value2 ? 1 : smj_value < smj_value2 ? -1 : 0);
/* 210 */ }
/* 211 */
/* 212 */ if (comp > 0) {
/* 213 */ smj_rightRow = null;
/* 214 */ } else if (comp < 0) {
/* 215 */ if (!smj_matches.isEmpty()) {
/* 216 */ smj_value3 = smj_value;
/* 217 */
/* 218 */ return true;
/* 219 */ }
/* 220 */ smj_leftRow = null;
/* 221 */ } else {
/* 222 */ smj_matches.add(smj_rightRow.copy());
/* 223 */ smj_rightRow = null;;
/* 224 */ }
/* 225 */ } while (smj_leftRow != null);
/* 226 */ }
/* 227 */ return false; // unreachable
/* 228 */ }
/* 229 */
/* 230 */ protected void processNext() throws java.io.IOException {
/* 231 */ while (findNextInnerJoinRows(smj_leftInput, smj_rightInput)) {
/* 232 */ int smj_size = smj_matches.size();
/* 233 */ boolean smj_loaded = false;
/* 234 */
/* 235 */ smj_isNull7 = smj_leftRow.isNullAt(5);
/* 236 */ smj_value9 = smj_isNull7 ? -1 : (smj_leftRow.getInt(5));
/* 237 */
/* 238 */ for (int smj_i = 0; smj_i < smj_size; smj_i ++) {
/* 239 */ InternalRow smj_rightRow1 = (InternalRow) smj_matches.get(smj_i);
/* 240 */ /* input[3, int] */
/* 241 */ boolean smj_isNull13 = smj_rightRow1.isNullAt(3);
/* 242 */ int smj_value15 = smj_isNull13 ? -1 : (smj_rightRow1.getInt(3));
/* 243 */ /* (input[11, int] < input[5, int]) */
/* 244 */ boolean smj_isNull14 = true;
/* 245 */ boolean smj_value16 = false;
/* 246 */
/* 247 */ if (!smj_isNull13) {
/* 248 */ if (!smj_isNull7) {
/* 249 */ smj_isNull14 = false; // resultCode could change nullability.
/* 250 */ smj_value16 = smj_value15 < smj_value9;
/* 251 */
/* 252 */ }
/* 253 */
/* 254 */ }
/* 255 */ if (smj_isNull14 || !smj_value16) continue;
/* 256 */ if (!smj_loaded) {
/* 257 */ smj_loaded = true;
/* 258 */
/* 259 */ smj_isNull2 = smj_leftRow.isNullAt(0);
/* 260 */ smj_value4 = smj_isNull2 ? -1 : (smj_leftRow.getInt(0));
/* 261 */
/* 262 */ smj_isNull3 = smj_leftRow.isNullAt(1);
/* 263 */ smj_value5 = smj_isNull3 ? -1 : (smj_leftRow.getInt(1));
/* 264 */
/* 265 */ smj_isNull4 = smj_leftRow.isNullAt(2);
/* 266 */ smj_value6 = smj_isNull4 ? -1 : (smj_leftRow.getInt(2));
/* 267 */
/* 268 */ smj_isNull5 = smj_leftRow.isNullAt(3);
/* 269 */ smj_value7 = smj_isNull5 ? -1 : (smj_leftRow.getInt(3));
/* 270 */
/* 271 */ smj_isNull6 = smj_leftRow.isNullAt(4);
/* 272 */ smj_value8 = smj_isNull6 ? -1 : (smj_leftRow.getInt(4));
/* 273 */
/* 274 */ smj_isNull8 = smj_leftRow.isNullAt(6);
/* 275 */ smj_value10 = smj_isNull8 ? -1 : (smj_leftRow.getInt(6));
/* 276 */
/* 277 */ smj_isNull9 = smj_leftRow.isNullAt(7);
/* 278 */ smj_value11 = smj_isNull9 ? -1 : (smj_leftRow.getInt(7));
/* 279 */
/* 280 */ }
/* 281 */ /* input[0, int] */
/* 282 */ boolean smj_isNull10 = smj_rightRow1.isNullAt(0);
/* 283 */ int smj_value12 = smj_isNull10 ? -1 : (smj_rightRow1.getInt(0));
/* 284 */ /* input[1, int] */
/* 285 */ boolean smj_isNull11 = smj_rightRow1.isNullAt(1);
/* 286 */ int smj_value13 = smj_isNull11 ? -1 : (smj_rightRow1.getInt(1));
/* 287 */ /* input[2, int] */
/* 288 */ boolean smj_isNull12 = smj_rightRow1.isNullAt(2);
/* 289 */ int smj_value14 = smj_isNull12 ? -1 : (smj_rightRow1.getInt(2));
/* 290 */ smj_metricValue.add(1); |
|
Test build #2544 has finished for PR 11248 at commit
|
|
Test build #51490 has finished for PR 11248 at commit
|
|
Test build #51495 has finished for PR 11248 at commit
|
|
Test build #2545 has finished for PR 11248 at commit
|
|
@davies this is linked to the wrong JIRA |
Conflicts: sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
|
Test build #51607 has finished for PR 11248 at commit
|
|
|
||
| override def doProduce(ctx: CodegenContext): String = { | ||
| val input = ctx.freshName("input") | ||
| // Right now, Range is only used when there is one upstream. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Range -> InputAdapter?
|
Does it make sense to do something like this and remove the copies? |
|
@nongli We already avoid the copy until found the match. In order to support left N rows match with right M rows, we have to copy some of them. Right now, we copy the keys from left, and the rows from right. It's possible to have less copies, I'd like to leave that out of this PR, because SMJ usually required Exchange and Sort, those are more expensive than coping the matched rows from right. |
|
Test build #51771 has finished for PR 11248 at commit
|
| class BenchmarkWholeStageCodegen extends SparkFunSuite { | ||
| lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") | ||
| .set("spark.sql.shuffle.partitions", "1") | ||
| .set("spark.sql.autoBroadcastJoinThreshold", "0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be set inside the sort merge join benchmark?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will use BroadcastHint when test with Broadcast join, it's better to disable it for all.
|
LGTM |
|
Merging this into master, thanks! |
What changes were proposed in this pull request?
Generates code for SortMergeJoin.
How was the this patch tested?
Unit tests and manually tested with TPCDS Q72, which showed 70% performance improvements (from 42s to 25s), but micro benchmark only show minor improvements, it may depends the distribution of data and number of columns.