Skip to content

Conversation

@kiszk
Copy link
Member

@kiszk kiszk commented Mar 9, 2018

What changes were proposed in this pull request?

This PR fixes runtime error regarding a large query when a generated code has split classes. The issue is append(), stopEarly(), and other methods are not accessible from split classes that are not subclasses of BufferedRowIterator.
This PR fixes this issue by making them public.

Before applying the PR, we see the following exception by running the attached program with CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD=-1.

  test("SPARK-23598") {
    // When set -1 to CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD, an exception is thrown
    val df_pet_age = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name")
    df_pet_age.groupBy("name").avg("age").show()
  }

Exception:

19:40:52.591 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
19:41:32.319 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 0.0 (TID 0)
java.lang.IllegalAccessError: tried to access method org.apache.spark.sql.execution.BufferedRowIterator.shouldStop()Z from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1.agg_doAggregateWithKeys$(generated.java:203)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(generated.java:160)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:616)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:745)
...

Generated code (line 195 calles stopEarly()).

/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean agg_initAgg;
/* 010 */   private boolean agg_bufIsNull;
/* 011 */   private double agg_bufValue;
/* 012 */   private boolean agg_bufIsNull1;
/* 013 */   private long agg_bufValue1;
/* 014 */   private agg_FastHashMap agg_fastHashMap;
/* 015 */   private org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> agg_fastHashMapIter;
/* 016 */   private org.apache.spark.unsafe.KVIterator agg_mapIter;
/* 017 */   private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap;
/* 018 */   private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter;
/* 019 */   private scala.collection.Iterator inputadapter_input;
/* 020 */   private boolean agg_agg_isNull11;
/* 021 */   private boolean agg_agg_isNull25;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[] agg_mutableStateArray1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[2];
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 024 */   private UnsafeRow[] agg_mutableStateArray = new UnsafeRow[2];
/* 025 */
/* 026 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */
/* 034 */     agg_fastHashMap = new agg_FastHashMap(((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskMemoryManager(), ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getEmptyAggregationBuffer());
/* 035 */     agg_hashMap = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).createHashMap();
/* 036 */     inputadapter_input = inputs[0];
/* 037 */     agg_mutableStateArray[0] = new UnsafeRow(1);
/* 038 */     agg_mutableStateArray1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[0], 32);
/* 039 */     agg_mutableStateArray2[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[0], 1);
/* 040 */     agg_mutableStateArray[1] = new UnsafeRow(3);
/* 041 */     agg_mutableStateArray1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[1], 32);
/* 042 */     agg_mutableStateArray2[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[1], 3);
/* 043 */
/* 044 */   }
/* 045 */
/* 046 */   public class agg_FastHashMap {
/* 047 */     private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
/* 048 */     private int[] buckets;
/* 049 */     private int capacity = 1 << 16;
/* 050 */     private double loadFactor = 0.5;
/* 051 */     private int numBuckets = (int) (capacity / loadFactor);
/* 052 */     private int maxSteps = 2;
/* 053 */     private int numRows = 0;
/* 054 */     private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1] /* keyName */), org.apache.spark.sql.types.DataTypes.StringType);
/* 055 */     private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2] /* keyName */), org.apache.spark.sql.types.DataTypes.DoubleType)
/* 056 */     .add(((java.lang.String) references[3] /* keyName */), org.apache.spark.sql.types.DataTypes.LongType);
/* 057 */     private Object emptyVBase;
/* 058 */     private long emptyVOff;
/* 059 */     private int emptyVLen;
/* 060 */     private boolean isBatchFull = false;
/* 061 */
/* 062 */     public agg_FastHashMap(
/* 063 */       org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
/* 064 */       InternalRow emptyAggregationBuffer) {
/* 065 */       batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
/* 066 */       .allocate(keySchema, valueSchema, taskMemoryManager, capacity);
/* 067 */
/* 068 */       final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema);
/* 069 */       final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
/* 070 */
/* 071 */       emptyVBase = emptyBuffer;
/* 072 */       emptyVOff = Platform.BYTE_ARRAY_OFFSET;
/* 073 */       emptyVLen = emptyBuffer.length;
/* 074 */
/* 075 */       buckets = new int[numBuckets];
/* 076 */       java.util.Arrays.fill(buckets, -1);
/* 077 */     }
/* 078 */
/* 079 */     public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) {
/* 080 */       long h = hash(agg_key);
/* 081 */       int step = 0;
/* 082 */       int idx = (int) h & (numBuckets - 1);
/* 083 */       while (step < maxSteps) {
/* 084 */         // Return bucket index if it's either an empty slot or already contains the key
/* 085 */         if (buckets[idx] == -1) {
/* 086 */           if (numRows < capacity && !isBatchFull) {
/* 087 */             // creating the unsafe for new entry
/* 088 */             UnsafeRow agg_result = new UnsafeRow(1);
/* 089 */             org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder
/* 090 */             = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result,
/* 091 */               32);
/* 092 */             org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
/* 093 */             = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
/* 094 */               agg_holder,
/* 095 */               1);
/* 096 */             agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed
/* 097 */             agg_rowWriter.zeroOutNullBytes();
/* 098 */             agg_rowWriter.write(0, agg_key);
/* 099 */             agg_result.setTotalSize(agg_holder.totalSize());
/* 100 */             Object kbase = agg_result.getBaseObject();
/* 101 */             long koff = agg_result.getBaseOffset();
/* 102 */             int klen = agg_result.getSizeInBytes();
/* 103 */
/* 104 */             UnsafeRow vRow
/* 105 */             = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen);
/* 106 */             if (vRow == null) {
/* 107 */               isBatchFull = true;
/* 108 */             } else {
/* 109 */               buckets[idx] = numRows++;
/* 110 */             }
/* 111 */             return vRow;
/* 112 */           } else {
/* 113 */             // No more space
/* 114 */             return null;
/* 115 */           }
/* 116 */         } else if (equals(idx, agg_key)) {
/* 117 */           return batch.getValueRow(buckets[idx]);
/* 118 */         }
/* 119 */         idx = (idx + 1) & (numBuckets - 1);
/* 120 */         step++;
/* 121 */       }
/* 122 */       // Didn't find it
/* 123 */       return null;
/* 124 */     }
/* 125 */
/* 126 */     private boolean equals(int idx, UTF8String agg_key) {
/* 127 */       UnsafeRow row = batch.getKeyRow(buckets[idx]);
/* 128 */       return (row.getUTF8String(0).equals(agg_key));
/* 129 */     }
/* 130 */
/* 131 */     private long hash(UTF8String agg_key) {
/* 132 */       long agg_hash = 0;
/* 133 */
/* 134 */       int agg_result = 0;
/* 135 */       byte[] agg_bytes = agg_key.getBytes();
/* 136 */       for (int i = 0; i < agg_bytes.length; i++) {
/* 137 */         int agg_hash1 = agg_bytes[i];
/* 138 */         agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2);
/* 139 */       }
/* 140 */
/* 141 */       agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2);
/* 142 */
/* 143 */       return agg_hash;
/* 144 */     }
/* 145 */
/* 146 */     public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
/* 147 */       return batch.rowIterator();
/* 148 */     }
/* 149 */
/* 150 */     public void close() {
/* 151 */       batch.close();
/* 152 */     }
/* 153 */
/* 154 */   }
/* 155 */
/* 156 */   protected void processNext() throws java.io.IOException {
/* 157 */     if (!agg_initAgg) {
/* 158 */       agg_initAgg = true;
/* 159 */       long wholestagecodegen_beforeAgg = System.nanoTime();
/* 160 */       agg_nestedClassInstance1.agg_doAggregateWithKeys();
/* 161 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[8] /* aggTime */).add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000);
/* 162 */     }
/* 163 */
/* 164 */     // output the result
/* 165 */
/* 166 */     while (agg_fastHashMapIter.next()) {
/* 167 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey();
/* 168 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue();
/* 169 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 170 */
/* 171 */       if (shouldStop()) return;
/* 172 */     }
/* 173 */     agg_fastHashMap.close();
/* 174 */
/* 175 */     while (agg_mapIter.next()) {
/* 176 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey();
/* 177 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue();
/* 178 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 179 */
/* 180 */       if (shouldStop()) return;
/* 181 */     }
/* 182 */
/* 183 */     agg_mapIter.close();
/* 184 */     if (agg_sorter == null) {
/* 185 */       agg_hashMap.free();
/* 186 */     }
/* 187 */   }
/* 188 */
/* 189 */   private wholestagecodegen_NestedClass wholestagecodegen_nestedClassInstance = new wholestagecodegen_NestedClass();
/* 190 */   private agg_NestedClass1 agg_nestedClassInstance1 = new agg_NestedClass1();
/* 191 */   private agg_NestedClass agg_nestedClassInstance = new agg_NestedClass();
/* 192 */
/* 193 */   private class agg_NestedClass1 {
/* 194 */     private void agg_doAggregateWithKeys() throws java.io.IOException {
/* 195 */       while (inputadapter_input.hasNext() && !stopEarly()) {
/* 196 */         InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 197 */         int inputadapter_value = inputadapter_row.getInt(0);
/* 198 */         boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 199 */         UTF8String inputadapter_value1 = inputadapter_isNull1 ?
/* 200 */         null : (inputadapter_row.getUTF8String(1));
/* 201 */
/* 202 */         agg_nestedClassInstance.agg_doConsume(inputadapter_row, inputadapter_value, inputadapter_value1, inputadapter_isNull1);
/* 203 */         if (shouldStop()) return;
/* 204 */       }
/* 205 */
/* 206 */       agg_fastHashMapIter = agg_fastHashMap.rowIterator();
/* 207 */       agg_mapIter = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap, agg_sorter, ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[6] /* avgHashProbe */));
/* 208 */
/* 209 */     }
/* 210 */
/* 211 */   }
/* 212 */
/* 213 */   private class wholestagecodegen_NestedClass {
/* 214 */     private void agg_doAggregateWithKeysOutput(UnsafeRow agg_keyTerm, UnsafeRow agg_bufferTerm)
/* 215 */     throws java.io.IOException {
/* 216 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* numOutputRows */).add(1);
/* 217 */
/* 218 */       boolean agg_isNull35 = agg_keyTerm.isNullAt(0);
/* 219 */       UTF8String agg_value37 = agg_isNull35 ?
/* 220 */       null : (agg_keyTerm.getUTF8String(0));
/* 221 */       boolean agg_isNull36 = agg_bufferTerm.isNullAt(0);
/* 222 */       double agg_value38 = agg_isNull36 ?
/* 223 */       -1.0 : (agg_bufferTerm.getDouble(0));
/* 224 */       boolean agg_isNull37 = agg_bufferTerm.isNullAt(1);
/* 225 */       long agg_value39 = agg_isNull37 ?
/* 226 */       -1L : (agg_bufferTerm.getLong(1));
/* 227 */
/* 228 */       agg_mutableStateArray1[1].reset();
/* 229 */
/* 230 */       agg_mutableStateArray2[1].zeroOutNullBytes();
/* 231 */
/* 232 */       if (agg_isNull35) {
/* 233 */         agg_mutableStateArray2[1].setNullAt(0);
/* 234 */       } else {
/* 235 */         agg_mutableStateArray2[1].write(0, agg_value37);
/* 236 */       }
/* 237 */
/* 238 */       if (agg_isNull36) {
/* 239 */         agg_mutableStateArray2[1].setNullAt(1);
/* 240 */       } else {
/* 241 */         agg_mutableStateArray2[1].write(1, agg_value38);
/* 242 */       }
/* 243 */
/* 244 */       if (agg_isNull37) {
/* 245 */         agg_mutableStateArray2[1].setNullAt(2);
/* 246 */       } else {
/* 247 */         agg_mutableStateArray2[1].write(2, agg_value39);
/* 248 */       }
/* 249 */       agg_mutableStateArray[1].setTotalSize(agg_mutableStateArray1[1].totalSize());
/* 250 */       append(agg_mutableStateArray[1]);
/* 251 */
/* 252 */     }
/* 253 */
/* 254 */   }
/* 255 */
/* 256 */   private class agg_NestedClass {
/* 257 */     private void agg_doConsume(InternalRow inputadapter_row, int agg_expr_0, UTF8String agg_expr_1, boolean agg_exprIsNull_1) throws java.io.IOException {
/* 258 */       UnsafeRow agg_unsafeRowAggBuffer = null;
/* 259 */       UnsafeRow agg_fastAggBuffer = null;
/* 260 */
/* 261 */       if (true) {
/* 262 */         if (!agg_exprIsNull_1) {
/* 263 */           agg_fastAggBuffer = agg_fastHashMap.findOrInsert(
/* 264 */             agg_expr_1);
/* 265 */         }
/* 266 */       }
/* 267 */       // Cannot find the key in fast hash map, try regular hash map.
/* 268 */       if (agg_fastAggBuffer == null) {
/* 269 */         // generate grouping key
/* 270 */         agg_mutableStateArray1[0].reset();
/* 271 */
/* 272 */         agg_mutableStateArray2[0].zeroOutNullBytes();
/* 273 */
/* 274 */         if (agg_exprIsNull_1) {
/* 275 */           agg_mutableStateArray2[0].setNullAt(0);
/* 276 */         } else {
/* 277 */           agg_mutableStateArray2[0].write(0, agg_expr_1);
/* 278 */         }
/* 279 */         agg_mutableStateArray[0].setTotalSize(agg_mutableStateArray1[0].totalSize());
/* 280 */         int agg_value7 = 42;
/* 281 */
/* 282 */         if (!agg_exprIsNull_1) {
/* 283 */           agg_value7 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(agg_expr_1.getBaseObject(), agg_expr_1.getBaseOffset(), agg_expr_1.numBytes(), agg_value7);
/* 284 */         }
/* 285 */         if (true) {
/* 286 */           // try to get the buffer from hash map
/* 287 */           agg_unsafeRowAggBuffer =
/* 288 */           agg_hashMap.getAggregationBufferFromUnsafeRow(agg_mutableStateArray[0], agg_value7);
/* 289 */         }
/* 290 */         // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
/* 291 */         // aggregation after processing all input rows.
/* 292 */         if (agg_unsafeRowAggBuffer == null) {
/* 293 */           if (agg_sorter == null) {
/* 294 */             agg_sorter = agg_hashMap.destructAndCreateExternalSorter();
/* 295 */           } else {
/* 296 */             agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter());
/* 297 */           }
/* 298 */
/* 299 */           // the hash map had be spilled, it should have enough memory now,
/* 300 */           // try to allocate buffer again.
/* 301 */           agg_unsafeRowAggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(
/* 302 */             agg_mutableStateArray[0], agg_value7);
/* 303 */           if (agg_unsafeRowAggBuffer == null) {
/* 304 */             // failed to allocate the first page
/* 305 */             throw new OutOfMemoryError("No enough memory for aggregation");
/* 306 */           }
/* 307 */         }
/* 308 */
/* 309 */       }
/* 310 */
/* 311 */       if (agg_fastAggBuffer != null) {
/* 312 */         // common sub-expressions
/* 313 */         boolean agg_isNull21 = false;
/* 314 */         long agg_value23 = -1L;
/* 315 */         if (!false) {
/* 316 */           agg_value23 = (long) agg_expr_0;
/* 317 */         }
/* 318 */         // evaluate aggregate function
/* 319 */         boolean agg_isNull23 = true;
/* 320 */         double agg_value25 = -1.0;
/* 321 */
/* 322 */         boolean agg_isNull24 = agg_fastAggBuffer.isNullAt(0);
/* 323 */         double agg_value26 = agg_isNull24 ?
/* 324 */         -1.0 : (agg_fastAggBuffer.getDouble(0));
/* 325 */         if (!agg_isNull24) {
/* 326 */           agg_agg_isNull25 = true;
/* 327 */           double agg_value27 = -1.0;
/* 328 */           do {
/* 329 */             boolean agg_isNull26 = agg_isNull21;
/* 330 */             double agg_value28 = -1.0;
/* 331 */             if (!agg_isNull21) {
/* 332 */               agg_value28 = (double) agg_value23;
/* 333 */             }
/* 334 */             if (!agg_isNull26) {
/* 335 */               agg_agg_isNull25 = false;
/* 336 */               agg_value27 = agg_value28;
/* 337 */               continue;
/* 338 */             }
/* 339 */
/* 340 */             boolean agg_isNull27 = false;
/* 341 */             double agg_value29 = -1.0;
/* 342 */             if (!false) {
/* 343 */               agg_value29 = (double) 0;
/* 344 */             }
/* 345 */             if (!agg_isNull27) {
/* 346 */               agg_agg_isNull25 = false;
/* 347 */               agg_value27 = agg_value29;
/* 348 */               continue;
/* 349 */             }
/* 350 */
/* 351 */           } while (false);
/* 352 */
/* 353 */           agg_isNull23 = false; // resultCode could change nullability.
/* 354 */           agg_value25 = agg_value26 + agg_value27;
/* 355 */
/* 356 */         }
/* 357 */         boolean agg_isNull29 = false;
/* 358 */         long agg_value31 = -1L;
/* 359 */         if (!false && agg_isNull21) {
/* 360 */           boolean agg_isNull31 = agg_fastAggBuffer.isNullAt(1);
/* 361 */           long agg_value33 = agg_isNull31 ?
/* 362 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 363 */           agg_isNull29 = agg_isNull31;
/* 364 */           agg_value31 = agg_value33;
/* 365 */         } else {
/* 366 */           boolean agg_isNull32 = true;
/* 367 */           long agg_value34 = -1L;
/* 368 */
/* 369 */           boolean agg_isNull33 = agg_fastAggBuffer.isNullAt(1);
/* 370 */           long agg_value35 = agg_isNull33 ?
/* 371 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 372 */           if (!agg_isNull33) {
/* 373 */             agg_isNull32 = false; // resultCode could change nullability.
/* 374 */             agg_value34 = agg_value35 + 1L;
/* 375 */
/* 376 */           }
/* 377 */           agg_isNull29 = agg_isNull32;
/* 378 */           agg_value31 = agg_value34;
/* 379 */         }
/* 380 */         // update fast row
/* 381 */         if (!agg_isNull23) {
/* 382 */           agg_fastAggBuffer.setDouble(0, agg_value25);
/* 383 */         } else {
/* 384 */           agg_fastAggBuffer.setNullAt(0);
/* 385 */         }
/* 386 */
/* 387 */         if (!agg_isNull29) {
/* 388 */           agg_fastAggBuffer.setLong(1, agg_value31);
/* 389 */         } else {
/* 390 */           agg_fastAggBuffer.setNullAt(1);
/* 391 */         }
/* 392 */       } else {
/* 393 */         // common sub-expressions
/* 394 */         boolean agg_isNull7 = false;
/* 395 */         long agg_value9 = -1L;
/* 396 */         if (!false) {
/* 397 */           agg_value9 = (long) agg_expr_0;
/* 398 */         }
/* 399 */         // evaluate aggregate function
/* 400 */         boolean agg_isNull9 = true;
/* 401 */         double agg_value11 = -1.0;
/* 402 */
/* 403 */         boolean agg_isNull10 = agg_unsafeRowAggBuffer.isNullAt(0);
/* 404 */         double agg_value12 = agg_isNull10 ?
/* 405 */         -1.0 : (agg_unsafeRowAggBuffer.getDouble(0));
/* 406 */         if (!agg_isNull10) {
/* 407 */           agg_agg_isNull11 = true;
/* 408 */           double agg_value13 = -1.0;
/* 409 */           do {
/* 410 */             boolean agg_isNull12 = agg_isNull7;
/* 411 */             double agg_value14 = -1.0;
/* 412 */             if (!agg_isNull7) {
/* 413 */               agg_value14 = (double) agg_value9;
/* 414 */             }
/* 415 */             if (!agg_isNull12) {
/* 416 */               agg_agg_isNull11 = false;
/* 417 */               agg_value13 = agg_value14;
/* 418 */               continue;
/* 419 */             }
/* 420 */
/* 421 */             boolean agg_isNull13 = false;
/* 422 */             double agg_value15 = -1.0;
/* 423 */             if (!false) {
/* 424 */               agg_value15 = (double) 0;
/* 425 */             }
/* 426 */             if (!agg_isNull13) {
/* 427 */               agg_agg_isNull11 = false;
/* 428 */               agg_value13 = agg_value15;
/* 429 */               continue;
/* 430 */             }
/* 431 */
/* 432 */           } while (false);
/* 433 */
/* 434 */           agg_isNull9 = false; // resultCode could change nullability.
/* 435 */           agg_value11 = agg_value12 + agg_value13;
/* 436 */
/* 437 */         }
/* 438 */         boolean agg_isNull15 = false;
/* 439 */         long agg_value17 = -1L;
/* 440 */         if (!false && agg_isNull7) {
/* 441 */           boolean agg_isNull17 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 442 */           long agg_value19 = agg_isNull17 ?
/* 443 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 444 */           agg_isNull15 = agg_isNull17;
/* 445 */           agg_value17 = agg_value19;
/* 446 */         } else {
/* 447 */           boolean agg_isNull18 = true;
/* 448 */           long agg_value20 = -1L;
/* 449 */
/* 450 */           boolean agg_isNull19 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 451 */           long agg_value21 = agg_isNull19 ?
/* 452 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 453 */           if (!agg_isNull19) {
/* 454 */             agg_isNull18 = false; // resultCode could change nullability.
/* 455 */             agg_value20 = agg_value21 + 1L;
/* 456 */
/* 457 */           }
/* 458 */           agg_isNull15 = agg_isNull18;
/* 459 */           agg_value17 = agg_value20;
/* 460 */         }
/* 461 */         // update unsafe row buffer
/* 462 */         if (!agg_isNull9) {
/* 463 */           agg_unsafeRowAggBuffer.setDouble(0, agg_value11);
/* 464 */         } else {
/* 465 */           agg_unsafeRowAggBuffer.setNullAt(0);
/* 466 */         }
/* 467 */
/* 468 */         if (!agg_isNull15) {
/* 469 */           agg_unsafeRowAggBuffer.setLong(1, agg_value17);
/* 470 */         } else {
/* 471 */           agg_unsafeRowAggBuffer.setNullAt(1);
/* 472 */         }
/* 473 */
/* 474 */       }
/* 475 */
/* 476 */     }
/* 477 */
/* 478 */   }
/* 479 */
/* 480 */ }

How was this patch tested?

Added UT into WholeStageCodegenSuite

@kiszk
Copy link
Member Author

kiszk commented Mar 9, 2018

cc: @hvanhovell @mgaido91

@viirya
Copy link
Member

viirya commented Mar 9, 2018

Looks good.

@SparkQA
Copy link

SparkQA commented Mar 9, 2018

Test build #88110 has finished for PR 20779 at commit e206fff.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@mgaido91
Copy link
Contributor

mgaido91 commented Mar 9, 2018

@kiszk can we add a test case for this?

@viirya
Copy link
Member

viirya commented Mar 9, 2018

Use the example in the PR description?

@kiszk
Copy link
Member Author

kiszk commented Mar 9, 2018

Good point. If we could set a specific value to GENERATED_CLASS_SIZE_THRESHOLD, we can add a test case. For now, I updated a value by hand, build it, and run test.

How do we make GENERATED_CLASS_SIZE_THRESHOLD non-final for testing? For example, by adding setter and getter method.

@mgaido91
Copy link
Contributor

mgaido91 commented Mar 9, 2018

I'd rather avoid to change it, because modifying it we can always trigger exceptions like the constant pool size limit. Can't we repro this without changing that value?

@kiszk
Copy link
Member Author

kiszk commented Mar 9, 2018

Under the current situation, I think that we have to create very very huge query since we made codegen stable.

@mgaido91
Copy link
Contributor

mgaido91 commented Mar 9, 2018

I am not sure. If you set it to -1,every new function generates a new inner class, i.e. 1+ entries in the constant pool. So a not very big query can crash. Moreover, since that is a constant, I think that if an error is not reproducible with that value, then it is not a issue. Otherwise I think we can generate many "theoretical" issue which in practise are not affecting us, since anyway there are limits which we are just making hard to reach, but we cannot avoid.

@viirya
Copy link
Member

viirya commented Mar 9, 2018

Looks like this issue is reasonable and we will hit it under some situations. Currently it is hard to find a reproducible test case, but this change can prevent it if I don't miss anything. Do we have any reason to block it?

@mgaido91
Copy link
Contributor

mgaido91 commented Mar 9, 2018

I don't want to block it, I just thought that if we can add a new test case for such situations it would be better, in order also to prevent regression in the future; since the next release is not going to be soon, I think there is no hurry and we have the time to try and find a good test case. But if you all think that it is not necessary, then we can merge it as it is. The change itself of course LGTM as I said on the JIRA.

@dvogelbacher
Copy link
Contributor

dvogelbacher commented Mar 9, 2018

I was able to reproduce just now without changing the value of the constant (i.e., with unmodified code from master):

 ➜  spark git:(master) ./bin/spark-shell
18/03/09 11:11:02 WARN Utils: Your hostname, dvogelbac resolves to a loopback address: 127.0.0.1; using 10.111.11.111 instead (on interface en0)
18/03/09 11:11:02 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
18/03/09 11:11:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Spark context Web UI available at http://10.111.11.111:4040
Spark context available as 'sc' (master = local[*], app id = local-1520593867110).
Spark session available as 'spark'.
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 2.4.0-SNAPSHOT
      /_/

Using Scala version 2.11.8 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_121)
Type in expressions to have them evaluated.
Type :help for more information.

scala> spark.conf.set("spark.sql.shuffle.partitions", 1)

scala> val df_pet_age = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name")
df_pet_age: org.apache.spark.sql.DataFrame = [age: int, name: string]

scala> df_pet_age.groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age")).limit(1).show()
[Stage 1:>                                                          (0 + 1) / 1]18/03/09 11:11:21 ERROR Executor: Exception in task 0.0 in stage 1.0 (TID 3)
java.lang.IllegalAccessError: tried to access method org.apache.spark.sql.execution.BufferedRowIterator.shouldStop()Z from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2$agg_NestedClass
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2$agg_NestedClass.agg_doAggregateWithKeys2$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2$agg_NestedClass.agg_doAggregateWithKeys1$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2$agg_NestedClass.agg_doAggregateWithKeys$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:616)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:745)

@dvogelbacher
Copy link
Contributor

Note though, that executing this also sometimes slays the compiler (and sometimes it doesn't and leads to the above error):

That entry seems to have slain the compiler.  Shall I replay
your session? I can re-run each line except the last one.

@viirya
Copy link
Member

viirya commented Mar 9, 2018

Yeah, I agree that if we can find an appropriate test case, that'd be better. Let's see if @dvogelbacher's test case works. Thanks!

@kiszk
Copy link
Member Author

kiszk commented Mar 9, 2018

@dvogelbacher Great, I have been also investigating the similar approach (i.e. concatenating agg()). Let me try your case.

}
}

test("SPARK-23598: Avoid compilation error with a lot of aggregation operations") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a compilation but a runtime error. maybe better say: codegen working for lots of aggregation operations without runtime errors

test("SPARK-23598: Avoid compilation error with a lot of aggregation operations") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val df = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name")
.groupBy("name").agg(avg("age").alias("age")).groupBy("name").agg(avg("age").alias("age"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do this with a loop?

@SparkQA
Copy link

SparkQA commented Mar 9, 2018

Test build #88131 has finished for PR 20779 at commit dccae53.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@kiszk kiszk changed the title [SPARK-23598][SQL] Make methods in BufferedRowIterator public to avoid compilation for a large query [SPARK-23598][SQL] Make methods in BufferedRowIterator public to avoid runtime error for a large query Mar 9, 2018
@SparkQA
Copy link

SparkQA commented Mar 10, 2018

Test build #88145 has finished for PR 20779 at commit 6e45791.

  • This patch fails due to an unknown error code, -9.
  • This patch merges cleanly.
  • This patch adds no public classes.

@viirya
Copy link
Member

viirya commented Mar 10, 2018

retest this please.

@mgaido91
Copy link
Contributor

LGTM

@dvogelbacher
Copy link
Contributor

LGTM as well, thanks for making the PR @kiszk

@SparkQA
Copy link

SparkQA commented Mar 10, 2018

Test build #88147 has finished for PR 20779 at commit 6e45791.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@mgaido91
Copy link
Contributor

@kiszk the UT error is valid. How did you tested it? Any idea about the reason of the OOM?

@kiszk
Copy link
Member Author

kiszk commented Mar 10, 2018

Ah, I increased the heap size (4GB) in my environment with IntelliJ.
Should we create a class like #20636?

@mgaido91
Copy link
Contributor

I don't think so. There is an option to change the heap size for test execution, but I am not sure we are allowed/it is a good idea to do that. Let's hear others' opinion...

@kiszk
Copy link
Member Author

kiszk commented Mar 10, 2018

Let me reduce the number of iterations. Another option is to revert this change to use non-loop version that worked without an exception.

@SparkQA
Copy link

SparkQA commented Mar 10, 2018

Test build #88154 has finished for PR 20779 at commit 603ce0f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@kiszk
Copy link
Member Author

kiszk commented Mar 13, 2018

ping @hvanhovell

Copy link
Member

@viirya viirya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

* Append a row to currentRows.
*/
protected void append(InternalRow row) {
public void append(InternalRow row) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Although we added the test, should we also add a short sentence that said this is public so inner classes can also access it?

@SparkQA
Copy link

SparkQA commented Mar 13, 2018

Test build #88212 has finished for PR 20779 at commit 8fb5df0.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Contributor

@hvanhovell hvanhovell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - merging to master and 2.3

@asfgit asfgit closed this in 1098933 Mar 13, 2018
asfgit pushed a commit that referenced this pull request Mar 15, 2018
…d runtime error for a large query

## What changes were proposed in this pull request?

This PR fixes runtime error regarding a large query when a generated code has split classes. The issue is `append()`, `stopEarly()`, and other methods are not accessible from split classes that are not subclasses of `BufferedRowIterator`.
This PR fixes this issue by making them `public`.

Before applying the PR, we see the following exception by running the attached program with `CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD=-1`.
```
  test("SPARK-23598") {
    // When set -1 to CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD, an exception is thrown
    val df_pet_age = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name")
    df_pet_age.groupBy("name").avg("age").show()
  }
```

Exception:
```
19:40:52.591 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
19:41:32.319 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 0.0 (TID 0)
java.lang.IllegalAccessError: tried to access method org.apache.spark.sql.execution.BufferedRowIterator.shouldStop()Z from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1.agg_doAggregateWithKeys$(generated.java:203)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(generated.java:160)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:616)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:745)
...
```

Generated code (line 195 calles `stopEarly()`).
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean agg_initAgg;
/* 010 */   private boolean agg_bufIsNull;
/* 011 */   private double agg_bufValue;
/* 012 */   private boolean agg_bufIsNull1;
/* 013 */   private long agg_bufValue1;
/* 014 */   private agg_FastHashMap agg_fastHashMap;
/* 015 */   private org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> agg_fastHashMapIter;
/* 016 */   private org.apache.spark.unsafe.KVIterator agg_mapIter;
/* 017 */   private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap;
/* 018 */   private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter;
/* 019 */   private scala.collection.Iterator inputadapter_input;
/* 020 */   private boolean agg_agg_isNull11;
/* 021 */   private boolean agg_agg_isNull25;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[] agg_mutableStateArray1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[2];
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 024 */   private UnsafeRow[] agg_mutableStateArray = new UnsafeRow[2];
/* 025 */
/* 026 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */
/* 034 */     agg_fastHashMap = new agg_FastHashMap(((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskMemoryManager(), ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getEmptyAggregationBuffer());
/* 035 */     agg_hashMap = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).createHashMap();
/* 036 */     inputadapter_input = inputs[0];
/* 037 */     agg_mutableStateArray[0] = new UnsafeRow(1);
/* 038 */     agg_mutableStateArray1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[0], 32);
/* 039 */     agg_mutableStateArray2[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[0], 1);
/* 040 */     agg_mutableStateArray[1] = new UnsafeRow(3);
/* 041 */     agg_mutableStateArray1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[1], 32);
/* 042 */     agg_mutableStateArray2[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[1], 3);
/* 043 */
/* 044 */   }
/* 045 */
/* 046 */   public class agg_FastHashMap {
/* 047 */     private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
/* 048 */     private int[] buckets;
/* 049 */     private int capacity = 1 << 16;
/* 050 */     private double loadFactor = 0.5;
/* 051 */     private int numBuckets = (int) (capacity / loadFactor);
/* 052 */     private int maxSteps = 2;
/* 053 */     private int numRows = 0;
/* 054 */     private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1] /* keyName */), org.apache.spark.sql.types.DataTypes.StringType);
/* 055 */     private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2] /* keyName */), org.apache.spark.sql.types.DataTypes.DoubleType)
/* 056 */     .add(((java.lang.String) references[3] /* keyName */), org.apache.spark.sql.types.DataTypes.LongType);
/* 057 */     private Object emptyVBase;
/* 058 */     private long emptyVOff;
/* 059 */     private int emptyVLen;
/* 060 */     private boolean isBatchFull = false;
/* 061 */
/* 062 */     public agg_FastHashMap(
/* 063 */       org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
/* 064 */       InternalRow emptyAggregationBuffer) {
/* 065 */       batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
/* 066 */       .allocate(keySchema, valueSchema, taskMemoryManager, capacity);
/* 067 */
/* 068 */       final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema);
/* 069 */       final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
/* 070 */
/* 071 */       emptyVBase = emptyBuffer;
/* 072 */       emptyVOff = Platform.BYTE_ARRAY_OFFSET;
/* 073 */       emptyVLen = emptyBuffer.length;
/* 074 */
/* 075 */       buckets = new int[numBuckets];
/* 076 */       java.util.Arrays.fill(buckets, -1);
/* 077 */     }
/* 078 */
/* 079 */     public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) {
/* 080 */       long h = hash(agg_key);
/* 081 */       int step = 0;
/* 082 */       int idx = (int) h & (numBuckets - 1);
/* 083 */       while (step < maxSteps) {
/* 084 */         // Return bucket index if it's either an empty slot or already contains the key
/* 085 */         if (buckets[idx] == -1) {
/* 086 */           if (numRows < capacity && !isBatchFull) {
/* 087 */             // creating the unsafe for new entry
/* 088 */             UnsafeRow agg_result = new UnsafeRow(1);
/* 089 */             org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder
/* 090 */             = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result,
/* 091 */               32);
/* 092 */             org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
/* 093 */             = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
/* 094 */               agg_holder,
/* 095 */               1);
/* 096 */             agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed
/* 097 */             agg_rowWriter.zeroOutNullBytes();
/* 098 */             agg_rowWriter.write(0, agg_key);
/* 099 */             agg_result.setTotalSize(agg_holder.totalSize());
/* 100 */             Object kbase = agg_result.getBaseObject();
/* 101 */             long koff = agg_result.getBaseOffset();
/* 102 */             int klen = agg_result.getSizeInBytes();
/* 103 */
/* 104 */             UnsafeRow vRow
/* 105 */             = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen);
/* 106 */             if (vRow == null) {
/* 107 */               isBatchFull = true;
/* 108 */             } else {
/* 109 */               buckets[idx] = numRows++;
/* 110 */             }
/* 111 */             return vRow;
/* 112 */           } else {
/* 113 */             // No more space
/* 114 */             return null;
/* 115 */           }
/* 116 */         } else if (equals(idx, agg_key)) {
/* 117 */           return batch.getValueRow(buckets[idx]);
/* 118 */         }
/* 119 */         idx = (idx + 1) & (numBuckets - 1);
/* 120 */         step++;
/* 121 */       }
/* 122 */       // Didn't find it
/* 123 */       return null;
/* 124 */     }
/* 125 */
/* 126 */     private boolean equals(int idx, UTF8String agg_key) {
/* 127 */       UnsafeRow row = batch.getKeyRow(buckets[idx]);
/* 128 */       return (row.getUTF8String(0).equals(agg_key));
/* 129 */     }
/* 130 */
/* 131 */     private long hash(UTF8String agg_key) {
/* 132 */       long agg_hash = 0;
/* 133 */
/* 134 */       int agg_result = 0;
/* 135 */       byte[] agg_bytes = agg_key.getBytes();
/* 136 */       for (int i = 0; i < agg_bytes.length; i++) {
/* 137 */         int agg_hash1 = agg_bytes[i];
/* 138 */         agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2);
/* 139 */       }
/* 140 */
/* 141 */       agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2);
/* 142 */
/* 143 */       return agg_hash;
/* 144 */     }
/* 145 */
/* 146 */     public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
/* 147 */       return batch.rowIterator();
/* 148 */     }
/* 149 */
/* 150 */     public void close() {
/* 151 */       batch.close();
/* 152 */     }
/* 153 */
/* 154 */   }
/* 155 */
/* 156 */   protected void processNext() throws java.io.IOException {
/* 157 */     if (!agg_initAgg) {
/* 158 */       agg_initAgg = true;
/* 159 */       long wholestagecodegen_beforeAgg = System.nanoTime();
/* 160 */       agg_nestedClassInstance1.agg_doAggregateWithKeys();
/* 161 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[8] /* aggTime */).add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000);
/* 162 */     }
/* 163 */
/* 164 */     // output the result
/* 165 */
/* 166 */     while (agg_fastHashMapIter.next()) {
/* 167 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey();
/* 168 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue();
/* 169 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 170 */
/* 171 */       if (shouldStop()) return;
/* 172 */     }
/* 173 */     agg_fastHashMap.close();
/* 174 */
/* 175 */     while (agg_mapIter.next()) {
/* 176 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey();
/* 177 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue();
/* 178 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 179 */
/* 180 */       if (shouldStop()) return;
/* 181 */     }
/* 182 */
/* 183 */     agg_mapIter.close();
/* 184 */     if (agg_sorter == null) {
/* 185 */       agg_hashMap.free();
/* 186 */     }
/* 187 */   }
/* 188 */
/* 189 */   private wholestagecodegen_NestedClass wholestagecodegen_nestedClassInstance = new wholestagecodegen_NestedClass();
/* 190 */   private agg_NestedClass1 agg_nestedClassInstance1 = new agg_NestedClass1();
/* 191 */   private agg_NestedClass agg_nestedClassInstance = new agg_NestedClass();
/* 192 */
/* 193 */   private class agg_NestedClass1 {
/* 194 */     private void agg_doAggregateWithKeys() throws java.io.IOException {
/* 195 */       while (inputadapter_input.hasNext() && !stopEarly()) {
/* 196 */         InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 197 */         int inputadapter_value = inputadapter_row.getInt(0);
/* 198 */         boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 199 */         UTF8String inputadapter_value1 = inputadapter_isNull1 ?
/* 200 */         null : (inputadapter_row.getUTF8String(1));
/* 201 */
/* 202 */         agg_nestedClassInstance.agg_doConsume(inputadapter_row, inputadapter_value, inputadapter_value1, inputadapter_isNull1);
/* 203 */         if (shouldStop()) return;
/* 204 */       }
/* 205 */
/* 206 */       agg_fastHashMapIter = agg_fastHashMap.rowIterator();
/* 207 */       agg_mapIter = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap, agg_sorter, ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[6] /* avgHashProbe */));
/* 208 */
/* 209 */     }
/* 210 */
/* 211 */   }
/* 212 */
/* 213 */   private class wholestagecodegen_NestedClass {
/* 214 */     private void agg_doAggregateWithKeysOutput(UnsafeRow agg_keyTerm, UnsafeRow agg_bufferTerm)
/* 215 */     throws java.io.IOException {
/* 216 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* numOutputRows */).add(1);
/* 217 */
/* 218 */       boolean agg_isNull35 = agg_keyTerm.isNullAt(0);
/* 219 */       UTF8String agg_value37 = agg_isNull35 ?
/* 220 */       null : (agg_keyTerm.getUTF8String(0));
/* 221 */       boolean agg_isNull36 = agg_bufferTerm.isNullAt(0);
/* 222 */       double agg_value38 = agg_isNull36 ?
/* 223 */       -1.0 : (agg_bufferTerm.getDouble(0));
/* 224 */       boolean agg_isNull37 = agg_bufferTerm.isNullAt(1);
/* 225 */       long agg_value39 = agg_isNull37 ?
/* 226 */       -1L : (agg_bufferTerm.getLong(1));
/* 227 */
/* 228 */       agg_mutableStateArray1[1].reset();
/* 229 */
/* 230 */       agg_mutableStateArray2[1].zeroOutNullBytes();
/* 231 */
/* 232 */       if (agg_isNull35) {
/* 233 */         agg_mutableStateArray2[1].setNullAt(0);
/* 234 */       } else {
/* 235 */         agg_mutableStateArray2[1].write(0, agg_value37);
/* 236 */       }
/* 237 */
/* 238 */       if (agg_isNull36) {
/* 239 */         agg_mutableStateArray2[1].setNullAt(1);
/* 240 */       } else {
/* 241 */         agg_mutableStateArray2[1].write(1, agg_value38);
/* 242 */       }
/* 243 */
/* 244 */       if (agg_isNull37) {
/* 245 */         agg_mutableStateArray2[1].setNullAt(2);
/* 246 */       } else {
/* 247 */         agg_mutableStateArray2[1].write(2, agg_value39);
/* 248 */       }
/* 249 */       agg_mutableStateArray[1].setTotalSize(agg_mutableStateArray1[1].totalSize());
/* 250 */       append(agg_mutableStateArray[1]);
/* 251 */
/* 252 */     }
/* 253 */
/* 254 */   }
/* 255 */
/* 256 */   private class agg_NestedClass {
/* 257 */     private void agg_doConsume(InternalRow inputadapter_row, int agg_expr_0, UTF8String agg_expr_1, boolean agg_exprIsNull_1) throws java.io.IOException {
/* 258 */       UnsafeRow agg_unsafeRowAggBuffer = null;
/* 259 */       UnsafeRow agg_fastAggBuffer = null;
/* 260 */
/* 261 */       if (true) {
/* 262 */         if (!agg_exprIsNull_1) {
/* 263 */           agg_fastAggBuffer = agg_fastHashMap.findOrInsert(
/* 264 */             agg_expr_1);
/* 265 */         }
/* 266 */       }
/* 267 */       // Cannot find the key in fast hash map, try regular hash map.
/* 268 */       if (agg_fastAggBuffer == null) {
/* 269 */         // generate grouping key
/* 270 */         agg_mutableStateArray1[0].reset();
/* 271 */
/* 272 */         agg_mutableStateArray2[0].zeroOutNullBytes();
/* 273 */
/* 274 */         if (agg_exprIsNull_1) {
/* 275 */           agg_mutableStateArray2[0].setNullAt(0);
/* 276 */         } else {
/* 277 */           agg_mutableStateArray2[0].write(0, agg_expr_1);
/* 278 */         }
/* 279 */         agg_mutableStateArray[0].setTotalSize(agg_mutableStateArray1[0].totalSize());
/* 280 */         int agg_value7 = 42;
/* 281 */
/* 282 */         if (!agg_exprIsNull_1) {
/* 283 */           agg_value7 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(agg_expr_1.getBaseObject(), agg_expr_1.getBaseOffset(), agg_expr_1.numBytes(), agg_value7);
/* 284 */         }
/* 285 */         if (true) {
/* 286 */           // try to get the buffer from hash map
/* 287 */           agg_unsafeRowAggBuffer =
/* 288 */           agg_hashMap.getAggregationBufferFromUnsafeRow(agg_mutableStateArray[0], agg_value7);
/* 289 */         }
/* 290 */         // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
/* 291 */         // aggregation after processing all input rows.
/* 292 */         if (agg_unsafeRowAggBuffer == null) {
/* 293 */           if (agg_sorter == null) {
/* 294 */             agg_sorter = agg_hashMap.destructAndCreateExternalSorter();
/* 295 */           } else {
/* 296 */             agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter());
/* 297 */           }
/* 298 */
/* 299 */           // the hash map had be spilled, it should have enough memory now,
/* 300 */           // try to allocate buffer again.
/* 301 */           agg_unsafeRowAggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(
/* 302 */             agg_mutableStateArray[0], agg_value7);
/* 303 */           if (agg_unsafeRowAggBuffer == null) {
/* 304 */             // failed to allocate the first page
/* 305 */             throw new OutOfMemoryError("No enough memory for aggregation");
/* 306 */           }
/* 307 */         }
/* 308 */
/* 309 */       }
/* 310 */
/* 311 */       if (agg_fastAggBuffer != null) {
/* 312 */         // common sub-expressions
/* 313 */         boolean agg_isNull21 = false;
/* 314 */         long agg_value23 = -1L;
/* 315 */         if (!false) {
/* 316 */           agg_value23 = (long) agg_expr_0;
/* 317 */         }
/* 318 */         // evaluate aggregate function
/* 319 */         boolean agg_isNull23 = true;
/* 320 */         double agg_value25 = -1.0;
/* 321 */
/* 322 */         boolean agg_isNull24 = agg_fastAggBuffer.isNullAt(0);
/* 323 */         double agg_value26 = agg_isNull24 ?
/* 324 */         -1.0 : (agg_fastAggBuffer.getDouble(0));
/* 325 */         if (!agg_isNull24) {
/* 326 */           agg_agg_isNull25 = true;
/* 327 */           double agg_value27 = -1.0;
/* 328 */           do {
/* 329 */             boolean agg_isNull26 = agg_isNull21;
/* 330 */             double agg_value28 = -1.0;
/* 331 */             if (!agg_isNull21) {
/* 332 */               agg_value28 = (double) agg_value23;
/* 333 */             }
/* 334 */             if (!agg_isNull26) {
/* 335 */               agg_agg_isNull25 = false;
/* 336 */               agg_value27 = agg_value28;
/* 337 */               continue;
/* 338 */             }
/* 339 */
/* 340 */             boolean agg_isNull27 = false;
/* 341 */             double agg_value29 = -1.0;
/* 342 */             if (!false) {
/* 343 */               agg_value29 = (double) 0;
/* 344 */             }
/* 345 */             if (!agg_isNull27) {
/* 346 */               agg_agg_isNull25 = false;
/* 347 */               agg_value27 = agg_value29;
/* 348 */               continue;
/* 349 */             }
/* 350 */
/* 351 */           } while (false);
/* 352 */
/* 353 */           agg_isNull23 = false; // resultCode could change nullability.
/* 354 */           agg_value25 = agg_value26 + agg_value27;
/* 355 */
/* 356 */         }
/* 357 */         boolean agg_isNull29 = false;
/* 358 */         long agg_value31 = -1L;
/* 359 */         if (!false && agg_isNull21) {
/* 360 */           boolean agg_isNull31 = agg_fastAggBuffer.isNullAt(1);
/* 361 */           long agg_value33 = agg_isNull31 ?
/* 362 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 363 */           agg_isNull29 = agg_isNull31;
/* 364 */           agg_value31 = agg_value33;
/* 365 */         } else {
/* 366 */           boolean agg_isNull32 = true;
/* 367 */           long agg_value34 = -1L;
/* 368 */
/* 369 */           boolean agg_isNull33 = agg_fastAggBuffer.isNullAt(1);
/* 370 */           long agg_value35 = agg_isNull33 ?
/* 371 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 372 */           if (!agg_isNull33) {
/* 373 */             agg_isNull32 = false; // resultCode could change nullability.
/* 374 */             agg_value34 = agg_value35 + 1L;
/* 375 */
/* 376 */           }
/* 377 */           agg_isNull29 = agg_isNull32;
/* 378 */           agg_value31 = agg_value34;
/* 379 */         }
/* 380 */         // update fast row
/* 381 */         if (!agg_isNull23) {
/* 382 */           agg_fastAggBuffer.setDouble(0, agg_value25);
/* 383 */         } else {
/* 384 */           agg_fastAggBuffer.setNullAt(0);
/* 385 */         }
/* 386 */
/* 387 */         if (!agg_isNull29) {
/* 388 */           agg_fastAggBuffer.setLong(1, agg_value31);
/* 389 */         } else {
/* 390 */           agg_fastAggBuffer.setNullAt(1);
/* 391 */         }
/* 392 */       } else {
/* 393 */         // common sub-expressions
/* 394 */         boolean agg_isNull7 = false;
/* 395 */         long agg_value9 = -1L;
/* 396 */         if (!false) {
/* 397 */           agg_value9 = (long) agg_expr_0;
/* 398 */         }
/* 399 */         // evaluate aggregate function
/* 400 */         boolean agg_isNull9 = true;
/* 401 */         double agg_value11 = -1.0;
/* 402 */
/* 403 */         boolean agg_isNull10 = agg_unsafeRowAggBuffer.isNullAt(0);
/* 404 */         double agg_value12 = agg_isNull10 ?
/* 405 */         -1.0 : (agg_unsafeRowAggBuffer.getDouble(0));
/* 406 */         if (!agg_isNull10) {
/* 407 */           agg_agg_isNull11 = true;
/* 408 */           double agg_value13 = -1.0;
/* 409 */           do {
/* 410 */             boolean agg_isNull12 = agg_isNull7;
/* 411 */             double agg_value14 = -1.0;
/* 412 */             if (!agg_isNull7) {
/* 413 */               agg_value14 = (double) agg_value9;
/* 414 */             }
/* 415 */             if (!agg_isNull12) {
/* 416 */               agg_agg_isNull11 = false;
/* 417 */               agg_value13 = agg_value14;
/* 418 */               continue;
/* 419 */             }
/* 420 */
/* 421 */             boolean agg_isNull13 = false;
/* 422 */             double agg_value15 = -1.0;
/* 423 */             if (!false) {
/* 424 */               agg_value15 = (double) 0;
/* 425 */             }
/* 426 */             if (!agg_isNull13) {
/* 427 */               agg_agg_isNull11 = false;
/* 428 */               agg_value13 = agg_value15;
/* 429 */               continue;
/* 430 */             }
/* 431 */
/* 432 */           } while (false);
/* 433 */
/* 434 */           agg_isNull9 = false; // resultCode could change nullability.
/* 435 */           agg_value11 = agg_value12 + agg_value13;
/* 436 */
/* 437 */         }
/* 438 */         boolean agg_isNull15 = false;
/* 439 */         long agg_value17 = -1L;
/* 440 */         if (!false && agg_isNull7) {
/* 441 */           boolean agg_isNull17 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 442 */           long agg_value19 = agg_isNull17 ?
/* 443 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 444 */           agg_isNull15 = agg_isNull17;
/* 445 */           agg_value17 = agg_value19;
/* 446 */         } else {
/* 447 */           boolean agg_isNull18 = true;
/* 448 */           long agg_value20 = -1L;
/* 449 */
/* 450 */           boolean agg_isNull19 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 451 */           long agg_value21 = agg_isNull19 ?
/* 452 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 453 */           if (!agg_isNull19) {
/* 454 */             agg_isNull18 = false; // resultCode could change nullability.
/* 455 */             agg_value20 = agg_value21 + 1L;
/* 456 */
/* 457 */           }
/* 458 */           agg_isNull15 = agg_isNull18;
/* 459 */           agg_value17 = agg_value20;
/* 460 */         }
/* 461 */         // update unsafe row buffer
/* 462 */         if (!agg_isNull9) {
/* 463 */           agg_unsafeRowAggBuffer.setDouble(0, agg_value11);
/* 464 */         } else {
/* 465 */           agg_unsafeRowAggBuffer.setNullAt(0);
/* 466 */         }
/* 467 */
/* 468 */         if (!agg_isNull15) {
/* 469 */           agg_unsafeRowAggBuffer.setLong(1, agg_value17);
/* 470 */         } else {
/* 471 */           agg_unsafeRowAggBuffer.setNullAt(1);
/* 472 */         }
/* 473 */
/* 474 */       }
/* 475 */
/* 476 */     }
/* 477 */
/* 478 */   }
/* 479 */
/* 480 */ }
```

## How was this patch tested?

Added UT into `WholeStageCodegenSuite`

Author: Kazuaki Ishizaki <[email protected]>

Closes #20779 from kiszk/SPARK-23598.

(cherry picked from commit 1098933)
Signed-off-by: Herman van Hovell <[email protected]>
@rednaxelafx
Copy link
Contributor

Sorry for the late comment. This PR itself is LGTM. I'd just like to make some side comments on why this bug is happening.

Janino uses a peculiar encoding of implementing bridge methods for extended accessibility from an inner class to members of its enclosing class. Here we're actually hitting a bug in Janino where it missed creating bridge methods on the enclosing class (GeneratedClass...) for protected members that it inherited from a base class (BufferedRowIterator).
I've seen this bug in Janino before, and I plan to fix it in Janino soon. Once it's fixed in Janino, we can safely use protected members such as append and stopEarly in nested classes within GeneratedClass... again. Would anybody be interested in switching these methods back to protected once it's fixed in Janino and Spark bumps the Janino dependency to that new version?

Thanks!

mstewart141 pushed a commit to mstewart141/spark that referenced this pull request Mar 24, 2018
…d runtime error for a large query

## What changes were proposed in this pull request?

This PR fixes runtime error regarding a large query when a generated code has split classes. The issue is `append()`, `stopEarly()`, and other methods are not accessible from split classes that are not subclasses of `BufferedRowIterator`.
This PR fixes this issue by making them `public`.

Before applying the PR, we see the following exception by running the attached program with `CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD=-1`.
```
  test("SPARK-23598") {
    // When set -1 to CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD, an exception is thrown
    val df_pet_age = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name")
    df_pet_age.groupBy("name").avg("age").show()
  }
```

Exception:
```
19:40:52.591 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
19:41:32.319 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 0.0 (TID 0)
java.lang.IllegalAccessError: tried to access method org.apache.spark.sql.execution.BufferedRowIterator.shouldStop()Z from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1.agg_doAggregateWithKeys$(generated.java:203)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(generated.java:160)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:616)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:745)
...
```

Generated code (line 195 calles `stopEarly()`).
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean agg_initAgg;
/* 010 */   private boolean agg_bufIsNull;
/* 011 */   private double agg_bufValue;
/* 012 */   private boolean agg_bufIsNull1;
/* 013 */   private long agg_bufValue1;
/* 014 */   private agg_FastHashMap agg_fastHashMap;
/* 015 */   private org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> agg_fastHashMapIter;
/* 016 */   private org.apache.spark.unsafe.KVIterator agg_mapIter;
/* 017 */   private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap;
/* 018 */   private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter;
/* 019 */   private scala.collection.Iterator inputadapter_input;
/* 020 */   private boolean agg_agg_isNull11;
/* 021 */   private boolean agg_agg_isNull25;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[] agg_mutableStateArray1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[2];
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 024 */   private UnsafeRow[] agg_mutableStateArray = new UnsafeRow[2];
/* 025 */
/* 026 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */
/* 034 */     agg_fastHashMap = new agg_FastHashMap(((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskMemoryManager(), ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getEmptyAggregationBuffer());
/* 035 */     agg_hashMap = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).createHashMap();
/* 036 */     inputadapter_input = inputs[0];
/* 037 */     agg_mutableStateArray[0] = new UnsafeRow(1);
/* 038 */     agg_mutableStateArray1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[0], 32);
/* 039 */     agg_mutableStateArray2[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[0], 1);
/* 040 */     agg_mutableStateArray[1] = new UnsafeRow(3);
/* 041 */     agg_mutableStateArray1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[1], 32);
/* 042 */     agg_mutableStateArray2[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[1], 3);
/* 043 */
/* 044 */   }
/* 045 */
/* 046 */   public class agg_FastHashMap {
/* 047 */     private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
/* 048 */     private int[] buckets;
/* 049 */     private int capacity = 1 << 16;
/* 050 */     private double loadFactor = 0.5;
/* 051 */     private int numBuckets = (int) (capacity / loadFactor);
/* 052 */     private int maxSteps = 2;
/* 053 */     private int numRows = 0;
/* 054 */     private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1] /* keyName */), org.apache.spark.sql.types.DataTypes.StringType);
/* 055 */     private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2] /* keyName */), org.apache.spark.sql.types.DataTypes.DoubleType)
/* 056 */     .add(((java.lang.String) references[3] /* keyName */), org.apache.spark.sql.types.DataTypes.LongType);
/* 057 */     private Object emptyVBase;
/* 058 */     private long emptyVOff;
/* 059 */     private int emptyVLen;
/* 060 */     private boolean isBatchFull = false;
/* 061 */
/* 062 */     public agg_FastHashMap(
/* 063 */       org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
/* 064 */       InternalRow emptyAggregationBuffer) {
/* 065 */       batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
/* 066 */       .allocate(keySchema, valueSchema, taskMemoryManager, capacity);
/* 067 */
/* 068 */       final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema);
/* 069 */       final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
/* 070 */
/* 071 */       emptyVBase = emptyBuffer;
/* 072 */       emptyVOff = Platform.BYTE_ARRAY_OFFSET;
/* 073 */       emptyVLen = emptyBuffer.length;
/* 074 */
/* 075 */       buckets = new int[numBuckets];
/* 076 */       java.util.Arrays.fill(buckets, -1);
/* 077 */     }
/* 078 */
/* 079 */     public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) {
/* 080 */       long h = hash(agg_key);
/* 081 */       int step = 0;
/* 082 */       int idx = (int) h & (numBuckets - 1);
/* 083 */       while (step < maxSteps) {
/* 084 */         // Return bucket index if it's either an empty slot or already contains the key
/* 085 */         if (buckets[idx] == -1) {
/* 086 */           if (numRows < capacity && !isBatchFull) {
/* 087 */             // creating the unsafe for new entry
/* 088 */             UnsafeRow agg_result = new UnsafeRow(1);
/* 089 */             org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder
/* 090 */             = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result,
/* 091 */               32);
/* 092 */             org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
/* 093 */             = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
/* 094 */               agg_holder,
/* 095 */               1);
/* 096 */             agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed
/* 097 */             agg_rowWriter.zeroOutNullBytes();
/* 098 */             agg_rowWriter.write(0, agg_key);
/* 099 */             agg_result.setTotalSize(agg_holder.totalSize());
/* 100 */             Object kbase = agg_result.getBaseObject();
/* 101 */             long koff = agg_result.getBaseOffset();
/* 102 */             int klen = agg_result.getSizeInBytes();
/* 103 */
/* 104 */             UnsafeRow vRow
/* 105 */             = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen);
/* 106 */             if (vRow == null) {
/* 107 */               isBatchFull = true;
/* 108 */             } else {
/* 109 */               buckets[idx] = numRows++;
/* 110 */             }
/* 111 */             return vRow;
/* 112 */           } else {
/* 113 */             // No more space
/* 114 */             return null;
/* 115 */           }
/* 116 */         } else if (equals(idx, agg_key)) {
/* 117 */           return batch.getValueRow(buckets[idx]);
/* 118 */         }
/* 119 */         idx = (idx + 1) & (numBuckets - 1);
/* 120 */         step++;
/* 121 */       }
/* 122 */       // Didn't find it
/* 123 */       return null;
/* 124 */     }
/* 125 */
/* 126 */     private boolean equals(int idx, UTF8String agg_key) {
/* 127 */       UnsafeRow row = batch.getKeyRow(buckets[idx]);
/* 128 */       return (row.getUTF8String(0).equals(agg_key));
/* 129 */     }
/* 130 */
/* 131 */     private long hash(UTF8String agg_key) {
/* 132 */       long agg_hash = 0;
/* 133 */
/* 134 */       int agg_result = 0;
/* 135 */       byte[] agg_bytes = agg_key.getBytes();
/* 136 */       for (int i = 0; i < agg_bytes.length; i++) {
/* 137 */         int agg_hash1 = agg_bytes[i];
/* 138 */         agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2);
/* 139 */       }
/* 140 */
/* 141 */       agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2);
/* 142 */
/* 143 */       return agg_hash;
/* 144 */     }
/* 145 */
/* 146 */     public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
/* 147 */       return batch.rowIterator();
/* 148 */     }
/* 149 */
/* 150 */     public void close() {
/* 151 */       batch.close();
/* 152 */     }
/* 153 */
/* 154 */   }
/* 155 */
/* 156 */   protected void processNext() throws java.io.IOException {
/* 157 */     if (!agg_initAgg) {
/* 158 */       agg_initAgg = true;
/* 159 */       long wholestagecodegen_beforeAgg = System.nanoTime();
/* 160 */       agg_nestedClassInstance1.agg_doAggregateWithKeys();
/* 161 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[8] /* aggTime */).add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000);
/* 162 */     }
/* 163 */
/* 164 */     // output the result
/* 165 */
/* 166 */     while (agg_fastHashMapIter.next()) {
/* 167 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey();
/* 168 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue();
/* 169 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 170 */
/* 171 */       if (shouldStop()) return;
/* 172 */     }
/* 173 */     agg_fastHashMap.close();
/* 174 */
/* 175 */     while (agg_mapIter.next()) {
/* 176 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey();
/* 177 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue();
/* 178 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 179 */
/* 180 */       if (shouldStop()) return;
/* 181 */     }
/* 182 */
/* 183 */     agg_mapIter.close();
/* 184 */     if (agg_sorter == null) {
/* 185 */       agg_hashMap.free();
/* 186 */     }
/* 187 */   }
/* 188 */
/* 189 */   private wholestagecodegen_NestedClass wholestagecodegen_nestedClassInstance = new wholestagecodegen_NestedClass();
/* 190 */   private agg_NestedClass1 agg_nestedClassInstance1 = new agg_NestedClass1();
/* 191 */   private agg_NestedClass agg_nestedClassInstance = new agg_NestedClass();
/* 192 */
/* 193 */   private class agg_NestedClass1 {
/* 194 */     private void agg_doAggregateWithKeys() throws java.io.IOException {
/* 195 */       while (inputadapter_input.hasNext() && !stopEarly()) {
/* 196 */         InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 197 */         int inputadapter_value = inputadapter_row.getInt(0);
/* 198 */         boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 199 */         UTF8String inputadapter_value1 = inputadapter_isNull1 ?
/* 200 */         null : (inputadapter_row.getUTF8String(1));
/* 201 */
/* 202 */         agg_nestedClassInstance.agg_doConsume(inputadapter_row, inputadapter_value, inputadapter_value1, inputadapter_isNull1);
/* 203 */         if (shouldStop()) return;
/* 204 */       }
/* 205 */
/* 206 */       agg_fastHashMapIter = agg_fastHashMap.rowIterator();
/* 207 */       agg_mapIter = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap, agg_sorter, ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[6] /* avgHashProbe */));
/* 208 */
/* 209 */     }
/* 210 */
/* 211 */   }
/* 212 */
/* 213 */   private class wholestagecodegen_NestedClass {
/* 214 */     private void agg_doAggregateWithKeysOutput(UnsafeRow agg_keyTerm, UnsafeRow agg_bufferTerm)
/* 215 */     throws java.io.IOException {
/* 216 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* numOutputRows */).add(1);
/* 217 */
/* 218 */       boolean agg_isNull35 = agg_keyTerm.isNullAt(0);
/* 219 */       UTF8String agg_value37 = agg_isNull35 ?
/* 220 */       null : (agg_keyTerm.getUTF8String(0));
/* 221 */       boolean agg_isNull36 = agg_bufferTerm.isNullAt(0);
/* 222 */       double agg_value38 = agg_isNull36 ?
/* 223 */       -1.0 : (agg_bufferTerm.getDouble(0));
/* 224 */       boolean agg_isNull37 = agg_bufferTerm.isNullAt(1);
/* 225 */       long agg_value39 = agg_isNull37 ?
/* 226 */       -1L : (agg_bufferTerm.getLong(1));
/* 227 */
/* 228 */       agg_mutableStateArray1[1].reset();
/* 229 */
/* 230 */       agg_mutableStateArray2[1].zeroOutNullBytes();
/* 231 */
/* 232 */       if (agg_isNull35) {
/* 233 */         agg_mutableStateArray2[1].setNullAt(0);
/* 234 */       } else {
/* 235 */         agg_mutableStateArray2[1].write(0, agg_value37);
/* 236 */       }
/* 237 */
/* 238 */       if (agg_isNull36) {
/* 239 */         agg_mutableStateArray2[1].setNullAt(1);
/* 240 */       } else {
/* 241 */         agg_mutableStateArray2[1].write(1, agg_value38);
/* 242 */       }
/* 243 */
/* 244 */       if (agg_isNull37) {
/* 245 */         agg_mutableStateArray2[1].setNullAt(2);
/* 246 */       } else {
/* 247 */         agg_mutableStateArray2[1].write(2, agg_value39);
/* 248 */       }
/* 249 */       agg_mutableStateArray[1].setTotalSize(agg_mutableStateArray1[1].totalSize());
/* 250 */       append(agg_mutableStateArray[1]);
/* 251 */
/* 252 */     }
/* 253 */
/* 254 */   }
/* 255 */
/* 256 */   private class agg_NestedClass {
/* 257 */     private void agg_doConsume(InternalRow inputadapter_row, int agg_expr_0, UTF8String agg_expr_1, boolean agg_exprIsNull_1) throws java.io.IOException {
/* 258 */       UnsafeRow agg_unsafeRowAggBuffer = null;
/* 259 */       UnsafeRow agg_fastAggBuffer = null;
/* 260 */
/* 261 */       if (true) {
/* 262 */         if (!agg_exprIsNull_1) {
/* 263 */           agg_fastAggBuffer = agg_fastHashMap.findOrInsert(
/* 264 */             agg_expr_1);
/* 265 */         }
/* 266 */       }
/* 267 */       // Cannot find the key in fast hash map, try regular hash map.
/* 268 */       if (agg_fastAggBuffer == null) {
/* 269 */         // generate grouping key
/* 270 */         agg_mutableStateArray1[0].reset();
/* 271 */
/* 272 */         agg_mutableStateArray2[0].zeroOutNullBytes();
/* 273 */
/* 274 */         if (agg_exprIsNull_1) {
/* 275 */           agg_mutableStateArray2[0].setNullAt(0);
/* 276 */         } else {
/* 277 */           agg_mutableStateArray2[0].write(0, agg_expr_1);
/* 278 */         }
/* 279 */         agg_mutableStateArray[0].setTotalSize(agg_mutableStateArray1[0].totalSize());
/* 280 */         int agg_value7 = 42;
/* 281 */
/* 282 */         if (!agg_exprIsNull_1) {
/* 283 */           agg_value7 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(agg_expr_1.getBaseObject(), agg_expr_1.getBaseOffset(), agg_expr_1.numBytes(), agg_value7);
/* 284 */         }
/* 285 */         if (true) {
/* 286 */           // try to get the buffer from hash map
/* 287 */           agg_unsafeRowAggBuffer =
/* 288 */           agg_hashMap.getAggregationBufferFromUnsafeRow(agg_mutableStateArray[0], agg_value7);
/* 289 */         }
/* 290 */         // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
/* 291 */         // aggregation after processing all input rows.
/* 292 */         if (agg_unsafeRowAggBuffer == null) {
/* 293 */           if (agg_sorter == null) {
/* 294 */             agg_sorter = agg_hashMap.destructAndCreateExternalSorter();
/* 295 */           } else {
/* 296 */             agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter());
/* 297 */           }
/* 298 */
/* 299 */           // the hash map had be spilled, it should have enough memory now,
/* 300 */           // try to allocate buffer again.
/* 301 */           agg_unsafeRowAggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(
/* 302 */             agg_mutableStateArray[0], agg_value7);
/* 303 */           if (agg_unsafeRowAggBuffer == null) {
/* 304 */             // failed to allocate the first page
/* 305 */             throw new OutOfMemoryError("No enough memory for aggregation");
/* 306 */           }
/* 307 */         }
/* 308 */
/* 309 */       }
/* 310 */
/* 311 */       if (agg_fastAggBuffer != null) {
/* 312 */         // common sub-expressions
/* 313 */         boolean agg_isNull21 = false;
/* 314 */         long agg_value23 = -1L;
/* 315 */         if (!false) {
/* 316 */           agg_value23 = (long) agg_expr_0;
/* 317 */         }
/* 318 */         // evaluate aggregate function
/* 319 */         boolean agg_isNull23 = true;
/* 320 */         double agg_value25 = -1.0;
/* 321 */
/* 322 */         boolean agg_isNull24 = agg_fastAggBuffer.isNullAt(0);
/* 323 */         double agg_value26 = agg_isNull24 ?
/* 324 */         -1.0 : (agg_fastAggBuffer.getDouble(0));
/* 325 */         if (!agg_isNull24) {
/* 326 */           agg_agg_isNull25 = true;
/* 327 */           double agg_value27 = -1.0;
/* 328 */           do {
/* 329 */             boolean agg_isNull26 = agg_isNull21;
/* 330 */             double agg_value28 = -1.0;
/* 331 */             if (!agg_isNull21) {
/* 332 */               agg_value28 = (double) agg_value23;
/* 333 */             }
/* 334 */             if (!agg_isNull26) {
/* 335 */               agg_agg_isNull25 = false;
/* 336 */               agg_value27 = agg_value28;
/* 337 */               continue;
/* 338 */             }
/* 339 */
/* 340 */             boolean agg_isNull27 = false;
/* 341 */             double agg_value29 = -1.0;
/* 342 */             if (!false) {
/* 343 */               agg_value29 = (double) 0;
/* 344 */             }
/* 345 */             if (!agg_isNull27) {
/* 346 */               agg_agg_isNull25 = false;
/* 347 */               agg_value27 = agg_value29;
/* 348 */               continue;
/* 349 */             }
/* 350 */
/* 351 */           } while (false);
/* 352 */
/* 353 */           agg_isNull23 = false; // resultCode could change nullability.
/* 354 */           agg_value25 = agg_value26 + agg_value27;
/* 355 */
/* 356 */         }
/* 357 */         boolean agg_isNull29 = false;
/* 358 */         long agg_value31 = -1L;
/* 359 */         if (!false && agg_isNull21) {
/* 360 */           boolean agg_isNull31 = agg_fastAggBuffer.isNullAt(1);
/* 361 */           long agg_value33 = agg_isNull31 ?
/* 362 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 363 */           agg_isNull29 = agg_isNull31;
/* 364 */           agg_value31 = agg_value33;
/* 365 */         } else {
/* 366 */           boolean agg_isNull32 = true;
/* 367 */           long agg_value34 = -1L;
/* 368 */
/* 369 */           boolean agg_isNull33 = agg_fastAggBuffer.isNullAt(1);
/* 370 */           long agg_value35 = agg_isNull33 ?
/* 371 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 372 */           if (!agg_isNull33) {
/* 373 */             agg_isNull32 = false; // resultCode could change nullability.
/* 374 */             agg_value34 = agg_value35 + 1L;
/* 375 */
/* 376 */           }
/* 377 */           agg_isNull29 = agg_isNull32;
/* 378 */           agg_value31 = agg_value34;
/* 379 */         }
/* 380 */         // update fast row
/* 381 */         if (!agg_isNull23) {
/* 382 */           agg_fastAggBuffer.setDouble(0, agg_value25);
/* 383 */         } else {
/* 384 */           agg_fastAggBuffer.setNullAt(0);
/* 385 */         }
/* 386 */
/* 387 */         if (!agg_isNull29) {
/* 388 */           agg_fastAggBuffer.setLong(1, agg_value31);
/* 389 */         } else {
/* 390 */           agg_fastAggBuffer.setNullAt(1);
/* 391 */         }
/* 392 */       } else {
/* 393 */         // common sub-expressions
/* 394 */         boolean agg_isNull7 = false;
/* 395 */         long agg_value9 = -1L;
/* 396 */         if (!false) {
/* 397 */           agg_value9 = (long) agg_expr_0;
/* 398 */         }
/* 399 */         // evaluate aggregate function
/* 400 */         boolean agg_isNull9 = true;
/* 401 */         double agg_value11 = -1.0;
/* 402 */
/* 403 */         boolean agg_isNull10 = agg_unsafeRowAggBuffer.isNullAt(0);
/* 404 */         double agg_value12 = agg_isNull10 ?
/* 405 */         -1.0 : (agg_unsafeRowAggBuffer.getDouble(0));
/* 406 */         if (!agg_isNull10) {
/* 407 */           agg_agg_isNull11 = true;
/* 408 */           double agg_value13 = -1.0;
/* 409 */           do {
/* 410 */             boolean agg_isNull12 = agg_isNull7;
/* 411 */             double agg_value14 = -1.0;
/* 412 */             if (!agg_isNull7) {
/* 413 */               agg_value14 = (double) agg_value9;
/* 414 */             }
/* 415 */             if (!agg_isNull12) {
/* 416 */               agg_agg_isNull11 = false;
/* 417 */               agg_value13 = agg_value14;
/* 418 */               continue;
/* 419 */             }
/* 420 */
/* 421 */             boolean agg_isNull13 = false;
/* 422 */             double agg_value15 = -1.0;
/* 423 */             if (!false) {
/* 424 */               agg_value15 = (double) 0;
/* 425 */             }
/* 426 */             if (!agg_isNull13) {
/* 427 */               agg_agg_isNull11 = false;
/* 428 */               agg_value13 = agg_value15;
/* 429 */               continue;
/* 430 */             }
/* 431 */
/* 432 */           } while (false);
/* 433 */
/* 434 */           agg_isNull9 = false; // resultCode could change nullability.
/* 435 */           agg_value11 = agg_value12 + agg_value13;
/* 436 */
/* 437 */         }
/* 438 */         boolean agg_isNull15 = false;
/* 439 */         long agg_value17 = -1L;
/* 440 */         if (!false && agg_isNull7) {
/* 441 */           boolean agg_isNull17 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 442 */           long agg_value19 = agg_isNull17 ?
/* 443 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 444 */           agg_isNull15 = agg_isNull17;
/* 445 */           agg_value17 = agg_value19;
/* 446 */         } else {
/* 447 */           boolean agg_isNull18 = true;
/* 448 */           long agg_value20 = -1L;
/* 449 */
/* 450 */           boolean agg_isNull19 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 451 */           long agg_value21 = agg_isNull19 ?
/* 452 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 453 */           if (!agg_isNull19) {
/* 454 */             agg_isNull18 = false; // resultCode could change nullability.
/* 455 */             agg_value20 = agg_value21 + 1L;
/* 456 */
/* 457 */           }
/* 458 */           agg_isNull15 = agg_isNull18;
/* 459 */           agg_value17 = agg_value20;
/* 460 */         }
/* 461 */         // update unsafe row buffer
/* 462 */         if (!agg_isNull9) {
/* 463 */           agg_unsafeRowAggBuffer.setDouble(0, agg_value11);
/* 464 */         } else {
/* 465 */           agg_unsafeRowAggBuffer.setNullAt(0);
/* 466 */         }
/* 467 */
/* 468 */         if (!agg_isNull15) {
/* 469 */           agg_unsafeRowAggBuffer.setLong(1, agg_value17);
/* 470 */         } else {
/* 471 */           agg_unsafeRowAggBuffer.setNullAt(1);
/* 472 */         }
/* 473 */
/* 474 */       }
/* 475 */
/* 476 */     }
/* 477 */
/* 478 */   }
/* 479 */
/* 480 */ }
```

