Skip to content

Conversation

@davies
Copy link
Contributor

@davies davies commented Feb 18, 2016

What changes were proposed in this pull request?

Generates code for SortMergeJoin.

How was the this patch tested?

Unit tests and manually tested with TPCDS Q72, which showed 70% performance improvements (from 42s to 25s), but micro benchmark only show minor improvements, it may depends the distribution of data and number of columns.

@davies
Copy link
Contributor Author

davies commented Feb 18, 2016

Generated code for

      val df1 = sqlContext.range(N).selectExpr(s"id as k1")
      val df2 = sqlContext.range(N).selectExpr(s"id as k2")
      df1.join(df2, col("k1") === col("k2")).count()
/* 001 */
/* 002 */ public Object generate(Object[] references) {
/* 003 */   return new GeneratedIterator(references);
/* 004 */ }
/* 005 */
/* 006 */ /** Codegened pipeline for:
/* 007 */ * TungstenAggregate(key=[], functions=[(count(1),mode=Partial,isDistinct=false)], output=[count#175L])
/* 008 */ +- Project
/* 009 */ +- SortMergeJoin [k1#169L], [k2#171L], None
/* 010 */ :- INPUT
/* 011 */ +- INPUT
/* 012 */ */
/* 013 */ class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 014 */   private Object[] references;
/* 015 */   private boolean agg_initAgg;
/* 016 */   private boolean agg_bufIsNull;
/* 017 */   private long agg_bufValue;
/* 018 */   private scala.collection.Iterator smj_leftInput;
/* 019 */   private scala.collection.Iterator smj_rightInput;
/* 020 */   private InternalRow smj_leftRow;
/* 021 */   private InternalRow smj_rightRow;
/* 022 */   private long smj_value2;
/* 023 */   private java.util.ArrayList smj_matches;
/* 024 */   private long smj_value3;
/* 025 */   private org.apache.spark.sql.execution.metric.LongSQLMetric smj_numOutputRows;
/* 026 */   private org.apache.spark.sql.execution.metric.LongSQLMetricValue smj_metricValue;
/* 027 */   private org.apache.spark.sql.execution.metric.LongSQLMetric agg_numOutputRows;
/* 028 */   private org.apache.spark.sql.execution.metric.LongSQLMetricValue agg_metricValue;
/* 029 */   private UnsafeRow agg_result;
/* 030 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
/* 031 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 032 */
/* 033 */   public GeneratedIterator(Object[] references) {
/* 034 */     this.references = references;
/* 035 */   }
/* 036 */
/* 037 */   public void init(scala.collection.Iterator inputs[]) {
/* 038 */     agg_initAgg = false;
/* 039 */
/* 040 */     smj_leftInput = inputs[0];
/* 041 */     smj_rightInput = inputs[1];
/* 042 */
/* 043 */     smj_rightRow = null;
/* 044 */
/* 045 */     smj_matches = new java.util.ArrayList();
/* 046 */
/* 047 */     this.smj_numOutputRows = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[0];
/* 048 */     smj_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) smj_numOutputRows.localValue();
/* 049 */     this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[1];
/* 050 */     agg_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) agg_numOutputRows.localValue();
/* 051 */     agg_result = new UnsafeRow(1);
/* 052 */     this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
/* 053 */     this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1);
/* 054 */   }
/* 055 */
/* 056 */   private void agg_doAggregateWithoutKey() throws java.io.IOException {
/* 057 */     // initialize aggregation buffer
/* 058 */
/* 059 */     agg_bufIsNull = false;
/* 060 */     agg_bufValue = 0L;
/* 061 */
/* 062 */     while (findNextInnerJoinRows(smj_leftInput, smj_rightInput)) {
/* 063 */       /* input[0, bigint] */
/* 064 */       boolean smj_isNull2 = smj_leftRow.isNullAt(0);
/* 065 */       long smj_value4 = smj_isNull2 ? -1L : (smj_leftRow.getLong(0));
/* 066 */       int smj_size = smj_matches.size();
/* 067 */       for (int smj_i = 0; smj_i < smj_size; smj_i ++) {
/* 068 */         InternalRow smj_rightRow1 = (InternalRow) smj_matches.get(smj_i);
/* 069 */         /* input[0, bigint] */
/* 070 */         boolean smj_isNull3 = smj_rightRow1.isNullAt(0);
/* 071 */         long smj_value5 = smj_isNull3 ? -1L : (smj_rightRow1.getLong(0));
/* 072 */
/* 073 */         if (false || !true) continue;
/* 074 */         smj_metricValue.add(1);
/* 075 */
/* 076 */         // do aggregate
/* 077 */         /* (input[0, bigint] + 1) */
/* 078 */         long agg_value1 = -1L;
/* 079 */         agg_value1 = agg_bufValue + 1L;
/* 080 */         // update aggregation buffer
/* 081 */         agg_bufIsNull = false;
/* 082 */         agg_bufValue = agg_value1;
/* 083 */
/* 084 */       }
/* 085 */       if (shouldStop()) return;
/* 086 */     }
/* 087 */
/* 088 */   }
/* 089 */
/* 090 */   private boolean findNextInnerJoinRows(
/* 091 */     scala.collection.Iterator leftIter,
/* 092 */     scala.collection.Iterator rightIter) {
/* 093 */     smj_leftRow = null;
/* 094 */     int comp = 0;
/* 095 */     while (smj_leftRow == null) {
/* 096 */       if (!leftIter.hasNext()) return false;
/* 097 */       smj_leftRow = (InternalRow) leftIter.next();
/* 098 */       /* input[0, bigint] */
/* 099 */       boolean smj_isNull = smj_leftRow.isNullAt(0);
/* 100 */       long smj_value = smj_isNull ? -1L : (smj_leftRow.getLong(0));
/* 101 */       if (smj_isNull) {
/* 102 */         smj_leftRow = null;
/* 103 */         continue;
/* 104 */       }
/* 105 */       if (!smj_matches.isEmpty()) {
/* 106 */         comp = 0;
/* 107 */         if (comp == 0) {
/* 108 */           comp = (smj_value > smj_value3 ? 1 : smj_value < smj_value3 ? -1 : 0);
/* 109 */         }
/* 110 */
/* 111 */         if (comp == 0) {
/* 112 */           return true;
/* 113 */         }
/* 114 */         smj_matches.clear();
/* 115 */       }
/* 116 */
/* 117 */       do {
/* 118 */         if (smj_rightRow == null) {
/* 119 */           if (!rightIter.hasNext()) {
/* 120 */             smj_value3 = smj_value;
/* 121 */
/* 122 */             return !smj_matches.isEmpty();
/* 123 */           }
/* 124 */           smj_rightRow = (InternalRow) rightIter.next();
/* 125 */           /* input[0, bigint] */
/* 126 */           boolean smj_isNull1 = smj_rightRow.isNullAt(0);
/* 127 */           long smj_value1 = smj_isNull1 ? -1L : (smj_rightRow.getLong(0));
/* 128 */           if (smj_isNull1) {
/* 129 */             smj_rightRow = null;
/* 130 */             continue;
/* 131 */           }
/* 132 */
/* 133 */           smj_value2 = smj_value1;
/* 134 */
/* 135 */         }
/* 136 */
/* 137 */         comp = 0;
/* 138 */         if (comp == 0) {
/* 139 */           comp = (smj_value > smj_value2 ? 1 : smj_value < smj_value2 ? -1 : 0);
/* 140 */         }
/* 141 */
/* 142 */         if (comp > 0) {
/* 143 */           smj_rightRow = null;
/* 144 */         } else if (comp < 0) {
/* 145 */           if (!smj_matches.isEmpty()) {
/* 146 */             smj_value3 = smj_value;
/* 147 */
/* 148 */             return true;
/* 149 */           }
/* 150 */           smj_leftRow = null;
/* 151 */         } else {
/* 152 */           smj_matches.add(smj_rightRow.copy());
/* 153 */           smj_rightRow = null;;
/* 154 */         }
/* 155 */       } while (smj_leftRow != null);
/* 156 */     }
/* 157 */     return false; // unreachable
/* 158 */   }
/* 159 */
/* 160 */   protected void processNext() throws java.io.IOException {
/* 161 */     if (!agg_initAgg) {
/* 162 */       agg_initAgg = true;
/* 163 */       agg_doAggregateWithoutKey();
/* 164 */
/* 165 */       // output the result
/* 166 */
/* 167 */       agg_metricValue.add(1);
/* 168 */       agg_rowWriter.zeroOutNullBytes();
/* 169 */
/* 170 */       if (agg_bufIsNull) {
/* 171 */         agg_rowWriter.setNullAt(0);
/* 172 */       } else {
/* 173 */         agg_rowWriter.write(0, agg_bufValue);
/* 174 */       }
/* 175 */       append(agg_result.copy());
/* 176 */     }
/* 177 */   }
/* 178 */ }

