@@ -20,7 +20,20 @@ package org.apache.spark.sql.execution.aggregate
2020import org .apache .spark .sql .catalyst .expressions .codegen .CodegenContext
2121import org .apache .spark .sql .types .StructType
2222
23- class TungstenAggregateHashMap (
23+ /**
24+ * This is a helper object to generate an append-only single-key/single value aggregate hash
25+ * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates
26+ * (and fall back to the `BytesToBytesMap` if a given key isn't found). This is 'codegened' in
27+ * TungstenAggregate to speed up aggregates w/ key.
28+ *
29+ * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the
30+ * key-value pairs. The index lookups in the array rely on linear probing (with a small number of
31+ * maximum tries) and use an inexpensive hash function which makes it really efficient for a
32+ * majority of lookups. However, using linear probing and an inexpensive hash function also makes it
33+ * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even
34+ * for certain distribution of keys) and requires us to fall back on the latter for correctness.
35+ */
36+ class ColumnarAggMapCodeGenerator (
2437 ctx : CodegenContext ,
2538 generatedClassName : String ,
2639 groupingKeySchema : StructType ,
@@ -43,7 +56,7 @@ class TungstenAggregateHashMap(
4356 """ .stripMargin
4457 }
4558
46- def initializeAggregateHashMap (): String = {
59+ private def initializeAggregateHashMap (): String = {
4760 val generatedSchema : String =
4861 s """
4962 |new org.apache.spark.sql.types.StructType()
@@ -76,16 +89,38 @@ class TungstenAggregateHashMap(
7689 """ .stripMargin
7790 }
7891
79- def generateHashFunction (): String = {
92+ /**
93+ * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For
94+ * instance, if we have 2 long group-by keys, the generated function would be of the form:
95+ *
96+ * {{{
97+ * private long hash(long agg_key, long agg_key1) {
98+ * return agg_key ^ agg_key1;
99+ * }
100+ * }}}
101+ */
102+ private def generateHashFunction (): String = {
80103 s """
81- |// TODO: Improve this Hash Function
104+ |// TODO: Improve this hash function
82105 |private long hash( $groupingKeySignature) {
83106 | return ${groupingKeys.map(_._2).mkString(" ^ " )};
84107 |}
85108 """ .stripMargin
86109 }
87110
88- def generateEquals (): String = {
111+ /**
112+ * Generates a method that returns true if the group-by keys exist at a given index in the
113+ * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch ]]. For instance, if we
114+ * have 2 long group-by keys, the generated function would be of the form:
115+ *
116+ * {{{
117+ * private boolean equals(int idx, long agg_key, long agg_key1) {
118+ * return batch.column(0).getLong(buckets[idx]) == agg_key &&
119+ * batch.column(1).getLong(buckets[idx]) == agg_key1;
120+ * }
121+ * }}}
122+ */
123+ private def generateEquals (): String = {
89124 s """
90125 |private boolean equals(int idx, $groupingKeySignature) {
91126 | return ${groupingKeys.zipWithIndex.map(k =>
@@ -94,10 +129,43 @@ class TungstenAggregateHashMap(
94129 """ .stripMargin
95130 }
96131
97- def generateFindOrInsert (): String = {
132+ /**
133+ * Generates a method that returns a mutable
134+ * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row ]] which keeps track of the
135+ * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
136+ * generated method adds the corresponding row in the associated
137+ * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch ]]. For instance, if we
138+ * have 2 long group-by keys, the generated function would be of the form:
139+ *
140+ * {{{
141+ * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(
142+ * long agg_key, long agg_key1) {
143+ * long h = hash(agg_key, agg_key1);
144+ * int step = 0;
145+ * int idx = (int) h & (numBuckets - 1);
146+ * while (step < maxSteps) {
147+ * // Return bucket index if it's either an empty slot or already contains the key
148+ * if (buckets[idx] == -1) {
149+ * batch.column(0).putLong(numRows, agg_key);
150+ * batch.column(1).putLong(numRows, agg_key1);
151+ * batch.column(2).putLong(numRows, 0);
152+ * buckets[idx] = numRows++;
153+ * return batch.getRow(buckets[idx]);
154+ * } else if (equals(idx, agg_key, agg_key1)) {
155+ * return batch.getRow(buckets[idx]);
156+ * }
157+ * idx = (idx + 1) & (numBuckets - 1);
158+ * step++;
159+ * }
160+ * // Didn't find it
161+ * return null;
162+ * }
163+ * }}}
164+ */
165+ private def generateFindOrInsert (): String = {
98166 s """
99167 |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert( ${
100- groupingKeySignature}) {
168+ groupingKeySignature}) {
101169 | long h = hash( ${groupingKeys.map(_._2).mkString(" , " )});
102170 | int step = 0;
103171 | int idx = (int) h & (numBuckets - 1);
@@ -117,8 +185,8 @@ class TungstenAggregateHashMap(
117185 | idx = (idx + 1) & (numBuckets - 1);
118186 | step++;
119187 | }
120- |// Didn't find it
121- |return null;
188+ | // Didn't find it
189+ | return null;
122190 |}
123191 """ .stripMargin
124192 }
0 commit comments