## How was this patch tested?

Added UT into `WholeStageCodegenSuite`

Author: Kazuaki Ishizaki <[email protected]>

Closes apache#20779 from kiszk/SPARK-23598.
gitfy pushed a commit to gitfy/spark that referenced this pull request May 21, 2018
…d runtime error for a large query

## What changes were proposed in this pull request?

This PR fixes runtime error regarding a large query when a generated code has split classes. The issue is `append()`, `stopEarly()`, and other methods are not accessible from split classes that are not subclasses of `BufferedRowIterator`.
This PR fixes this issue by making them `public`.

Before applying the PR, we see the following exception by running the attached program with `CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD=-1`.
```
  test("SPARK-23598") {
    // When set -1 to CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD, an exception is thrown
    val df_pet_age = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name")
    df_pet_age.groupBy("name").avg("age").show()
  }
```

Exception:
```
19:40:52.591 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
19:41:32.319 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 0.0 (TID 0)
java.lang.IllegalAccessError: tried to access method org.apache.spark.sql.execution.BufferedRowIterator.shouldStop()Z from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1.agg_doAggregateWithKeys$(generated.java:203)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(generated.java:160)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:616)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:745)
...
```

Generated code (line 195 calles `stopEarly()`).
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean agg_initAgg;
/* 010 */   private boolean agg_bufIsNull;
/* 011 */   private double agg_bufValue;
/* 012 */   private boolean agg_bufIsNull1;
/* 013 */   private long agg_bufValue1;
/* 014 */   private agg_FastHashMap agg_fastHashMap;
/* 015 */   private org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> agg_fastHashMapIter;
/* 016 */   private org.apache.spark.unsafe.KVIterator agg_mapIter;
/* 017 */   private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap;
/* 018 */   private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter;
/* 019 */   private scala.collection.Iterator inputadapter_input;
/* 020 */   private boolean agg_agg_isNull11;
/* 021 */   private boolean agg_agg_isNull25;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[] agg_mutableStateArray1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[2];
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 024 */   private UnsafeRow[] agg_mutableStateArray = new UnsafeRow[2];
/* 025 */
/* 026 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */
/* 034 */     agg_fastHashMap = new agg_FastHashMap(((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskMemoryManager(), ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getEmptyAggregationBuffer());
/* 035 */     agg_hashMap = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).createHashMap();
/* 036 */     inputadapter_input = inputs[0];
/* 037 */     agg_mutableStateArray[0] = new UnsafeRow(1);
/* 038 */     agg_mutableStateArray1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[0], 32);
/* 039 */     agg_mutableStateArray2[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[0], 1);
/* 040 */     agg_mutableStateArray[1] = new UnsafeRow(3);
/* 041 */     agg_mutableStateArray1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[1], 32);
/* 042 */     agg_mutableStateArray2[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[1], 3);
/* 043 */
/* 044 */   }
/* 045 */
/* 046 */   public class agg_FastHashMap {
/* 047 */     private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
/* 048 */     private int[] buckets;
/* 049 */     private int capacity = 1 << 16;
/* 050 */     private double loadFactor = 0.5;
/* 051 */     private int numBuckets = (int) (capacity / loadFactor);
/* 052 */     private int maxSteps = 2;
/* 053 */     private int numRows = 0;
/* 054 */     private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1] /* keyName */), org.apache.spark.sql.types.DataTypes.StringType);
/* 055 */     private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2] /* keyName */), org.apache.spark.sql.types.DataTypes.DoubleType)
/* 056 */     .add(((java.lang.String) references[3] /* keyName */), org.apache.spark.sql.types.DataTypes.LongType);
/* 057 */     private Object emptyVBase;
/* 058 */     private long emptyVOff;
/* 059 */     private int emptyVLen;
/* 060 */     private boolean isBatchFull = false;
/* 061 */
/* 062 */     public agg_FastHashMap(
/* 063 */       org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
/* 064 */       InternalRow emptyAggregationBuffer) {
/* 065 */       batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
/* 066 */       .allocate(keySchema, valueSchema, taskMemoryManager, capacity);
/* 067 */
/* 068 */       final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema);
/* 069 */       final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
/* 070 */
/* 071 */       emptyVBase = emptyBuffer;
/* 072 */       emptyVOff = Platform.BYTE_ARRAY_OFFSET;
/* 073 */       emptyVLen = emptyBuffer.length;
/* 074 */
/* 075 */       buckets = new int[numBuckets];
/* 076 */       java.util.Arrays.fill(buckets, -1);
/* 077 */     }
/* 078 */
/* 079 */     public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) {
/* 080 */       long h = hash(agg_key);
/* 081 */       int step = 0;
/* 082 */       int idx = (int) h & (numBuckets - 1);
/* 083 */       while (step < maxSteps) {
/* 084 */         // Return bucket index if it's either an empty slot or already contains the key
/* 085 */         if (buckets[idx] == -1) {
/* 086 */           if (numRows < capacity && !isBatchFull) {
/* 087 */             // creating the unsafe for new entry
/* 088 */             UnsafeRow agg_result = new UnsafeRow(1);
/* 089 */             org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder
/* 090 */             = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result,
/* 091 */               32);
/* 092 */             org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
/* 093 */             = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
/* 094 */               agg_holder,
/* 095 */               1);
/* 096 */             agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed
/* 097 */             agg_rowWriter.zeroOutNullBytes();
/* 098 */             agg_rowWriter.write(0, agg_key);
/* 099 */             agg_result.setTotalSize(agg_holder.totalSize());
/* 100 */             Object kbase = agg_result.getBaseObject();
/* 101 */             long koff = agg_result.getBaseOffset();
/* 102 */             int klen = agg_result.getSizeInBytes();
/* 103 */
/* 104 */             UnsafeRow vRow
/* 105 */             = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen);
/* 106 */             if (vRow == null) {
/* 107 */               isBatchFull = true;
/* 108 */             } else {
/* 109 */               buckets[idx] = numRows++;
/* 110 */             }
/* 111 */             return vRow;
/* 112 */           } else {
/* 113 */             // No more space
/* 114 */             return null;
/* 115 */           }
/* 116 */         } else if (equals(idx, agg_key)) {
/* 117 */           return batch.getValueRow(buckets[idx]);
/* 118 */         }
/* 119 */         idx = (idx + 1) & (numBuckets - 1);
/* 120 */         step++;
/* 121 */       }
/* 122 */       // Didn't find it
/* 123 */       return null;
/* 124 */     }
/* 125 */
/* 126 */     private boolean equals(int idx, UTF8String agg_key) {
/* 127 */       UnsafeRow row = batch.getKeyRow(buckets[idx]);
/* 128 */       return (row.getUTF8String(0).equals(agg_key));
/* 129 */     }
/* 130 */
/* 131 */     private long hash(UTF8String agg_key) {
/* 132 */       long agg_hash = 0;
/* 133 */
/* 134 */       int agg_result = 0;
/* 135 */       byte[] agg_bytes = agg_key.getBytes();
/* 136 */       for (int i = 0; i < agg_bytes.length; i++) {
/* 137 */         int agg_hash1 = agg_bytes[i];
/* 138 */         agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2);
/* 139 */       }
/* 140 */
/* 141 */       agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2);
/* 142 */
/* 143 */       return agg_hash;
/* 144 */     }
/* 145 */
/* 146 */     public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
/* 147 */       return batch.rowIterator();
/* 148 */     }
/* 149 */
/* 150 */     public void close() {
/* 151 */       batch.close();
/* 152 */     }
/* 153 */
/* 154 */   }
/* 155 */
/* 156 */   protected void processNext() throws java.io.IOException {
/* 157 */     if (!agg_initAgg) {
/* 158 */       agg_initAgg = true;
/* 159 */       long wholestagecodegen_beforeAgg = System.nanoTime();
/* 160 */       agg_nestedClassInstance1.agg_doAggregateWithKeys();
/* 161 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[8] /* aggTime */).add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000);
/* 162 */     }
/* 163 */
/* 164 */     // output the result
/* 165 */
/* 166 */     while (agg_fastHashMapIter.next()) {
/* 167 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey();
/* 168 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue();
/* 169 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 170 */
/* 171 */       if (shouldStop()) return;
/* 172 */     }
/* 173 */     agg_fastHashMap.close();
/* 174 */
/* 175 */     while (agg_mapIter.next()) {
/* 176 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey();
/* 177 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue();
/* 178 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 179 */
/* 180 */       if (shouldStop()) return;
/* 181 */     }
/* 182 */
/* 183 */     agg_mapIter.close();
/* 184 */     if (agg_sorter == null) {
/* 185 */       agg_hashMap.free();
/* 186 */     }
/* 187 */   }
/* 188 */
/* 189 */   private wholestagecodegen_NestedClass wholestagecodegen_nestedClassInstance = new wholestagecodegen_NestedClass();
/* 190 */   private agg_NestedClass1 agg_nestedClassInstance1 = new agg_NestedClass1();
/* 191 */   private agg_NestedClass agg_nestedClassInstance = new agg_NestedClass();
/* 192 */
/* 193 */   private class agg_NestedClass1 {
/* 194 */     private void agg_doAggregateWithKeys() throws java.io.IOException {
/* 195 */       while (inputadapter_input.hasNext() && !stopEarly()) {
/* 196 */         InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 197 */         int inputadapter_value = inputadapter_row.getInt(0);
/* 198 */         boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 199 */         UTF8String inputadapter_value1 = inputadapter_isNull1 ?
/* 200 */         null : (inputadapter_row.getUTF8String(1));
/* 201 */
/* 202 */         agg_nestedClassInstance.agg_doConsume(inputadapter_row, inputadapter_value, inputadapter_value1, inputadapter_isNull1);
/* 203 */         if (shouldStop()) return;
/* 204 */       }
/* 205 */
/* 206 */       agg_fastHashMapIter = agg_fastHashMap.rowIterator();
/* 207 */       agg_mapIter = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap, agg_sorter, ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[6] /* avgHashProbe */));
/* 208 */
/* 209 */     }
/* 210 */
/* 211 */   }
/* 212 */
/* 213 */   private class wholestagecodegen_NestedClass {
/* 214 */     private void agg_doAggregateWithKeysOutput(UnsafeRow agg_keyTerm, UnsafeRow agg_bufferTerm)
/* 215 */     throws java.io.IOException {
/* 216 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* numOutputRows */).add(1);
/* 217 */
/* 218 */       boolean agg_isNull35 = agg_keyTerm.isNullAt(0);
/* 219 */       UTF8String agg_value37 = agg_isNull35 ?
/* 220 */       null : (agg_keyTerm.getUTF8String(0));
/* 221 */       boolean agg_isNull36 = agg_bufferTerm.isNullAt(0);
/* 222 */       double agg_value38 = agg_isNull36 ?
/* 223 */       -1.0 : (agg_bufferTerm.getDouble(0));
/* 224 */       boolean agg_isNull37 = agg_bufferTerm.isNullAt(1);
/* 225 */       long agg_value39 = agg_isNull37 ?
/* 226 */       -1L : (agg_bufferTerm.getLong(1));
/* 227 */
/* 228 */       agg_mutableStateArray1[1].reset();
/* 229 */
/* 230 */       agg_mutableStateArray2[1].zeroOutNullBytes();
/* 231 */
/* 232 */       if (agg_isNull35) {
/* 233 */         agg_mutableStateArray2[1].setNullAt(0);
/* 234 */       } else {
/* 235 */         agg_mutableStateArray2[1].write(0, agg_value37);
/* 236 */       }
/* 237 */
/* 238 */       if (agg_isNull36) {
/* 239 */         agg_mutableStateArray2[1].setNullAt(1);
/* 240 */       } else {
/* 241 */         agg_mutableStateArray2[1].write(1, agg_value38);
/* 242 */       }
/* 243 */
/* 244 */       if (agg_isNull37) {
/* 245 */         agg_mutableStateArray2[1].setNullAt(2);
/* 246 */       } else {
/* 247 */         agg_mutableStateArray2[1].write(2, agg_value39);
/* 248 */       }
/* 249 */       agg_mutableStateArray[1].setTotalSize(agg_mutableStateArray1[1].totalSize());
/* 250 */       append(agg_mutableStateArray[1]);
/* 251 */
/* 252 */     }
/* 253 */
/* 254 */   }
/* 255 */
/* 256 */   private class agg_NestedClass {
/* 257 */     private void agg_doConsume(InternalRow inputadapter_row, int agg_expr_0, UTF8String agg_expr_1, boolean agg_exprIsNull_1) throws java.io.IOException {
/* 258 */       UnsafeRow agg_unsafeRowAggBuffer = null;
/* 259 */       UnsafeRow agg_fastAggBuffer = null;
/* 260 */
/* 261 */       if (true) {
/* 262 */         if (!agg_exprIsNull_1) {
/* 263 */           agg_fastAggBuffer = agg_fastHashMap.findOrInsert(
/* 264 */             agg_expr_1);
/* 265 */         }
/* 266 */       }
/* 267 */       // Cannot find the key in fast hash map, try regular hash map.
/* 268 */       if (agg_fastAggBuffer == null) {
/* 269 */         // generate grouping key
/* 270 */         agg_mutableStateArray1[0].reset();
/* 271 */
/* 272 */         agg_mutableStateArray2[0].zeroOutNullBytes();
/* 273 */
/* 274 */         if (agg_exprIsNull_1) {
/* 275 */           agg_mutableStateArray2[0].setNullAt(0);
/* 276 */         } else {
/* 277 */           agg_mutableStateArray2[0].write(0, agg_expr_1);
/* 278 */         }
/* 279 */         agg_mutableStateArray[0].setTotalSize(agg_mutableStateArray1[0].totalSize());
/* 280 */         int agg_value7 = 42;
/* 281 */
/* 282 */         if (!agg_exprIsNull_1) {
/* 283 */           agg_value7 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(agg_expr_1.getBaseObject(), agg_expr_1.getBaseOffset(), agg_expr_1.numBytes(), agg_value7);
/* 284 */         }
/* 285 */         if (true) {
/* 286 */           // try to get the buffer from hash map
/* 287 */           agg_unsafeRowAggBuffer =
/* 288 */           agg_hashMap.getAggregationBufferFromUnsafeRow(agg_mutableStateArray[0], agg_value7);
/* 289 */         }
/* 290 */         // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
/* 291 */         // aggregation after processing all input rows.
/* 292 */         if (agg_unsafeRowAggBuffer == null) {
/* 293 */           if (agg_sorter == null) {
/* 294 */             agg_sorter = agg_hashMap.destructAndCreateExternalSorter();
/* 295 */           } else {
/* 296 */             agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter());
/* 297 */           }
/* 298 */
/* 299 */           // the hash map had be spilled, it should have enough memory now,
/* 300 */           // try to allocate buffer again.
/* 301 */           agg_unsafeRowAggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(
/* 302 */             agg_mutableStateArray[0], agg_value7);
/* 303 */           if (agg_unsafeRowAggBuffer == null) {
/* 304 */             // failed to allocate the first page
/* 305 */             throw new OutOfMemoryError("No enough memory for aggregation");
/* 306 */           }
/* 307 */         }
/* 308 */
/* 309 */       }
/* 310 */
/* 311 */       if (agg_fastAggBuffer != null) {
/* 312 */         // common sub-expressions
/* 313 */         boolean agg_isNull21 = false;
/* 314 */         long agg_value23 = -1L;
/* 315 */         if (!false) {
/* 316 */           agg_value23 = (long) agg_expr_0;
/* 317 */         }
/* 318 */         // evaluate aggregate function
/* 319 */         boolean agg_isNull23 = true;
/* 320 */         double agg_value25 = -1.0;
/* 321 */
/* 322 */         boolean agg_isNull24 = agg_fastAggBuffer.isNullAt(0);
/* 323 */         double agg_value26 = agg_isNull24 ?
/* 324 */         -1.0 : (agg_fastAggBuffer.getDouble(0));
/* 325 */         if (!agg_isNull24) {
/* 326 */           agg_agg_isNull25 = true;
/* 327 */           double agg_value27 = -1.0;
/* 328 */           do {
/* 329 */             boolean agg_isNull26 = agg_isNull21;
/* 330 */             double agg_value28 = -1.0;
/* 331 */             if (!agg_isNull21) {
/* 332 */               agg_value28 = (double) agg_value23;
/* 333 */             }
/* 334 */             if (!agg_isNull26) {
/* 335 */               agg_agg_isNull25 = false;
/* 336 */               agg_value27 = agg_value28;
/* 337 */               continue;
/* 338 */             }
/* 339 */
/* 340 */             boolean agg_isNull27 = false;
/* 341 */             double agg_value29 = -1.0;
/* 342 */             if (!false) {
/* 343 */               agg_value29 = (double) 0;
/* 344 */             }
/* 345 */             if (!agg_isNull27) {
/* 346 */               agg_agg_isNull25 = false;
/* 347 */               agg_value27 = agg_value29;
/* 348 */               continue;
/* 349 */             }
/* 350 */
/* 351 */           } while (false);
/* 352 */
/* 353 */           agg_isNull23 = false; // resultCode could change nullability.
/* 354 */           agg_value25 = agg_value26 + agg_value27;
/* 355 */
/* 356 */         }
/* 357 */         boolean agg_isNull29 = false;
/* 358 */         long agg_value31 = -1L;
/* 359 */         if (!false && agg_isNull21) {
/* 360 */           boolean agg_isNull31 = agg_fastAggBuffer.isNullAt(1);
/* 361 */           long agg_value33 = agg_isNull31 ?
/* 362 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 363 */           agg_isNull29 = agg_isNull31;
/* 364 */           agg_value31 = agg_value33;
/* 365 */         } else {
/* 366 */           boolean agg_isNull32 = true;
/* 367 */           long agg_value34 = -1L;
/* 368 */
/* 369 */           boolean agg_isNull33 = agg_fastAggBuffer.isNullAt(1);
/* 370 */           long agg_value35 = agg_isNull33 ?
/* 371 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 372 */           if (!agg_isNull33) {
/* 373 */             agg_isNull32 = false; // resultCode could change nullability.
/* 374 */             agg_value34 = agg_value35 + 1L;
/* 375 */
/* 376 */           }
/* 377 */           agg_isNull29 = agg_isNull32;
/* 378 */           agg_value31 = agg_value34;
/* 379 */         }
/* 380 */         // update fast row
/* 381 */         if (!agg_isNull23) {
/* 382 */           agg_fastAggBuffer.setDouble(0, agg_value25);
/* 383 */         } else {
/* 384 */           agg_fastAggBuffer.setNullAt(0);
/* 385 */         }
/* 386 */
/* 387 */         if (!agg_isNull29) {
/* 388 */           agg_fastAggBuffer.setLong(1, agg_value31);
/* 389 */         } else {
/* 390 */           agg_fastAggBuffer.setNullAt(1);
/* 391 */         }
/* 392 */       } else {
/* 393 */         // common sub-expressions
/* 394 */         boolean agg_isNull7 = false;
/* 395 */         long agg_value9 = -1L;
/* 396 */         if (!false) {
/* 397 */           agg_value9 = (long) agg_expr_0;
/* 398 */         }
/* 399 */         // evaluate aggregate function
/* 400 */         boolean agg_isNull9 = true;
/* 401 */         double agg_value11 = -1.0;
/* 402 */
/* 403 */         boolean agg_isNull10 = agg_unsafeRowAggBuffer.isNullAt(0);
/* 404 */         double agg_value12 = agg_isNull10 ?
/* 405 */         -1.0 : (agg_unsafeRowAggBuffer.getDouble(0));
/* 406 */         if (!agg_isNull10) {
/* 407 */           agg_agg_isNull11 = true;
/* 408 */           double agg_value13 = -1.0;
/* 409 */           do {
/* 410 */             boolean agg_isNull12 = agg_isNull7;
/* 411 */             double agg_value14 = -1.0;
/* 412 */             if (!agg_isNull7) {
/* 413 */               agg_value14 = (double) agg_value9;
/* 414 */             }
/* 415 */             if (!agg_isNull12) {
/* 416 */               agg_agg_isNull11 = false;
/* 417 */               agg_value13 = agg_value14;
/* 418 */               continue;
/* 419 */             }
/* 420 */
/* 421 */             boolean agg_isNull13 = false;
/* 422 */             double agg_value15 = -1.0;
/* 423 */             if (!false) {
/* 424 */               agg_value15 = (double) 0;
/* 425 */             }
/* 426 */             if (!agg_isNull13) {
/* 427 */               agg_agg_isNull11 = false;
/* 428 */               agg_value13 = agg_value15;
/* 429 */               continue;
/* 430 */             }
/* 431 */
/* 432 */           } while (false);
/* 433 */
/* 434 */           agg_isNull9 = false; // resultCode could change nullability.
/* 435 */           agg_value11 = agg_value12 + agg_value13;
/* 436 */
/* 437 */         }
/* 438 */         boolean agg_isNull15 = false;
/* 439 */         long agg_value17 = -1L;
/* 440 */         if (!false && agg_isNull7) {
/* 441 */           boolean agg_isNull17 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 442 */           long agg_value19 = agg_isNull17 ?
/* 443 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 444 */           agg_isNull15 = agg_isNull17;
/* 445 */           agg_value17 = agg_value19;
/* 446 */         } else {
/* 447 */           boolean agg_isNull18 = true;
/* 448 */           long agg_value20 = -1L;
/* 449 */
/* 450 */           boolean agg_isNull19 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 451 */           long agg_value21 = agg_isNull19 ?
/* 452 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 453 */           if (!agg_isNull19) {
/* 454 */             agg_isNull18 = false; // resultCode could change nullability.
/* 455 */             agg_value20 = agg_value21 + 1L;
/* 456 */
/* 457 */           }
/* 458 */           agg_isNull15 = agg_isNull18;
/* 459 */           agg_value17 = agg_value20;
/* 460 */         }
/* 461 */         // update unsafe row buffer
/* 462 */         if (!agg_isNull9) {
/* 463 */           agg_unsafeRowAggBuffer.setDouble(0, agg_value11);
/* 464 */         } else {
/* 465 */           agg_unsafeRowAggBuffer.setNullAt(0);
/* 466 */         }
/* 467 */
/* 468 */         if (!agg_isNull15) {
/* 469 */           agg_unsafeRowAggBuffer.setLong(1, agg_value17);
/* 470 */         } else {
/* 471 */           agg_unsafeRowAggBuffer.setNullAt(1);
/* 472 */         }
/* 473 */
/* 474 */       }
/* 475 */
/* 476 */     }
/* 477 */
/* 478 */   }
/* 479 */
/* 480 */ }
```

## How was this patch tested?

Added UT into `WholeStageCodegenSuite`

Author: Kazuaki Ishizaki <[email protected]>

Closes apache#20779 from kiszk/SPARK-23598.
peter-toth pushed a commit to peter-toth/spark that referenced this pull request Oct 6, 2018
…d runtime error for a large query

## What changes were proposed in this pull request?

This PR fixes runtime error regarding a large query when a generated code has split classes. The issue is `append()`, `stopEarly()`, and other methods are not accessible from split classes that are not subclasses of `BufferedRowIterator`.
This PR fixes this issue by making them `public`.

Before applying the PR, we see the following exception by running the attached program with `CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD=-1`.
```
  test("SPARK-23598") {
    // When set -1 to CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD, an exception is thrown
    val df_pet_age = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name")
    df_pet_age.groupBy("name").avg("age").show()
  }