@SparkQA
Copy link

SparkQA commented Feb 18, 2016

Test build #51471 has finished for PR 11248 at commit 2a83ff5.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • public abstract class BufferedRowIterator

@SparkQA
Copy link

SparkQA commented Feb 18, 2016

Test build #51477 has finished for PR 11248 at commit 3c8b35c.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 18, 2016

Test build #51478 has finished for PR 11248 at commit b3669f2.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • public abstract class BufferedRowIterator

@hvanhovell
Copy link
Contributor

Shouldn't we add a more realistic benchmark then?

@davies
Copy link
Contributor Author

davies commented Feb 18, 2016

@hvanhovell I think the purpose of micro benchmark is to show the baseline cost of each operator, the result will be very different from case to case, it's hard to say which one is more realistic.

@davies
Copy link
Contributor Author

davies commented Feb 18, 2016

Here is part of the code generated from Q72

/* 160 */   private boolean findNextInnerJoinRows(
/* 161 */     scala.collection.Iterator leftIter,
/* 162 */     scala.collection.Iterator rightIter) {
/* 163 */     smj_leftRow = null;
/* 164 */     int comp = 0;
/* 165 */     while (smj_leftRow == null) {
/* 166 */       if (!leftIter.hasNext()) return false;
/* 167 */       smj_leftRow = (InternalRow) leftIter.next();
/* 168 */       /* input[7, int] */
/* 169 */       boolean smj_isNull = smj_leftRow.isNullAt(7);
/* 170 */       int smj_value = smj_isNull ? -1 : (smj_leftRow.getInt(7));
/* 171 */       if (smj_isNull) {
/* 172 */         smj_leftRow = null;
/* 173 */         continue;
/* 174 */       }
/* 175 */       if (!smj_matches.isEmpty()) {
/* 176 */         comp = 0;
/* 177 */         if (comp == 0) {
/* 178 */           comp = (smj_value > smj_value3 ? 1 : smj_value < smj_value3 ? -1 : 0);
/* 179 */         }
/* 180 */
/* 181 */         if (comp == 0) {
/* 182 */           return true;
/* 183 */         }
/* 184 */         smj_matches.clear();
/* 185 */       }
/* 186 */
/* 187 */       do {
/* 188 */         if (smj_rightRow == null) {
/* 189 */           if (!rightIter.hasNext()) {
/* 190 */             smj_value3 = smj_value;
/* 191 */
/* 192 */             return !smj_matches.isEmpty();
/* 193 */           }
/* 194 */           smj_rightRow = (InternalRow) rightIter.next();
/* 195 */           /* input[1, int] */
/* 196 */           boolean smj_isNull1 = smj_rightRow.isNullAt(1);
/* 197 */           int smj_value1 = smj_isNull1 ? -1 : (smj_rightRow.getInt(1));
/* 198 */           if (smj_isNull1) {
/* 199 */             smj_rightRow = null;
/* 200 */             continue;
/* 201 */           }
/* 202 */
/* 203 */           smj_value2 = smj_value1;
/* 204 */
/* 205 */         }
/* 206 */
/* 207 */         comp = 0;
/* 208 */         if (comp == 0) {
/* 209 */           comp = (smj_value > smj_value2 ? 1 : smj_value < smj_value2 ? -1 : 0);
/* 210 */         }
/* 211 */
/* 212 */         if (comp > 0) {
/* 213 */           smj_rightRow = null;
/* 214 */         } else if (comp < 0) {
/* 215 */           if (!smj_matches.isEmpty()) {
/* 216 */             smj_value3 = smj_value;
/* 217 */
/* 218 */             return true;
/* 219 */           }
/* 220 */           smj_leftRow = null;
/* 221 */         } else {
/* 222 */           smj_matches.add(smj_rightRow.copy());
/* 223 */           smj_rightRow = null;;
/* 224 */         }
/* 225 */       } while (smj_leftRow != null);
/* 226 */     }
/* 227 */     return false; // unreachable
/* 228 */   }
/* 229 */
/* 230 */   protected void processNext() throws java.io.IOException {
/* 231 */     while (findNextInnerJoinRows(smj_leftInput, smj_rightInput)) {
/* 232 */       int smj_size = smj_matches.size();
/* 233 */       boolean smj_loaded = false;
/* 234 */
/* 235 */       smj_isNull7 = smj_leftRow.isNullAt(5);
/* 236 */       smj_value9 = smj_isNull7 ? -1 : (smj_leftRow.getInt(5));
/* 237 */
/* 238 */       for (int smj_i = 0; smj_i < smj_size; smj_i ++) {
/* 239 */         InternalRow smj_rightRow1 = (InternalRow) smj_matches.get(smj_i);
/* 240 */         /* input[3, int] */
/* 241 */         boolean smj_isNull13 = smj_rightRow1.isNullAt(3);
/* 242 */         int smj_value15 = smj_isNull13 ? -1 : (smj_rightRow1.getInt(3));
/* 243 */         /* (input[11, int] < input[5, int]) */
/* 244 */         boolean smj_isNull14 = true;
/* 245 */         boolean smj_value16 = false;
/* 246 */
/* 247 */         if (!smj_isNull13) {
/* 248 */           if (!smj_isNull7) {
/* 249 */             smj_isNull14 = false; // resultCode could change nullability.
/* 250 */             smj_value16 = smj_value15 < smj_value9;
/* 251 */
/* 252 */           }
/* 253 */
/* 254 */         }
/* 255 */         if (smj_isNull14 || !smj_value16) continue;
/* 256 */         if (!smj_loaded) {
/* 257 */           smj_loaded = true;
/* 258 */
/* 259 */           smj_isNull2 = smj_leftRow.isNullAt(0);
/* 260 */           smj_value4 = smj_isNull2 ? -1 : (smj_leftRow.getInt(0));
/* 261 */
/* 262 */           smj_isNull3 = smj_leftRow.isNullAt(1);
/* 263 */           smj_value5 = smj_isNull3 ? -1 : (smj_leftRow.getInt(1));
/* 264 */
/* 265 */           smj_isNull4 = smj_leftRow.isNullAt(2);
/* 266 */           smj_value6 = smj_isNull4 ? -1 : (smj_leftRow.getInt(2));
/* 267 */
/* 268 */           smj_isNull5 = smj_leftRow.isNullAt(3);
/* 269 */           smj_value7 = smj_isNull5 ? -1 : (smj_leftRow.getInt(3));
/* 270 */
/* 271 */           smj_isNull6 = smj_leftRow.isNullAt(4);
/* 272 */           smj_value8 = smj_isNull6 ? -1 : (smj_leftRow.getInt(4));
/* 273 */
/* 274 */           smj_isNull8 = smj_leftRow.isNullAt(6);
/* 275 */           smj_value10 = smj_isNull8 ? -1 : (smj_leftRow.getInt(6));
/* 276 */
/* 277 */           smj_isNull9 = smj_leftRow.isNullAt(7);
/* 278 */           smj_value11 = smj_isNull9 ? -1 : (smj_leftRow.getInt(7));
/* 279 */
/* 280 */         }
/* 281 */         /* input[0, int] */
/* 282 */         boolean smj_isNull10 = smj_rightRow1.isNullAt(0);
/* 283 */         int smj_value12 = smj_isNull10 ? -1 : (smj_rightRow1.getInt(0));
/* 284 */         /* input[1, int] */
/* 285 */         boolean smj_isNull11 = smj_rightRow1.isNullAt(1);
/* 286 */         int smj_value13 = smj_isNull11 ? -1 : (smj_rightRow1.getInt(1));
/* 287 */         /* input[2, int] */
/* 288 */         boolean smj_isNull12 = smj_rightRow1.isNullAt(2);
/* 289 */         int smj_value14 = smj_isNull12 ? -1 : (smj_rightRow1.getInt(2));
/* 290 */         smj_metricValue.add(1);

@SparkQA
Copy link

SparkQA commented Feb 18, 2016

Test build #2544 has finished for PR 11248 at commit 8898957.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 18, 2016

Test build #51490 has finished for PR 11248 at commit 8898957.

  • This patch fails from timeout after a configured wait of 250m.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 18, 2016

Test build #51495 has finished for PR 11248 at commit 09a97d4.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 19, 2016

Test build #2545 has finished for PR 11248 at commit 09a97d4.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@srowen
Copy link
Member

srowen commented Feb 19, 2016

@davies this is linked to the wrong JIRA

@davies davies changed the title [SPARK-13375] [SQL] generate sort merge join [SPARK-13373] [SQL] generate sort merge join Feb 19, 2016
Conflicts:
	sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
	sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@SparkQA
Copy link

SparkQA commented Feb 21, 2016

Test build #51607 has finished for PR 11248 at commit 55148d4.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.


override def doProduce(ctx: CodegenContext): String = {
val input = ctx.freshName("input")
// Right now, Range is only used when there is one upstream.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Range -> InputAdapter?

@nongli
Copy link
Contributor

nongli commented Feb 23, 2016

Does it make sense to do something like this and remove the copies?

/* 056 */   private void agg_doAggregateWithoutKey() throws java.io.IOException {
/* 057 */     // initialize aggregation buffer
/* 058 */
/* 059 */     agg_bufIsNull = false;
/* 060 */     agg_bufValue = 0L;
/* 061 */

              smj_left = leftIter.next();
              smj_right = leftIter.next();
              while (true) {
                int comp = compare(smj_left, smj_right);
                if (comp > 0) {
                  if (!leftIter.hasNext()) break;
                  smj_left = leftIter.next();
                } else if (comp < 0) {
                  if (!rightIter.hasNext()) break;
                  rightIter = rightIter.next();
                }

                // Match
/* 071 */         long smj_value5 = smj_isNull3 ? -1L : (smj_rightRow1.getLong(0));
/* 072 */
/* 073 */         if (false || !true) continue;
/* 074 */         smj_metricValue.add(1);
/* 075 */
/* 076 */         // do aggregate
/* 077 */         /* (input[0, bigint] + 1) */
/* 078 */         long agg_value1 = -1L;
/* 079 */         agg_value1 = agg_bufValue + 1L;
/* 080 */         // update aggregation buffer
/* 081 */         agg_bufIsNull = false;
/* 082 */         agg_bufValue = agg_value1;
/* 083 */

                // Advance 1
                if (!leftIter.hasNext()) break;
                smj_left = leftIter.next();
              }
/* 088 */   }

@davies
Copy link
Contributor Author

davies commented Feb 23, 2016

@nongli We already avoid the copy until found the match. In order to support left N rows match with right M rows, we have to copy some of them. Right now, we copy the keys from left, and the rows from right. It's possible to have less copies, I'd like to leave that out of this PR, because SMJ usually required Exchange and Sort, those are more expensive than coping the matched rows from right.

@SparkQA
Copy link

SparkQA commented Feb 23, 2016

Test build #51771 has finished for PR 11248 at commit 0aa075f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

class BenchmarkWholeStageCodegen extends SparkFunSuite {
lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.sql.autoBroadcastJoinThreshold", "0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be set inside the sort merge join benchmark?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will use BroadcastHint when test with Broadcast join, it's better to disable it for all.

@nongli
Copy link
Contributor

nongli commented Feb 23, 2016

LGTM

@davies
Copy link
Contributor Author

davies commented Feb 23, 2016

Merging this into master, thanks!

@asfgit asfgit closed this in 9cdd867 Feb 23, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants