Skip to content

Commit f8c9bec

Browse files
sameeragarwalyhuai
authored andcommitted
[SPARK-14394][SQL] Generate AggregateHashMap class for LongTypes during TungstenAggregate codegen
## What changes were proposed in this pull request? This PR adds support for generating the `AggregateHashMap` class in `TungstenAggregate` if the aggregate group by keys/value are of `LongType`. Note that currently this generate aggregate is not actually used. NB: This currently only supports `LongType` keys/values (please see `isAggregateHashMapSupported` in `TungstenAggregate`) and will be generalized to other data types in a subsequent PR. ## How was this patch tested? Manually inspected the generated code. This is what the generated map looks like for 2 keys: ```java /* 068 */ public class agg_GeneratedAggregateHashMap { /* 069 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; /* 070 */ private int[] buckets; /* 071 */ private int numBuckets; /* 072 */ private int maxSteps; /* 073 */ private int numRows = 0; /* 074 */ private org.apache.spark.sql.types.StructType schema = /* 075 */ new org.apache.spark.sql.types.StructType() /* 076 */ .add("k1", org.apache.spark.sql.types.DataTypes.LongType) /* 077 */ .add("k2", org.apache.spark.sql.types.DataTypes.LongType) /* 078 */ .add("sum", org.apache.spark.sql.types.DataTypes.LongType); /* 079 */ /* 080 */ public agg_GeneratedAggregateHashMap(int capacity, double loadFactor, int maxSteps) { /* 081 */ assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); /* 082 */ this.maxSteps = maxSteps; /* 083 */ numBuckets = (int) (capacity / loadFactor); /* 084 */ batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, /* 085 */ org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); /* 086 */ buckets = new int[numBuckets]; /* 087 */ java.util.Arrays.fill(buckets, -1); /* 088 */ } /* 089 */ /* 090 */ public agg_GeneratedAggregateHashMap() { /* 091 */ new agg_GeneratedAggregateHashMap(1 << 16, 0.25, 5); /* 092 */ } /* 093 */ /* 094 */ public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(long agg_key, long agg_key1) { /* 095 */ long h = hash(agg_key, agg_key1); /* 096 */ int step = 0; /* 097 */ int idx = (int) h & (numBuckets - 1); /* 098 */ while (step < maxSteps) { /* 099 */ // Return bucket index if it's either an empty slot or already contains the key /* 100 */ if (buckets[idx] == -1) { /* 101 */ batch.column(0).putLong(numRows, agg_key); /* 102 */ batch.column(1).putLong(numRows, agg_key1); /* 103 */ batch.column(2).putLong(numRows, 0); /* 104 */ buckets[idx] = numRows++; /* 105 */ return batch.getRow(buckets[idx]); /* 106 */ } else if (equals(idx, agg_key, agg_key1)) { /* 107 */ return batch.getRow(buckets[idx]); /* 108 */ } /* 109 */ idx = (idx + 1) & (numBuckets - 1); /* 110 */ step++; /* 111 */ } /* 112 */ // Didn't find it /* 113 */ return null; /* 114 */ } /* 115 */ /* 116 */ private boolean equals(int idx, long agg_key, long agg_key1) { /* 117 */ return batch.column(0).getLong(buckets[idx]) == agg_key && batch.column(1).getLong(buckets[idx]) == agg_key1; /* 118 */ } /* 119 */ /* 120 */ // TODO: Improve this Hash Function /* 121 */ private long hash(long agg_key, long agg_key1) { /* 122 */ return agg_key ^ agg_key1; /* 123 */ } /* 124 */ /* 125 */ } ``` Author: Sameer Agarwal <[email protected]> Closes #12161 from sameeragarwal/tungsten-aggregate.
1 parent 0275753 commit f8c9bec

File tree

2 files changed

+210
-3
lines changed

2 files changed

+210
-3
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.aggregate
19+
20+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
21+
import org.apache.spark.sql.types.StructType
22+
23+
/**
24+
* This is a helper object to generate an append-only single-key/single value aggregate hash
25+
* map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates
26+
* (and fall back to the `BytesToBytesMap` if a given key isn't found). This is 'codegened' in
27+
* TungstenAggregate to speed up aggregates w/ key.
28+
*
29+
* It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the
30+
* key-value pairs. The index lookups in the array rely on linear probing (with a small number of
31+
* maximum tries) and use an inexpensive hash function which makes it really efficient for a
32+
* majority of lookups. However, using linear probing and an inexpensive hash function also makes it
33+
* less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even
34+
* for certain distribution of keys) and requires us to fall back on the latter for correctness.
35+
*/
36+
class ColumnarAggMapCodeGenerator(
37+
ctx: CodegenContext,
38+
generatedClassName: String,
39+
groupingKeySchema: StructType,
40+
bufferSchema: StructType) {
41+
val groupingKeys = groupingKeySchema.map(k => (k.dataType.typeName, ctx.freshName("key")))
42+
val bufferValues = bufferSchema.map(k => (k.dataType.typeName, ctx.freshName("value")))
43+
val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ")
44+
45+
def generate(): String = {
46+
s"""
47+
|public class $generatedClassName {
48+
|${initializeAggregateHashMap()}
49+
|
50+
|${generateFindOrInsert()}
51+
|
52+
|${generateEquals()}
53+
|
54+
|${generateHashFunction()}
55+
|}
56+
""".stripMargin
57+
}
58+
59+
private def initializeAggregateHashMap(): String = {
60+
val generatedSchema: String =
61+
s"""
62+
|new org.apache.spark.sql.types.StructType()
63+
|${(groupingKeySchema ++ bufferSchema).map(key =>
64+
s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""")
65+
.mkString("\n")};
66+
""".stripMargin
67+
68+
s"""
69+
| private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
70+
| private int[] buckets;
71+
| private int numBuckets;
72+
| private int maxSteps;
73+
| private int numRows = 0;
74+
| private org.apache.spark.sql.types.StructType schema = $generatedSchema
75+
|
76+
| public $generatedClassName(int capacity, double loadFactor, int maxSteps) {
77+
| assert (capacity > 0 && ((capacity & (capacity - 1)) == 0));
78+
| this.maxSteps = maxSteps;
79+
| numBuckets = (int) (capacity / loadFactor);
80+
| batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema,
81+
| org.apache.spark.memory.MemoryMode.ON_HEAP, capacity);
82+
| buckets = new int[numBuckets];
83+
| java.util.Arrays.fill(buckets, -1);
84+
| }
85+
|
86+
| public $generatedClassName() {
87+
| new $generatedClassName(1 << 16, 0.25, 5);
88+
| }
89+
""".stripMargin
90+
}
91+
92+
/**
93+
* Generates a method that computes a hash by currently xor-ing all individual group-by keys. For
94+
* instance, if we have 2 long group-by keys, the generated function would be of the form:
95+
*
96+
* {{{
97+
* private long hash(long agg_key, long agg_key1) {
98+
* return agg_key ^ agg_key1;
99+
* }
100+
* }}}
101+
*/
102+
private def generateHashFunction(): String = {
103+
s"""
104+
|// TODO: Improve this hash function
105+
|private long hash($groupingKeySignature) {
106+
| return ${groupingKeys.map(_._2).mkString(" ^ ")};
107+
|}
108+
""".stripMargin
109+
}
110+
111+
/**
112+
* Generates a method that returns true if the group-by keys exist at a given index in the
113+
* associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
114+
* have 2 long group-by keys, the generated function would be of the form:
115+
*
116+
* {{{
117+
* private boolean equals(int idx, long agg_key, long agg_key1) {
118+
* return batch.column(0).getLong(buckets[idx]) == agg_key &&
119+
* batch.column(1).getLong(buckets[idx]) == agg_key1;
120+
* }
121+
* }}}
122+
*/
123+
private def generateEquals(): String = {
124+
s"""
125+
|private boolean equals(int idx, $groupingKeySignature) {
126+
| return ${groupingKeys.zipWithIndex.map(k =>
127+
s"batch.column(${k._2}).getLong(buckets[idx]) == ${k._1._2}").mkString(" && ")};
128+
|}
129+
""".stripMargin
130+
}
131+
132+
/**
133+
* Generates a method that returns a mutable
134+
* [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the
135+
* aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
136+
* generated method adds the corresponding row in the associated
137+
* [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
138+
* have 2 long group-by keys, the generated function would be of the form:
139+
*
140+
* {{{
141+
* public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(
142+
* long agg_key, long agg_key1) {
143+
* long h = hash(agg_key, agg_key1);
144+
* int step = 0;
145+
* int idx = (int) h & (numBuckets - 1);
146+
* while (step < maxSteps) {
147+
* // Return bucket index if it's either an empty slot or already contains the key
148+
* if (buckets[idx] == -1) {
149+
* batch.column(0).putLong(numRows, agg_key);
150+
* batch.column(1).putLong(numRows, agg_key1);
151+
* batch.column(2).putLong(numRows, 0);
152+
* buckets[idx] = numRows++;
153+
* return batch.getRow(buckets[idx]);
154+
* } else if (equals(idx, agg_key, agg_key1)) {
155+
* return batch.getRow(buckets[idx]);
156+
* }
157+
* idx = (idx + 1) & (numBuckets - 1);
158+
* step++;
159+
* }
160+
* // Didn't find it
161+
* return null;
162+
* }
163+
* }}}
164+
*/
165+
private def generateFindOrInsert(): String = {
166+
s"""
167+
|public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${
168+
groupingKeySignature}) {
169+
| long h = hash(${groupingKeys.map(_._2).mkString(", ")});
170+
| int step = 0;
171+
| int idx = (int) h & (numBuckets - 1);
172+
| while (step < maxSteps) {
173+
| // Return bucket index if it's either an empty slot or already contains the key
174+
| if (buckets[idx] == -1) {
175+
| ${groupingKeys.zipWithIndex.map(k =>
176+
s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")}
177+
| ${bufferValues.zipWithIndex.map(k =>
178+
s"batch.column(${groupingKeys.length + k._2}).putLong(numRows, 0);")
179+
.mkString("\n")}
180+
| buckets[idx] = numRows++;
181+
| return batch.getRow(buckets[idx]);
182+
| } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) {
183+
| return batch.getRow(buckets[idx]);
184+
| }
185+
| idx = (idx + 1) & (numBuckets - 1);
186+
| step++;
187+
| }
188+
| // Didn't find it
189+
| return null;
190+
|}
191+
""".stripMargin
192+
}
193+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.catalyst.plans.physical._
2828
import org.apache.spark.sql.execution._
2929
import org.apache.spark.sql.execution.metric.SQLMetrics
30-
import org.apache.spark.sql.types.StructType
30+
import org.apache.spark.sql.types.{LongType, StructType}
3131
import org.apache.spark.unsafe.KVIterator
3232

3333
case class TungstenAggregate(
@@ -64,8 +64,8 @@ case class TungstenAggregate(
6464

6565
override def requiredChildDistribution: List[Distribution] = {
6666
requiredChildDistributionExpressions match {
67-
case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
68-
case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
67+
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
68+
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
6969
case None => UnspecifiedDistribution :: Nil
7070
}
7171
}
@@ -437,6 +437,19 @@ case class TungstenAggregate(
437437
val initAgg = ctx.freshName("initAgg")
438438
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
439439

440+
// create AggregateHashMap
441+
val isAggregateHashMapEnabled: Boolean = false
442+
val isAggregateHashMapSupported: Boolean =
443+
(groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType)
444+
val aggregateHashMapTerm = ctx.freshName("aggregateHashMap")
445+
val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap")
446+
val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName,
447+
groupingKeySchema, bufferSchema)
448+
if (isAggregateHashMapEnabled && isAggregateHashMapSupported) {
449+
ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm,
450+
s"$aggregateHashMapTerm = new $aggregateHashMapClassName();")
451+
}
452+
440453
// create hashMap
441454
val thisPlan = ctx.addReferenceObj("plan", this)
442455
hashMapTerm = ctx.freshName("hashMap")
@@ -452,6 +465,7 @@ case class TungstenAggregate(
452465
val doAgg = ctx.freshName("doAggregateWithKeys")
453466
ctx.addNewFunction(doAgg,
454467
s"""
468+
${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""}
455469
private void $doAgg() throws java.io.IOException {
456470
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
457471

0 commit comments

Comments
 (0)