```

Exception:
```
19:40:52.591 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
19:41:32.319 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 0.0 (TID 0)
java.lang.IllegalAccessError: tried to access method org.apache.spark.sql.execution.BufferedRowIterator.shouldStop()Z from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1.agg_doAggregateWithKeys$(generated.java:203)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(generated.java:160)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:616)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:745)
...
```

Generated code (line 195 calles `stopEarly()`).
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean agg_initAgg;
/* 010 */   private boolean agg_bufIsNull;
/* 011 */   private double agg_bufValue;
/* 012 */   private boolean agg_bufIsNull1;
/* 013 */   private long agg_bufValue1;
/* 014 */   private agg_FastHashMap agg_fastHashMap;
/* 015 */   private org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> agg_fastHashMapIter;
/* 016 */   private org.apache.spark.unsafe.KVIterator agg_mapIter;
/* 017 */   private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap;
/* 018 */   private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter;
/* 019 */   private scala.collection.Iterator inputadapter_input;
/* 020 */   private boolean agg_agg_isNull11;
/* 021 */   private boolean agg_agg_isNull25;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[] agg_mutableStateArray1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[2];
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 024 */   private UnsafeRow[] agg_mutableStateArray = new UnsafeRow[2];
/* 025 */
/* 026 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */
/* 034 */     agg_fastHashMap = new agg_FastHashMap(((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskMemoryManager(), ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getEmptyAggregationBuffer());
/* 035 */     agg_hashMap = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).createHashMap();
/* 036 */     inputadapter_input = inputs[0];
/* 037 */     agg_mutableStateArray[0] = new UnsafeRow(1);
/* 038 */     agg_mutableStateArray1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[0], 32);
/* 039 */     agg_mutableStateArray2[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[0], 1);
/* 040 */     agg_mutableStateArray[1] = new UnsafeRow(3);
/* 041 */     agg_mutableStateArray1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[1], 32);
/* 042 */     agg_mutableStateArray2[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[1], 3);
/* 043 */
/* 044 */   }
/* 045 */
/* 046 */   public class agg_FastHashMap {
/* 047 */     private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
/* 048 */     private int[] buckets;
/* 049 */     private int capacity = 1 << 16;
/* 050 */     private double loadFactor = 0.5;
/* 051 */     private int numBuckets = (int) (capacity / loadFactor);
/* 052 */     private int maxSteps = 2;
/* 053 */     private int numRows = 0;
/* 054 */     private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1] /* keyName */), org.apache.spark.sql.types.DataTypes.StringType);
/* 055 */     private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2] /* keyName */), org.apache.spark.sql.types.DataTypes.DoubleType)
/* 056 */     .add(((java.lang.String) references[3] /* keyName */), org.apache.spark.sql.types.DataTypes.LongType);
/* 057 */     private Object emptyVBase;
/* 058 */     private long emptyVOff;
/* 059 */     private int emptyVLen;
/* 060 */     private boolean isBatchFull = false;
/* 061 */
/* 062 */     public agg_FastHashMap(
/* 063 */       org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
/* 064 */       InternalRow emptyAggregationBuffer) {
/* 065 */       batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
/* 066 */       .allocate(keySchema, valueSchema, taskMemoryManager, capacity);
/* 067 */
/* 068 */       final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema);
/* 069 */       final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
/* 070 */
/* 071 */       emptyVBase = emptyBuffer;
/* 072 */       emptyVOff = Platform.BYTE_ARRAY_OFFSET;
/* 073 */       emptyVLen = emptyBuffer.length;
/* 074 */
/* 075 */       buckets = new int[numBuckets];
/* 076 */       java.util.Arrays.fill(buckets, -1);
/* 077 */     }
/* 078 */
/* 079 */     public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) {
/* 080 */       long h = hash(agg_key);
/* 081 */       int step = 0;
/* 082 */       int idx = (int) h & (numBuckets - 1);
/* 083 */       while (step < maxSteps) {
/* 084 */         // Return bucket index if it's either an empty slot or already contains the key
/* 085 */         if (buckets[idx] == -1) {
/* 086 */           if (numRows < capacity && !isBatchFull) {
/* 087 */             // creating the unsafe for new entry
/* 088 */             UnsafeRow agg_result = new UnsafeRow(1);
/* 089 */             org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder
/* 090 */             = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result,
/* 091 */               32);
/* 092 */             org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
/* 093 */             = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
/* 094 */               agg_holder,
/* 095 */               1);
/* 096 */             agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed
/* 097 */             agg_rowWriter.zeroOutNullBytes();
/* 098 */             agg_rowWriter.write(0, agg_key);
/* 099 */             agg_result.setTotalSize(agg_holder.totalSize());
/* 100 */             Object kbase = agg_result.getBaseObject();
/* 101 */             long koff = agg_result.getBaseOffset();
/* 102 */             int klen = agg_result.getSizeInBytes();
/* 103 */
/* 104 */             UnsafeRow vRow
/* 105 */             = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen);
/* 106 */             if (vRow == null) {
/* 107 */               isBatchFull = true;
/* 108 */             } else {
/* 109 */               buckets[idx] = numRows++;
/* 110 */             }
/* 111 */             return vRow;
/* 112 */           } else {
/* 113 */             // No more space
/* 114 */             return null;
/* 115 */           }
/* 116 */         } else if (equals(idx, agg_key)) {
/* 117 */           return batch.getValueRow(buckets[idx]);
/* 118 */         }
/* 119 */         idx = (idx + 1) & (numBuckets - 1);
/* 120 */         step++;
/* 121 */       }
/* 122 */       // Didn't find it
/* 123 */       return null;
/* 124 */     }
/* 125 */
/* 126 */     private boolean equals(int idx, UTF8String agg_key) {
/* 127 */       UnsafeRow row = batch.getKeyRow(buckets[idx]);
/* 128 */       return (row.getUTF8String(0).equals(agg_key));
/* 129 */     }
/* 130 */
/* 131 */     private long hash(UTF8String agg_key) {
/* 132 */       long agg_hash = 0;
/* 133 */
/* 134 */       int agg_result = 0;
/* 135 */       byte[] agg_bytes = agg_key.getBytes();
/* 136 */       for (int i = 0; i < agg_bytes.length; i++) {
/* 137 */         int agg_hash1 = agg_bytes[i];
/* 138 */         agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2);
/* 139 */       }
/* 140 */
/* 141 */       agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2);
/* 142 */
/* 143 */       return agg_hash;
/* 144 */     }
/* 145 */
/* 146 */     public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
/* 147 */       return batch.rowIterator();
/* 148 */     }
/* 149 */
/* 150 */     public void close() {
/* 151 */       batch.close();
/* 152 */     }
/* 153 */
/* 154 */   }
/* 155 */
/* 156 */   protected void processNext() throws java.io.IOException {
/* 157 */     if (!agg_initAgg) {
/* 158 */       agg_initAgg = true;
/* 159 */       long wholestagecodegen_beforeAgg = System.nanoTime();
/* 160 */       agg_nestedClassInstance1.agg_doAggregateWithKeys();
/* 161 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[8] /* aggTime */).add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000);
/* 162 */     }
/* 163 */
/* 164 */     // output the result
/* 165 */
/* 166 */     while (agg_fastHashMapIter.next()) {
/* 167 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey();
/* 168 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue();
/* 169 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 170 */
/* 171 */       if (shouldStop()) return;
/* 172 */     }
/* 173 */     agg_fastHashMap.close();
/* 174 */
/* 175 */     while (agg_mapIter.next()) {
/* 176 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey();
/* 177 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue();
/* 178 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 179 */
/* 180 */       if (shouldStop()) return;
/* 181 */     }
/* 182 */
/* 183 */     agg_mapIter.close();
/* 184 */     if (agg_sorter == null) {
/* 185 */       agg_hashMap.free();
/* 186 */     }
/* 187 */   }
/* 188 */
/* 189 */   private wholestagecodegen_NestedClass wholestagecodegen_nestedClassInstance = new wholestagecodegen_NestedClass();
/* 190 */   private agg_NestedClass1 agg_nestedClassInstance1 = new agg_NestedClass1();
/* 191 */   private agg_NestedClass agg_nestedClassInstance = new agg_NestedClass();
/* 192 */
/* 193 */   private class agg_NestedClass1 {
/* 194 */     private void agg_doAggregateWithKeys() throws java.io.IOException {
/* 195 */       while (inputadapter_input.hasNext() && !stopEarly()) {
/* 196 */         InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 197 */         int inputadapter_value = inputadapter_row.getInt(0);
/* 198 */         boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 199 */         UTF8String inputadapter_value1 = inputadapter_isNull1 ?
/* 200 */         null : (inputadapter_row.getUTF8String(1));
/* 201 */
/* 202 */         agg_nestedClassInstance.agg_doConsume(inputadapter_row, inputadapter_value, inputadapter_value1, inputadapter_isNull1);
/* 203 */         if (shouldStop()) return;
/* 204 */       }
/* 205 */
/* 206 */       agg_fastHashMapIter = agg_fastHashMap.rowIterator();
/* 207 */       agg_mapIter = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap, agg_sorter, ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[6] /* avgHashProbe */));
/* 208 */
/* 209 */     }
/* 210 */
/* 211 */   }
/* 212 */
/* 213 */   private class wholestagecodegen_NestedClass {
/* 214 */     private void agg_doAggregateWithKeysOutput(UnsafeRow agg_keyTerm, UnsafeRow agg_bufferTerm)
/* 215 */     throws java.io.IOException {
/* 216 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* numOutputRows */).add(1);
/* 217 */
/* 218 */       boolean agg_isNull35 = agg_keyTerm.isNullAt(0);
/* 219 */       UTF8String agg_value37 = agg_isNull35 ?
/* 220 */       null : (agg_keyTerm.getUTF8String(0));
/* 221 */       boolean agg_isNull36 = agg_bufferTerm.isNullAt(0);
/* 222 */       double agg_value38 = agg_isNull36 ?
/* 223 */       -1.0 : (agg_bufferTerm.getDouble(0));
/* 224 */       boolean agg_isNull37 = agg_bufferTerm.isNullAt(1);
/* 225 */       long agg_value39 = agg_isNull37 ?
/* 226 */       -1L : (agg_bufferTerm.getLong(1));
/* 227 */
/* 228 */       agg_mutableStateArray1[1].reset();
/* 229 */
/* 230 */       agg_mutableStateArray2[1].zeroOutNullBytes();
/* 231 */
/* 232 */       if (agg_isNull35) {
/* 233 */         agg_mutableStateArray2[1].setNullAt(0);
/* 234 */       } else {
/* 235 */         agg_mutableStateArray2[1].write(0, agg_value37);
/* 236 */       }
/* 237 */
/* 238 */       if (agg_isNull36) {
/* 239 */         agg_mutableStateArray2[1].setNullAt(1);
/* 240 */       } else {
/* 241 */         agg_mutableStateArray2[1].write(1, agg_value38);
/* 242 */       }
/* 243 */
/* 244 */       if (agg_isNull37) {
/* 245 */         agg_mutableStateArray2[1].setNullAt(2);
/* 246 */       } else {
/* 247 */         agg_mutableStateArray2[1].write(2, agg_value39);
/* 248 */       }
/* 249 */       agg_mutableStateArray[1].setTotalSize(agg_mutableStateArray1[1].totalSize());
/* 250 */       append(agg_mutableStateArray[1]);
/* 251 */
/* 252 */     }
/* 253 */
/* 254 */   }
/* 255 */
/* 256 */   private class agg_NestedClass {
/* 257 */     private void agg_doConsume(InternalRow inputadapter_row, int agg_expr_0, UTF8String agg_expr_1, boolean agg_exprIsNull_1) throws java.io.IOException {
/* 258 */       UnsafeRow agg_unsafeRowAggBuffer = null;
/* 259 */       UnsafeRow agg_fastAggBuffer = null;
/* 260 */
/* 261 */       if (true) {
/* 262 */         if (!agg_exprIsNull_1) {
/* 263 */           agg_fastAggBuffer = agg_fastHashMap.findOrInsert(
/* 264 */             agg_expr_1);
/* 265 */         }
/* 266 */       }
/* 267 */       // Cannot find the key in fast hash map, try regular hash map.
/* 268 */       if (agg_fastAggBuffer == null) {
/* 269 */         // generate grouping key
/* 270 */         agg_mutableStateArray1[0].reset();
/* 271 */
/* 272 */         agg_mutableStateArray2[0].zeroOutNullBytes();
/* 273 */
/* 274 */         if (agg_exprIsNull_1) {
/* 275 */           agg_mutableStateArray2[0].setNullAt(0);
/* 276 */         } else {
/* 277 */           agg_mutableStateArray2[0].write(0, agg_expr_1);
/* 278 */         }
/* 279 */         agg_mutableStateArray[0].setTotalSize(agg_mutableStateArray1[0].totalSize());
/* 280 */         int agg_value7 = 42;
/* 281 */
/* 282 */         if (!agg_exprIsNull_1) {
/* 283 */           agg_value7 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(agg_expr_1.getBaseObject(), agg_expr_1.getBaseOffset(), agg_expr_1.numBytes(), agg_value7);
/* 284 */         }
/* 285 */         if (true) {
/* 286 */           // try to get the buffer from hash map
/* 287 */           agg_unsafeRowAggBuffer =
/* 288 */           agg_hashMap.getAggregationBufferFromUnsafeRow(agg_mutableStateArray[0], agg_value7);
/* 289 */         }
/* 290 */         // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
/* 291 */         // aggregation after processing all input rows.
/* 292 */         if (agg_unsafeRowAggBuffer == null) {
/* 293 */           if (agg_sorter == null) {
/* 294 */             agg_sorter = agg_hashMap.destructAndCreateExternalSorter();
/* 295 */           } else {
/* 296 */             agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter());
/* 297 */           }
/* 298 */
/* 299 */           // the hash map had be spilled, it should have enough memory now,
/* 300 */           // try to allocate buffer again.
/* 301 */           agg_unsafeRowAggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(
/* 302 */             agg_mutableStateArray[0], agg_value7);
/* 303 */           if (agg_unsafeRowAggBuffer == null) {
/* 304 */             // failed to allocate the first page
/* 305 */             throw new OutOfMemoryError("No enough memory for aggregation");
/* 306 */           }
/* 307 */         }
/* 308 */
/* 309 */       }
/* 310 */
/* 311 */       if (agg_fastAggBuffer != null) {
/* 312 */         // common sub-expressions
/* 313 */         boolean agg_isNull21 = false;
/* 314 */         long agg_value23 = -1L;
/* 315 */         if (!false) {
/* 316 */           agg_value23 = (long) agg_expr_0;
/* 317 */         }
/* 318 */         // evaluate aggregate function
/* 319 */         boolean agg_isNull23 = true;
/* 320 */         double agg_value25 = -1.0;
/* 321 */
/* 322 */         boolean agg_isNull24 = agg_fastAggBuffer.isNullAt(0);
/* 323 */         double agg_value26 = agg_isNull24 ?
/* 324 */         -1.0 : (agg_fastAggBuffer.getDouble(0));
/* 325 */         if (!agg_isNull24) {
/* 326 */           agg_agg_isNull25 = true;
/* 327 */           double agg_value27 = -1.0;
/* 328 */           do {
/* 329 */             boolean agg_isNull26 = agg_isNull21;
/* 330 */             double agg_value28 = -1.0;
/* 331 */             if (!agg_isNull21) {
/* 332 */               agg_value28 = (double) agg_value23;
/* 333 */             }
/* 334 */             if (!agg_isNull26) {
/* 335 */               agg_agg_isNull25 = false;
/* 336 */               agg_value27 = agg_value28;
/* 337 */               continue;
/* 338 */             }
/* 339 */
/* 340 */             boolean agg_isNull27 = false;
/* 341 */             double agg_value29 = -1.0;
/* 342 */             if (!false) {
/* 343 */               agg_value29 = (double) 0;
/* 344 */             }
/* 345 */             if (!agg_isNull27) {
/* 346 */               agg_agg_isNull25 = false;
/* 347 */               agg_value27 = agg_value29;
/* 348 */               continue;
/* 349 */             }
/* 350 */
/* 351 */           } while (false);
/* 352 */
/* 353 */           agg_isNull23 = false; // resultCode could change nullability.
/* 354 */           agg_value25 = agg_value26 + agg_value27;
/* 355 */
/* 356 */         }
/* 357 */         boolean agg_isNull29 = false;
/* 358 */         long agg_value31 = -1L;
/* 359 */         if (!false && agg_isNull21) {
/* 360 */           boolean agg_isNull31 = agg_fastAggBuffer.isNullAt(1);
/* 361 */           long agg_value33 = agg_isNull31 ?
/* 362 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 363 */           agg_isNull29 = agg_isNull31;
/* 364 */           agg_value31 = agg_value33;
/* 365 */         } else {
/* 366 */           boolean agg_isNull32 = true;
/* 367 */           long agg_value34 = -1L;
/* 368 */
/* 369 */           boolean agg_isNull33 = agg_fastAggBuffer.isNullAt(1);
/* 370 */           long agg_value35 = agg_isNull33 ?
/* 371 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 372 */           if (!agg_isNull33) {
/* 373 */             agg_isNull32 = false; // resultCode could change nullability.
/* 374 */             agg_value34 = agg_value35 + 1L;
/* 375 */
/* 376 */           }
/* 377 */           agg_isNull29 = agg_isNull32;
/* 378 */           agg_value31 = agg_value34;
/* 379 */         }
/* 380 */         // update fast row
/* 381 */         if (!agg_isNull23) {
/* 382 */           agg_fastAggBuffer.setDouble(0, agg_value25);
/* 383 */         } else {
/* 384 */           agg_fastAggBuffer.setNullAt(0);
/* 385 */         }
/* 386 */
/* 387 */         if (!agg_isNull29) {
/* 388 */           agg_fastAggBuffer.setLong(1, agg_value31);
/* 389 */         } else {
/* 390 */           agg_fastAggBuffer.setNullAt(1);
/* 391 */         }
/* 392 */       } else {
/* 393 */         // common sub-expressions
/* 394 */         boolean agg_isNull7 = false;
/* 395 */         long agg_value9 = -1L;
/* 396 */         if (!false) {
/* 397 */           agg_value9 = (long) agg_expr_0;
/* 398 */         }
/* 399 */         // evaluate aggregate function
/* 400 */         boolean agg_isNull9 = true;
/* 401 */         double agg_value11 = -1.0;
/* 402 */
/* 403 */         boolean agg_isNull10 = agg_unsafeRowAggBuffer.isNullAt(0);
/* 404 */         double agg_value12 = agg_isNull10 ?
/* 405 */         -1.0 : (agg_unsafeRowAggBuffer.getDouble(0));
/* 406 */         if (!agg_isNull10) {
/* 407 */           agg_agg_isNull11 = true;
/* 408 */           double agg_value13 = -1.0;
/* 409 */           do {
/* 410 */             boolean agg_isNull12 = agg_isNull7;
/* 411 */             double agg_value14 = -1.0;
/* 412 */             if (!agg_isNull7) {
/* 413 */               agg_value14 = (double) agg_value9;
/* 414 */             }
/* 415 */             if (!agg_isNull12) {
/* 416 */               agg_agg_isNull11 = false;
/* 417 */               agg_value13 = agg_value14;
/* 418 */               continue;
/* 419 */             }
/* 420 */
/* 421 */             boolean agg_isNull13 = false;
/* 422 */             double agg_value15 = -1.0;
/* 423 */             if (!false) {
/* 424 */               agg_value15 = (double) 0;
/* 425 */             }
/* 426 */             if (!agg_isNull13) {
/* 427 */               agg_agg_isNull11 = false;
/* 428 */               agg_value13 = agg_value15;
/* 429 */               continue;
/* 430 */             }
/* 431 */
/* 432 */           } while (false);
/* 433 */
/* 434 */           agg_isNull9 = false; // resultCode could change nullability.
/* 435 */           agg_value11 = agg_value12 + agg_value13;
/* 436 */
/* 437 */         }
/* 438 */         boolean agg_isNull15 = false;
/* 439 */         long agg_value17 = -1L;
/* 440 */         if (!false && agg_isNull7) {
/* 441 */           boolean agg_isNull17 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 442 */           long agg_value19 = agg_isNull17 ?
/* 443 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 444 */           agg_isNull15 = agg_isNull17;
/* 445 */           agg_value17 = agg_value19;
/* 446 */         } else {
/* 447 */           boolean agg_isNull18 = true;
/* 448 */           long agg_value20 = -1L;
/* 449 */
/* 450 */           boolean agg_isNull19 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 451 */           long agg_value21 = agg_isNull19 ?
/* 452 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 453 */           if (!agg_isNull19) {
/* 454 */             agg_isNull18 = false; // resultCode could change nullability.
/* 455 */             agg_value20 = agg_value21 + 1L;
/* 456 */
/* 457 */           }
/* 458 */           agg_isNull15 = agg_isNull18;
/* 459 */           agg_value17 = agg_value20;
/* 460 */         }
/* 461 */         // update unsafe row buffer
/* 462 */         if (!agg_isNull9) {
/* 463 */           agg_unsafeRowAggBuffer.setDouble(0, agg_value11);
/* 464 */         } else {
/* 465 */           agg_unsafeRowAggBuffer.setNullAt(0);
/* 466 */         }
/* 467 */
/* 468 */         if (!agg_isNull15) {
/* 469 */           agg_unsafeRowAggBuffer.setLong(1, agg_value17);
/* 470 */         } else {
/* 471 */           agg_unsafeRowAggBuffer.setNullAt(1);
/* 472 */         }
/* 473 */
/* 474 */       }
/* 475 */
/* 476 */     }
/* 477 */
/* 478 */   }
/* 479 */
/* 480 */ }
```

