Skip to content

Commit ec74328

Browse files
committed
CR
1 parent 927a7c8 commit ec74328

File tree

2 files changed

+79
-11
lines changed

2 files changed

+79
-11
lines changed
Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,20 @@ package org.apache.spark.sql.execution.aggregate
2020
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
2121
import org.apache.spark.sql.types.StructType
2222

23-
class TungstenAggregateHashMap(
23+
/**
24+
* This is a helper object to generate an append-only single-key/single value aggregate hash
25+
* map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates
26+
* (and fall back to the `BytesToBytesMap` if a given key isn't found). This is 'codegened' in
27+
* TungstenAggregate to speed up aggregates w/ key.
28+
*
29+
* It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the
30+
* key-value pairs. The index lookups in the array rely on linear probing (with a small number of
31+
* maximum tries) and use an inexpensive hash function which makes it really efficient for a
32+
* majority of lookups. However, using linear probing and an inexpensive hash function also makes it
33+
* less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even
34+
* for certain distribution of keys) and requires us to fall back on the latter for correctness.
35+
*/
36+
class ColumnarAggMapCodeGenerator(
2437
ctx: CodegenContext,
2538
generatedClassName: String,
2639
groupingKeySchema: StructType,
@@ -43,7 +56,7 @@ class TungstenAggregateHashMap(
4356
""".stripMargin
4457
}
4558

46-
def initializeAggregateHashMap(): String = {
59+
private def initializeAggregateHashMap(): String = {
4760
val generatedSchema: String =
4861
s"""
4962
|new org.apache.spark.sql.types.StructType()
@@ -76,16 +89,38 @@ class TungstenAggregateHashMap(
7689
""".stripMargin
7790
}
7891

79-
def generateHashFunction(): String = {
92+
/**
93+
* Generates a method that computes a hash by currently xor-ing all individual group-by keys. For
94+
* instance, if we have 2 long group-by keys, the generated function would be of the form:
95+
*
96+
* {{{
97+
* private long hash(long agg_key, long agg_key1) {
98+
* return agg_key ^ agg_key1;
99+
* }
100+
* }}}
101+
*/
102+
private def generateHashFunction(): String = {
80103
s"""
81-
|// TODO: Improve this Hash Function
104+
|// TODO: Improve this hash function
82105
|private long hash($groupingKeySignature) {
83106
| return ${groupingKeys.map(_._2).mkString(" ^ ")};
84107
|}
85108
""".stripMargin
86109
}
87110

88-
def generateEquals(): String = {
111+
/**
112+
* Generates a method that returns true if the group-by keys exist at a given index in the
113+
* associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
114+
* have 2 long group-by keys, the generated function would be of the form:
115+
*
116+
* {{{
117+
* private boolean equals(int idx, long agg_key, long agg_key1) {
118+
* return batch.column(0).getLong(buckets[idx]) == agg_key &&
119+
* batch.column(1).getLong(buckets[idx]) == agg_key1;
120+
* }
121+
* }}}
122+
*/
123+
private def generateEquals(): String = {
89124
s"""
90125
|private boolean equals(int idx, $groupingKeySignature) {
91126
| return ${groupingKeys.zipWithIndex.map(k =>
@@ -94,10 +129,43 @@ class TungstenAggregateHashMap(
94129
""".stripMargin
95130
}
96131

97-
def generateFindOrInsert(): String = {
132+
/**
133+
* Generates a method that returns a mutable
134+
* [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the
135+
* aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
136+
* generated method adds the corresponding row in the associated
137+
* [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
138+
* have 2 long group-by keys, the generated function would be of the form:
139+
*
140+
* {{{
141+
* public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(
142+
* long agg_key, long agg_key1) {
143+
* long h = hash(agg_key, agg_key1);
144+
* int step = 0;
145+
* int idx = (int) h & (numBuckets - 1);
146+
* while (step < maxSteps) {
147+
* // Return bucket index if it's either an empty slot or already contains the key
148+
* if (buckets[idx] == -1) {
149+
* batch.column(0).putLong(numRows, agg_key);
150+
* batch.column(1).putLong(numRows, agg_key1);
151+
* batch.column(2).putLong(numRows, 0);
152+
* buckets[idx] = numRows++;
153+
* return batch.getRow(buckets[idx]);
154+
* } else if (equals(idx, agg_key, agg_key1)) {
155+
* return batch.getRow(buckets[idx]);
156+
* }
157+
* idx = (idx + 1) & (numBuckets - 1);
158+
* step++;
159+
* }
160+
* // Didn't find it
161+
* return null;
162+
* }
163+
* }}}
164+
*/
165+
private def generateFindOrInsert(): String = {
98166
s"""
99167
|public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${
100-
groupingKeySignature}) {
168+
groupingKeySignature}) {
101169
| long h = hash(${groupingKeys.map(_._2).mkString(", ")});
102170
| int step = 0;
103171
| int idx = (int) h & (numBuckets - 1);
@@ -117,8 +185,8 @@ class TungstenAggregateHashMap(
117185
| idx = (idx + 1) & (numBuckets - 1);
118186
| step++;
119187
| }
120-
|// Didn't find it
121-
|return null;
188+
| // Didn't find it
189+
| return null;
122190
|}
123191
""".stripMargin
124192
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,8 @@ case class TungstenAggregate(
443443
(groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType)
444444
val aggregateHashMapTerm = ctx.freshName("aggregateHashMap")
445445
val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap")
446-
val aggregateHashMapGenerator =
447-
new TungstenAggregateHashMap(ctx, aggregateHashMapClassName, groupingKeySchema, bufferSchema)
446+
val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName,
447+
groupingKeySchema, bufferSchema)
448448
if (isAggregateHashMapEnabled && isAggregateHashMapSupported) {
449449
ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm,
450450
s"$aggregateHashMapTerm = new $aggregateHashMapClassName();")

0 commit comments

Comments
 (0)