## How was this patch tested?

Added UT into `WholeStageCodegenSuite`

Author: Kazuaki Ishizaki <[email protected]>

Closes apache#20779 from kiszk/SPARK-23598.

(cherry picked from commit 1098933)
Signed-off-by: Herman van Hovell <[email protected]>
otterc pushed a commit to linkedin/spark that referenced this pull request Mar 22, 2023
…r public to avoid runtime error for a large query

This PR fixes runtime error regarding a large query when a generated code has split classes. The issue is `append()`, `stopEarly()`, and other methods are not accessible from split classes that are not subclasses of `BufferedRowIterator`.
This PR fixes this issue by making them `public`.

Before applying the PR, we see the following exception by running the attached program with `CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD=-1`.
```
  test("SPARK-23598") {
    // When set -1 to CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD, an exception is thrown
    val df_pet_age = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name")
    df_pet_age.groupBy("name").avg("age").show()
  }
```

Exception:
```
19:40:52.591 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
19:41:32.319 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 0.0 (TID 0)
java.lang.IllegalAccessError: tried to access method org.apache.spark.sql.execution.BufferedRowIterator.shouldStop()Z from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1.agg_doAggregateWithKeys$(generated.java:203)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(generated.java:160)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:616)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:745)
...
```

Generated code (line 195 calles `stopEarly()`).
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean agg_initAgg;
/* 010 */   private boolean agg_bufIsNull;
/* 011 */   private double agg_bufValue;
/* 012 */   private boolean agg_bufIsNull1;
/* 013 */   private long agg_bufValue1;
/* 014 */   private agg_FastHashMap agg_fastHashMap;
/* 015 */   private org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> agg_fastHashMapIter;
/* 016 */   private org.apache.spark.unsafe.KVIterator agg_mapIter;
/* 017 */   private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap;
/* 018 */   private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter;
/* 019 */   private scala.collection.Iterator inputadapter_input;
/* 020 */   private boolean agg_agg_isNull11;
/* 021 */   private boolean agg_agg_isNull25;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[] agg_mutableStateArray1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[2];
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 024 */   private UnsafeRow[] agg_mutableStateArray = new UnsafeRow[2];
/* 025 */
/* 026 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */
/* 034 */     agg_fastHashMap = new agg_FastHashMap(((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskMemoryManager(), ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getEmptyAggregationBuffer());
/* 035 */     agg_hashMap = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).createHashMap();
/* 036 */     inputadapter_input = inputs[0];
/* 037 */     agg_mutableStateArray[0] = new UnsafeRow(1);
/* 038 */     agg_mutableStateArray1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[0], 32);
/* 039 */     agg_mutableStateArray2[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[0], 1);
/* 040 */     agg_mutableStateArray[1] = new UnsafeRow(3);
/* 041 */     agg_mutableStateArray1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[1], 32);
/* 042 */     agg_mutableStateArray2[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[1], 3);
/* 043 */
/* 044 */   }
/* 045 */
/* 046 */   public class agg_FastHashMap {
/* 047 */     private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
/* 048 */     private int[] buckets;
/* 049 */     private int capacity = 1 << 16;
/* 050 */     private double loadFactor = 0.5;
/* 051 */     private int numBuckets = (int) (capacity / loadFactor);
/* 052 */     private int maxSteps = 2;
/* 053 */     private int numRows = 0;
/* 054 */     private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1] /* keyName */), org.apache.spark.sql.types.DataTypes.StringType);
/* 055 */     private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2] /* keyName */), org.apache.spark.sql.types.DataTypes.DoubleType)
/* 056 */     .add(((java.lang.String) references[3] /* keyName */), org.apache.spark.sql.types.DataTypes.LongType);
/* 057 */     private Object emptyVBase;
/* 058 */     private long emptyVOff;
/* 059 */     private int emptyVLen;
/* 060 */     private boolean isBatchFull = false;
/* 061 */
/* 062 */     public agg_FastHashMap(
/* 063 */       org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
/* 064 */       InternalRow emptyAggregationBuffer) {
/* 065 */       batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
/* 066 */       .allocate(keySchema, valueSchema, taskMemoryManager, capacity);
/* 067 */
/* 068 */       final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema);
/* 069 */       final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
/* 070 */
/* 071 */       emptyVBase = emptyBuffer;
/* 072 */       emptyVOff = Platform.BYTE_ARRAY_OFFSET;
/* 073 */       emptyVLen = emptyBuffer.length;
/* 074 */
/* 075 */       buckets = new int[numBuckets];
/* 076 */       java.util.Arrays.fill(buckets, -1);
/* 077 */     }
/* 078 */
/* 079 */     public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) {
/* 080 */       long h = hash(agg_key);
/* 081 */       int step = 0;
/* 082 */       int idx = (int) h & (numBuckets - 1);
/* 083 */       while (step < maxSteps) {
/* 084 */         // Return bucket index if it's either an empty slot or already contains the key
/* 085 */         if (buckets[idx] == -1) {
/* 086 */           if (numRows < capacity && !isBatchFull) {
/* 087 */             // creating the unsafe for new entry
/* 088 */             UnsafeRow agg_result = new UnsafeRow(1);
/* 089 */             org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder
/* 090 */             = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result,
/* 091 */               32);
/* 092 */             org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
/* 093 */             = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
/* 094 */               agg_holder,
/* 095 */               1);
/* 096 */             agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed
/* 097 */             agg_rowWriter.zeroOutNullBytes();
/* 098 */             agg_rowWriter.write(0, agg_key);
/* 099 */             agg_result.setTotalSize(agg_holder.totalSize());
/* 100 */             Object kbase = agg_result.getBaseObject();
/* 101 */             long koff = agg_result.getBaseOffset();
/* 102 */             int klen = agg_result.getSizeInBytes();
/* 103 */
/* 104 */             UnsafeRow vRow
/* 105 */             = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen);
/* 106 */             if (vRow == null) {
/* 107 */               isBatchFull = true;
/* 108 */             } else {
/* 109 */               buckets[idx] = numRows++;
/* 110 */             }
/* 111 */             return vRow;
/* 112 */           } else {
/* 113 */             // No more space
/* 114 */             return null;
/* 115 */           }
/* 116 */         } else if (equals(idx, agg_key)) {
/* 117 */           return batch.getValueRow(buckets[idx]);
/* 118 */         }
/* 119 */         idx = (idx + 1) & (numBuckets - 1);
/* 120 */         step++;
/* 121 */       }
/* 122 */       // Didn't find it
/* 123 */       return null;
/* 124 */     }
/* 125 */
/* 126 */     private boolean equals(int idx, UTF8String agg_key) {
/* 127 */       UnsafeRow row = batch.getKeyRow(buckets[idx]);
/* 128 */       return (row.getUTF8String(0).equals(agg_key));
/* 129 */     }
/* 130 */
/* 131 */     private long hash(UTF8String agg_key) {
/* 132 */       long agg_hash = 0;
/* 133 */
/* 134 */       int agg_result = 0;
/* 135 */       byte[] agg_bytes = agg_key.getBytes();
/* 136 */       for (int i = 0; i < agg_bytes.length; i++) {
/* 137 */         int agg_hash1 = agg_bytes[i];
/* 138 */         agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2);
/* 139 */       }
/* 140 */
/* 141 */       agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2);
/* 142 */
/* 143 */       return agg_hash;
/* 144 */     }
/* 145 */
/* 146 */     public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
/* 147 */       return batch.rowIterator();
/* 148 */     }
/* 149 */
/* 150 */     public void close() {
/* 151 */       batch.close();
/* 152 */     }
/* 153 */
/* 154 */   }
/* 155 */
/* 156 */   protected void processNext() throws java.io.IOException {
/* 157 */     if (!agg_initAgg) {
/* 158 */       agg_initAgg = true;
/* 159 */       long wholestagecodegen_beforeAgg = System.nanoTime();
/* 160 */       agg_nestedClassInstance1.agg_doAggregateWithKeys();
/* 161 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[8] /* aggTime */).add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000);
/* 162 */     }
/* 163 */
/* 164 */     // output the result
/* 165 */
/* 166 */     while (agg_fastHashMapIter.next()) {
/* 167 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey();
/* 168 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue();
/* 169 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 170 */
/* 171 */       if (shouldStop()) return;
/* 172 */     }
/* 173 */     agg_fastHashMap.close();
/* 174 */
/* 175 */     while (agg_mapIter.next()) {
/* 176 */       UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey();
/* 177 */       UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue();
/* 178 */       wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer);
/* 179 */
/* 180 */       if (shouldStop()) return;
/* 181 */     }
/* 182 */
/* 183 */     agg_mapIter.close();
/* 184 */     if (agg_sorter == null) {
/* 185 */       agg_hashMap.free();
/* 186 */     }
/* 187 */   }
/* 188 */
/* 189 */   private wholestagecodegen_NestedClass wholestagecodegen_nestedClassInstance = new wholestagecodegen_NestedClass();
/* 190 */   private agg_NestedClass1 agg_nestedClassInstance1 = new agg_NestedClass1();
/* 191 */   private agg_NestedClass agg_nestedClassInstance = new agg_NestedClass();
/* 192 */
/* 193 */   private class agg_NestedClass1 {
/* 194 */     private void agg_doAggregateWithKeys() throws java.io.IOException {
/* 195 */       while (inputadapter_input.hasNext() && !stopEarly()) {
/* 196 */         InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 197 */         int inputadapter_value = inputadapter_row.getInt(0);
/* 198 */         boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 199 */         UTF8String inputadapter_value1 = inputadapter_isNull1 ?
/* 200 */         null : (inputadapter_row.getUTF8String(1));
/* 201 */
/* 202 */         agg_nestedClassInstance.agg_doConsume(inputadapter_row, inputadapter_value, inputadapter_value1, inputadapter_isNull1);
/* 203 */         if (shouldStop()) return;
/* 204 */       }
/* 205 */
/* 206 */       agg_fastHashMapIter = agg_fastHashMap.rowIterator();
/* 207 */       agg_mapIter = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap, agg_sorter, ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[6] /* avgHashProbe */));
/* 208 */
/* 209 */     }
/* 210 */
/* 211 */   }
/* 212 */
/* 213 */   private class wholestagecodegen_NestedClass {
/* 214 */     private void agg_doAggregateWithKeysOutput(UnsafeRow agg_keyTerm, UnsafeRow agg_bufferTerm)
/* 215 */     throws java.io.IOException {
/* 216 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* numOutputRows */).add(1);
/* 217 */
/* 218 */       boolean agg_isNull35 = agg_keyTerm.isNullAt(0);
/* 219 */       UTF8String agg_value37 = agg_isNull35 ?
/* 220 */       null : (agg_keyTerm.getUTF8String(0));
/* 221 */       boolean agg_isNull36 = agg_bufferTerm.isNullAt(0);
/* 222 */       double agg_value38 = agg_isNull36 ?
/* 223 */       -1.0 : (agg_bufferTerm.getDouble(0));
/* 224 */       boolean agg_isNull37 = agg_bufferTerm.isNullAt(1);
/* 225 */       long agg_value39 = agg_isNull37 ?
/* 226 */       -1L : (agg_bufferTerm.getLong(1));
/* 227 */
/* 228 */       agg_mutableStateArray1[1].reset();
/* 229 */
/* 230 */       agg_mutableStateArray2[1].zeroOutNullBytes();
/* 231 */
/* 232 */       if (agg_isNull35) {
/* 233 */         agg_mutableStateArray2[1].setNullAt(0);
/* 234 */       } else {
/* 235 */         agg_mutableStateArray2[1].write(0, agg_value37);
/* 236 */       }
/* 237 */
/* 238 */       if (agg_isNull36) {
/* 239 */         agg_mutableStateArray2[1].setNullAt(1);
/* 240 */       } else {
/* 241 */         agg_mutableStateArray2[1].write(1, agg_value38);
/* 242 */       }
/* 243 */
/* 244 */       if (agg_isNull37) {
/* 245 */         agg_mutableStateArray2[1].setNullAt(2);
/* 246 */       } else {
/* 247 */         agg_mutableStateArray2[1].write(2, agg_value39);
/* 248 */       }
/* 249 */       agg_mutableStateArray[1].setTotalSize(agg_mutableStateArray1[1].totalSize());
/* 250 */       append(agg_mutableStateArray[1]);
/* 251 */
/* 252 */     }
/* 253 */
/* 254 */   }
/* 255 */
/* 256 */   private class agg_NestedClass {
/* 257 */     private void agg_doConsume(InternalRow inputadapter_row, int agg_expr_0, UTF8String agg_expr_1, boolean agg_exprIsNull_1) throws java.io.IOException {
/* 258 */       UnsafeRow agg_unsafeRowAggBuffer = null;
/* 259 */       UnsafeRow agg_fastAggBuffer = null;
/* 260 */
/* 261 */       if (true) {
/* 262 */         if (!agg_exprIsNull_1) {
/* 263 */           agg_fastAggBuffer = agg_fastHashMap.findOrInsert(
/* 264 */             agg_expr_1);
/* 265 */         }
/* 266 */       }
/* 267 */       // Cannot find the key in fast hash map, try regular hash map.
/* 268 */       if (agg_fastAggBuffer == null) {
/* 269 */         // generate grouping key
/* 270 */         agg_mutableStateArray1[0].reset();
/* 271 */
/* 272 */         agg_mutableStateArray2[0].zeroOutNullBytes();
/* 273 */
/* 274 */         if (agg_exprIsNull_1) {
/* 275 */           agg_mutableStateArray2[0].setNullAt(0);
/* 276 */         } else {
/* 277 */           agg_mutableStateArray2[0].write(0, agg_expr_1);
/* 278 */         }
/* 279 */         agg_mutableStateArray[0].setTotalSize(agg_mutableStateArray1[0].totalSize());
/* 280 */         int agg_value7 = 42;
/* 281 */
/* 282 */         if (!agg_exprIsNull_1) {
/* 283 */           agg_value7 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(agg_expr_1.getBaseObject(), agg_expr_1.getBaseOffset(), agg_expr_1.numBytes(), agg_value7);
/* 284 */         }
/* 285 */         if (true) {
/* 286 */           // try to get the buffer from hash map
/* 287 */           agg_unsafeRowAggBuffer =
/* 288 */           agg_hashMap.getAggregationBufferFromUnsafeRow(agg_mutableStateArray[0], agg_value7);
/* 289 */         }
/* 290 */         // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
/* 291 */         // aggregation after processing all input rows.
/* 292 */         if (agg_unsafeRowAggBuffer == null) {
/* 293 */           if (agg_sorter == null) {
/* 294 */             agg_sorter = agg_hashMap.destructAndCreateExternalSorter();
/* 295 */           } else {
/* 296 */             agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter());
/* 297 */           }
/* 298 */
/* 299 */           // the hash map had be spilled, it should have enough memory now,
/* 300 */           // try to allocate buffer again.
/* 301 */           agg_unsafeRowAggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(
/* 302 */             agg_mutableStateArray[0], agg_value7);
/* 303 */           if (agg_unsafeRowAggBuffer == null) {
/* 304 */             // failed to allocate the first page
/* 305 */             throw new OutOfMemoryError("No enough memory for aggregation");
/* 306 */           }
/* 307 */         }
/* 308 */
/* 309 */       }
/* 310 */
/* 311 */       if (agg_fastAggBuffer != null) {
/* 312 */         // common sub-expressions
/* 313 */         boolean agg_isNull21 = false;
/* 314 */         long agg_value23 = -1L;
/* 315 */         if (!false) {
/* 316 */           agg_value23 = (long) agg_expr_0;
/* 317 */         }
/* 318 */         // evaluate aggregate function
/* 319 */         boolean agg_isNull23 = true;
/* 320 */         double agg_value25 = -1.0;
/* 321 */
/* 322 */         boolean agg_isNull24 = agg_fastAggBuffer.isNullAt(0);
/* 323 */         double agg_value26 = agg_isNull24 ?
/* 324 */         -1.0 : (agg_fastAggBuffer.getDouble(0));
/* 325 */         if (!agg_isNull24) {
/* 326 */           agg_agg_isNull25 = true;
/* 327 */           double agg_value27 = -1.0;
/* 328 */           do {
/* 329 */             boolean agg_isNull26 = agg_isNull21;
/* 330 */             double agg_value28 = -1.0;
/* 331 */             if (!agg_isNull21) {
/* 332 */               agg_value28 = (double) agg_value23;
/* 333 */             }
/* 334 */             if (!agg_isNull26) {
/* 335 */               agg_agg_isNull25 = false;
/* 336 */               agg_value27 = agg_value28;
/* 337 */               continue;
/* 338 */             }
/* 339 */
/* 340 */             boolean agg_isNull27 = false;
/* 341 */             double agg_value29 = -1.0;
/* 342 */             if (!false) {
/* 343 */               agg_value29 = (double) 0;
/* 344 */             }
/* 345 */             if (!agg_isNull27) {
/* 346 */               agg_agg_isNull25 = false;
/* 347 */               agg_value27 = agg_value29;
/* 348 */               continue;
/* 349 */             }
/* 350 */
/* 351 */           } while (false);
/* 352 */
/* 353 */           agg_isNull23 = false; // resultCode could change nullability.
/* 354 */           agg_value25 = agg_value26 + agg_value27;
/* 355 */
/* 356 */         }
/* 357 */         boolean agg_isNull29 = false;
/* 358 */         long agg_value31 = -1L;
/* 359 */         if (!false && agg_isNull21) {
/* 360 */           boolean agg_isNull31 = agg_fastAggBuffer.isNullAt(1);
/* 361 */           long agg_value33 = agg_isNull31 ?
/* 362 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 363 */           agg_isNull29 = agg_isNull31;
/* 364 */           agg_value31 = agg_value33;
/* 365 */         } else {
/* 366 */           boolean agg_isNull32 = true;
/* 367 */           long agg_value34 = -1L;
/* 368 */
/* 369 */           boolean agg_isNull33 = agg_fastAggBuffer.isNullAt(1);
/* 370 */           long agg_value35 = agg_isNull33 ?
/* 371 */           -1L : (agg_fastAggBuffer.getLong(1));
/* 372 */           if (!agg_isNull33) {
/* 373 */             agg_isNull32 = false; // resultCode could change nullability.
/* 374 */             agg_value34 = agg_value35 + 1L;
/* 375 */
/* 376 */           }
/* 377 */           agg_isNull29 = agg_isNull32;
/* 378 */           agg_value31 = agg_value34;
/* 379 */         }
/* 380 */         // update fast row
/* 381 */         if (!agg_isNull23) {
/* 382 */           agg_fastAggBuffer.setDouble(0, agg_value25);
/* 383 */         } else {
/* 384 */           agg_fastAggBuffer.setNullAt(0);
/* 385 */         }
/* 386 */
/* 387 */         if (!agg_isNull29) {
/* 388 */           agg_fastAggBuffer.setLong(1, agg_value31);
/* 389 */         } else {
/* 390 */           agg_fastAggBuffer.setNullAt(1);
/* 391 */         }
/* 392 */       } else {
/* 393 */         // common sub-expressions
/* 394 */         boolean agg_isNull7 = false;
/* 395 */         long agg_value9 = -1L;
/* 396 */         if (!false) {
/* 397 */           agg_value9 = (long) agg_expr_0;
/* 398 */         }
/* 399 */         // evaluate aggregate function
/* 400 */         boolean agg_isNull9 = true;
/* 401 */         double agg_value11 = -1.0;
/* 402 */
/* 403 */         boolean agg_isNull10 = agg_unsafeRowAggBuffer.isNullAt(0);
/* 404 */         double agg_value12 = agg_isNull10 ?
/* 405 */         -1.0 : (agg_unsafeRowAggBuffer.getDouble(0));
/* 406 */         if (!agg_isNull10) {
/* 407 */           agg_agg_isNull11 = true;
/* 408 */           double agg_value13 = -1.0;
/* 409 */           do {
/* 410 */             boolean agg_isNull12 = agg_isNull7;
/* 411 */             double agg_value14 = -1.0;
/* 412 */             if (!agg_isNull7) {
/* 413 */               agg_value14 = (double) agg_value9;
/* 414 */             }
/* 415 */             if (!agg_isNull12) {
/* 416 */               agg_agg_isNull11 = false;
/* 417 */               agg_value13 = agg_value14;
/* 418 */               continue;
/* 419 */             }
/* 420 */
/* 421 */             boolean agg_isNull13 = false;
/* 422 */             double agg_value15 = -1.0;
/* 423 */             if (!false) {
/* 424 */               agg_value15 = (double) 0;
/* 425 */             }
/* 426 */             if (!agg_isNull13) {
/* 427 */               agg_agg_isNull11 = false;
/* 428 */               agg_value13 = agg_value15;
/* 429 */               continue;
/* 430 */             }
/* 431 */
/* 432 */           } while (false);
/* 433 */
/* 434 */           agg_isNull9 = false; // resultCode could change nullability.
/* 435 */           agg_value11 = agg_value12 + agg_value13;
/* 436 */
/* 437 */         }
/* 438 */         boolean agg_isNull15 = false;
/* 439 */         long agg_value17 = -1L;
/* 440 */         if (!false && agg_isNull7) {
/* 441 */           boolean agg_isNull17 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 442 */           long agg_value19 = agg_isNull17 ?
/* 443 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 444 */           agg_isNull15 = agg_isNull17;
/* 445 */           agg_value17 = agg_value19;
/* 446 */         } else {
/* 447 */           boolean agg_isNull18 = true;
/* 448 */           long agg_value20 = -1L;
/* 449 */
/* 450 */           boolean agg_isNull19 = agg_unsafeRowAggBuffer.isNullAt(1);
/* 451 */           long agg_value21 = agg_isNull19 ?
/* 452 */           -1L : (agg_unsafeRowAggBuffer.getLong(1));
/* 453 */           if (!agg_isNull19) {
/* 454 */             agg_isNull18 = false; // resultCode could change nullability.
/* 455 */             agg_value20 = agg_value21 + 1L;
/* 456 */
/* 457 */           }
/* 458 */           agg_isNull15 = agg_isNull18;
/* 459 */           agg_value17 = agg_value20;
/* 460 */         }
/* 461 */         // update unsafe row buffer
/* 462 */         if (!agg_isNull9) {
/* 463 */           agg_unsafeRowAggBuffer.setDouble(0, agg_value11);
/* 464 */         } else {
/* 465 */           agg_unsafeRowAggBuffer.setNullAt(0);
/* 466 */         }
/* 467 */
/* 468 */         if (!agg_isNull15) {
/* 469 */           agg_unsafeRowAggBuffer.setLong(1, agg_value17);
/* 470 */         } else {
/* 471 */           agg_unsafeRowAggBuffer.setNullAt(1);
/* 472 */         }
/* 473 */
/* 474 */       }
/* 475 */
/* 476 */     }
/* 477 */
/* 478 */   }
/* 479 */
/* 480 */ }
```

Added UT into `WholeStageCodegenSuite`

Author: Kazuaki Ishizaki <[email protected]>

Closes apache#20779 from kiszk/SPARK-23598.

(cherry picked from commit 1098933)

RB=1737161
BUG=LIHADOOP-48021
G=superfriends-reviewers
R=mshen,zolin,latang,yezhou,fli
A=fli
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants