From 864497a5eab63b8db93701977d8996a5169170d6 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Thu, 31 Mar 2016 14:15:34 -0700 Subject: [PATCH 1/5] Generate AggregateHashMap class during TungstenAggregate codegen --- .../aggregate/TungstenAggregate.scala | 20 ++- .../aggregate/TungstenAggregateHashMap.scala | 130 ++++++++++++++++++ 2 files changed, 147 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 60027edc7c396..d4f1e5b084276 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( @@ -64,8 +64,8 @@ case class TungstenAggregate( override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil case None => UnspecifiedDistribution :: Nil } } @@ -437,6 +437,19 @@ case class TungstenAggregate( val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + // create AggregateHashMap + val isAggregateHashMapEnabled: Boolean = false + val isAggregateHashMapSupported: Boolean = + (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) + val aggregateHashMapTerm = ctx.freshName("aggregateHashMap") + val aggregateHashMapClassName = ctx.freshName("AggregateHashMap") + val aggregateHashMapGenerator = + new TungstenAggregateHashMap(aggregateHashMapClassName, groupingKeySchema, bufferSchema) + if (isAggregateHashMapEnabled && isAggregateHashMapSupported) { + ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm, + s"$aggregateHashMapTerm = new $aggregateHashMapClassName();") + } + // create hashMap val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") @@ -452,6 +465,7 @@ case class TungstenAggregate( val doAgg = ctx.freshName("doAggregateWithKeys") ctx.addNewFunction(doAgg, s""" + ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} private void $doAgg() throws java.io.IOException { ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala new file mode 100644 index 0000000000000..87ad946db9a96 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.types.StructType + +class TungstenAggregateHashMap( + generatedClassName: String, + groupingKeySchema: StructType, + bufferSchema: StructType) { + val groupingKeys = groupingKeySchema.map(key => (key.dataType.typeName, key.name)) + val bufferValues = bufferSchema.map(key => (key.name, key.dataType.typeName)) + val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ") + + def generate(): String = { + s""" + |public class $generatedClassName { + |${initializeAggregateHashMap()} + | + |${generateFindOrInsert()} + | + |${generateEquals()} + | + |${generateHashFunction()} + |} + """.stripMargin + } + + def initializeAggregateHashMap(): String = { + val generatedSchema: String = + s""" + |new org.apache.spark.sql.types.StructType() + |${(groupingKeySchema ++ bufferSchema).map(key => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""") + .mkString("\n")}; + """.stripMargin + + s""" + | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; + | private int[] buckets; + | private int numBuckets; + | private int maxSteps; + | private int numRows = 0; + | private org.apache.spark.sql.types.StructType schema = $generatedSchema + | + | public $generatedClassName(int capacity, double loadFactor, int maxSteps) { + | assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); + | this.maxSteps = maxSteps; + | numBuckets = (int) (capacity / loadFactor); + | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, + | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); + | buckets = new int[numBuckets]; + | java.util.Arrays.fill(buckets, -1); + | } + | + | public $generatedClassName() { + | new $generatedClassName(1 << 16, 0.25, 5); + | } + """.stripMargin + } + + def generateHashFunction(): String = { + s""" + |// TODO: Improve this Hash Function + |private long hash($groupingKeySignature) { + | return ${groupingKeys.map(_._2).mkString(" & ")}; + |} + """.stripMargin + } + + def generateEquals(): String = { + s""" + |private boolean equals(int idx, $groupingKeySignature) { + | return ${groupingKeys.zipWithIndex.map(key => + s"batch.column(${key._2}).getLong(buckets[idx]) == ${key._1._2}").mkString(" && ")}; + |} + """.stripMargin + } + + def generateFindOrInsert(): String = { + s""" + |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${ + groupingKeySignature}) { + | int idx = find(${groupingKeys.map(_._2).mkString(", ")}); + | if (idx != -1 && buckets[idx] == -1) { + | ${groupingKeys.zipWithIndex.map(key => + s"batch.column(${key._2}).putLong(numRows, ${key._1._2});").mkString("\n")} + | ${bufferValues.zipWithIndex.map(key => + s"batch.column(${groupingKeys.length + key._2}).putLong(numRows, 0);") + .mkString("\n")} + | buckets[idx] = numRows++; + | } + | return batch.getRow(buckets[idx]); + |} + | + |private int find($groupingKeySignature) { + | long h = hash(${groupingKeys.map(_._2).mkString(", ")}); + | int step = 0; + | int idx = (int) h & (numBuckets - 1); + | while (step < maxSteps) { + | // Return bucket index if it's either an empty slot or already contains the key + | if (buckets[idx] == -1) { + | return idx; + | } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) { + | return idx; + | } + | idx = (idx + 1) & (numBuckets - 1); + | step++; + | } + |// Didn't find it + |return -1; + |} + """.stripMargin + } +} From 564310f0d4e3bce0b56c1d1e9e304c2432dce64a Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 5 Apr 2016 23:29:49 -0700 Subject: [PATCH 2/5] unique names --- .../spark/sql/execution/aggregate/TungstenAggregate.scala | 2 +- .../sql/execution/aggregate/TungstenAggregateHashMap.scala | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index d4f1e5b084276..9e5535c981639 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -444,7 +444,7 @@ case class TungstenAggregate( val aggregateHashMapTerm = ctx.freshName("aggregateHashMap") val aggregateHashMapClassName = ctx.freshName("AggregateHashMap") val aggregateHashMapGenerator = - new TungstenAggregateHashMap(aggregateHashMapClassName, groupingKeySchema, bufferSchema) + new TungstenAggregateHashMap(ctx, aggregateHashMapClassName, groupingKeySchema, bufferSchema) if (isAggregateHashMapEnabled && isAggregateHashMapSupported) { ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm, s"$aggregateHashMapTerm = new $aggregateHashMapClassName();") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala index 87ad946db9a96..cb3d432472f48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types.StructType class TungstenAggregateHashMap( + ctx: CodegenContext, generatedClassName: String, groupingKeySchema: StructType, bufferSchema: StructType) { - val groupingKeys = groupingKeySchema.map(key => (key.dataType.typeName, key.name)) - val bufferValues = bufferSchema.map(key => (key.name, key.dataType.typeName)) + val groupingKeys = groupingKeySchema.map(key => (key.dataType.typeName, ctx.freshName("key"))) + val bufferValues = bufferSchema.map(key => (ctx.freshName("value"), key.dataType.typeName)) val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ") def generate(): String = { From 4e3baac834fcf363a953444704ab616aa503e0c9 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 6 Apr 2016 21:20:14 -0700 Subject: [PATCH 3/5] CR --- .../aggregate/TungstenAggregate.scala | 11 +++- .../aggregate/TungstenAggregateHashMap.scala | 51 +++++++------------ 2 files changed, 27 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 9e5535c981639..db113553bca43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -442,12 +442,19 @@ case class TungstenAggregate( val isAggregateHashMapSupported: Boolean = (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) val aggregateHashMapTerm = ctx.freshName("aggregateHashMap") - val aggregateHashMapClassName = ctx.freshName("AggregateHashMap") + val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap") val aggregateHashMapGenerator = new TungstenAggregateHashMap(ctx, aggregateHashMapClassName, groupingKeySchema, bufferSchema) if (isAggregateHashMapEnabled && isAggregateHashMapSupported) { ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm, - s"$aggregateHashMapTerm = new $aggregateHashMapClassName();") + s""" + |org.apache.spark.sql.types.StructType schema = + | new org.apache.spark.sql.types.StructType() + | ${(groupingKeySchema ++ bufferSchema).map(k => + s""".add("${k.name}", org.apache.spark.sql.types.DataTypes.${k.dataType})""") + .mkString("\n")}; + |$aggregateHashMapTerm = new $aggregateHashMapClassName(schema); + """.stripMargin) } // create hashMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala index cb3d432472f48..7ef6f9edd7102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala @@ -25,8 +25,8 @@ class TungstenAggregateHashMap( generatedClassName: String, groupingKeySchema: StructType, bufferSchema: StructType) { - val groupingKeys = groupingKeySchema.map(key => (key.dataType.typeName, ctx.freshName("key"))) - val bufferValues = bufferSchema.map(key => (ctx.freshName("value"), key.dataType.typeName)) + val groupingKeys = groupingKeySchema.map(k => (k.dataType.typeName, ctx.freshName("key"))) + val bufferValues = bufferSchema.map(k => (k.dataType.typeName, ctx.freshName("value"))) val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ") def generate(): String = { @@ -44,23 +44,15 @@ class TungstenAggregateHashMap( } def initializeAggregateHashMap(): String = { - val generatedSchema: String = - s""" - |new org.apache.spark.sql.types.StructType() - |${(groupingKeySchema ++ bufferSchema).map(key => - s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""") - .mkString("\n")}; - """.stripMargin - s""" | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; | private int[] buckets; | private int numBuckets; | private int maxSteps; | private int numRows = 0; - | private org.apache.spark.sql.types.StructType schema = $generatedSchema | - | public $generatedClassName(int capacity, double loadFactor, int maxSteps) { + | public $generatedClassName(org.apache.spark.sql.types.StructType schema, int capacity, + | double loadFactor, int maxSteps) { | assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); | this.maxSteps = maxSteps; | numBuckets = (int) (capacity / loadFactor); @@ -70,8 +62,8 @@ class TungstenAggregateHashMap( | java.util.Arrays.fill(buckets, -1); | } | - | public $generatedClassName() { - | new $generatedClassName(1 << 16, 0.25, 5); + | public $generatedClassName(org.apache.spark.sql.types.StructType schema) { + | new $generatedClassName(schema, 1 << 16, 0.25, 5); | } """.stripMargin } @@ -80,7 +72,7 @@ class TungstenAggregateHashMap( s""" |// TODO: Improve this Hash Function |private long hash($groupingKeySignature) { - | return ${groupingKeys.map(_._2).mkString(" & ")}; + | return ${groupingKeys.map(_._2).mkString(" ^ ")}; |} """.stripMargin } @@ -88,8 +80,8 @@ class TungstenAggregateHashMap( def generateEquals(): String = { s""" |private boolean equals(int idx, $groupingKeySignature) { - | return ${groupingKeys.zipWithIndex.map(key => - s"batch.column(${key._2}).getLong(buckets[idx]) == ${key._1._2}").mkString(" && ")}; + | return ${groupingKeys.zipWithIndex.map(k => + s"batch.column(${k._2}).getLong(buckets[idx]) == ${k._1._2}").mkString(" && ")}; |} """.stripMargin } @@ -98,34 +90,27 @@ class TungstenAggregateHashMap( s""" |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${ groupingKeySignature}) { - | int idx = find(${groupingKeys.map(_._2).mkString(", ")}); - | if (idx != -1 && buckets[idx] == -1) { - | ${groupingKeys.zipWithIndex.map(key => - s"batch.column(${key._2}).putLong(numRows, ${key._1._2});").mkString("\n")} - | ${bufferValues.zipWithIndex.map(key => - s"batch.column(${groupingKeys.length + key._2}).putLong(numRows, 0);") - .mkString("\n")} - | buckets[idx] = numRows++; - | } - | return batch.getRow(buckets[idx]); - |} - | - |private int find($groupingKeySignature) { | long h = hash(${groupingKeys.map(_._2).mkString(", ")}); | int step = 0; | int idx = (int) h & (numBuckets - 1); | while (step < maxSteps) { | // Return bucket index if it's either an empty slot or already contains the key | if (buckets[idx] == -1) { - | return idx; + | ${groupingKeys.zipWithIndex.map(k => + s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")} + | ${bufferValues.zipWithIndex.map(k => + s"batch.column(${groupingKeys.length + k._2}).putLong(numRows, 0);") + .mkString("\n")} + | buckets[idx] = numRows++; + | return batch.getRow(buckets[idx]); | } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) { - | return idx; + | return batch.getRow(buckets[idx]); | } | idx = (idx + 1) & (numBuckets - 1); | step++; | } |// Didn't find it - |return -1; + |return null; |} """.stripMargin } From 927a7c8c620c70e07a3722a43f0833bb3cfdd1e9 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Thu, 7 Apr 2016 16:29:29 -0700 Subject: [PATCH 4/5] CR --- .../execution/aggregate/TungstenAggregate.scala | 9 +-------- .../aggregate/TungstenAggregateHashMap.scala | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index db113553bca43..6661ad7e8e62b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -447,14 +447,7 @@ case class TungstenAggregate( new TungstenAggregateHashMap(ctx, aggregateHashMapClassName, groupingKeySchema, bufferSchema) if (isAggregateHashMapEnabled && isAggregateHashMapSupported) { ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm, - s""" - |org.apache.spark.sql.types.StructType schema = - | new org.apache.spark.sql.types.StructType() - | ${(groupingKeySchema ++ bufferSchema).map(k => - s""".add("${k.name}", org.apache.spark.sql.types.DataTypes.${k.dataType})""") - .mkString("\n")}; - |$aggregateHashMapTerm = new $aggregateHashMapClassName(schema); - """.stripMargin) + s"$aggregateHashMapTerm = new $aggregateHashMapClassName();") } // create hashMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala index 7ef6f9edd7102..0dbdbd880f530 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala @@ -44,15 +44,23 @@ class TungstenAggregateHashMap( } def initializeAggregateHashMap(): String = { + val generatedSchema: String = + s""" + |new org.apache.spark.sql.types.StructType() + |${(groupingKeySchema ++ bufferSchema).map(key => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""") + .mkString("\n")}; + """.stripMargin + s""" | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; | private int[] buckets; | private int numBuckets; | private int maxSteps; | private int numRows = 0; + | private org.apache.spark.sql.types.StructType schema = $generatedSchema | - | public $generatedClassName(org.apache.spark.sql.types.StructType schema, int capacity, - | double loadFactor, int maxSteps) { + | public $generatedClassName(int capacity, double loadFactor, int maxSteps) { | assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); | this.maxSteps = maxSteps; | numBuckets = (int) (capacity / loadFactor); @@ -62,8 +70,8 @@ class TungstenAggregateHashMap( | java.util.Arrays.fill(buckets, -1); | } | - | public $generatedClassName(org.apache.spark.sql.types.StructType schema) { - | new $generatedClassName(schema, 1 << 16, 0.25, 5); + | public $generatedClassName() { + | new $generatedClassName(1 << 16, 0.25, 5); | } """.stripMargin } From ec74328ab73766481d3aa7e566fe592bbde747eb Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Thu, 7 Apr 2016 22:15:07 -0700 Subject: [PATCH 5/5] CR --- ...cala => ColumnarAggMapCodeGenerator.scala} | 86 +++++++++++++++++-- .../aggregate/TungstenAggregate.scala | 4 +- 2 files changed, 79 insertions(+), 11 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/{TungstenAggregateHashMap.scala => ColumnarAggMapCodeGenerator.scala} (55%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala similarity index 55% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala index 0dbdbd880f530..e415dd8e6ac9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregateHashMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala @@ -20,7 +20,20 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types.StructType -class TungstenAggregateHashMap( +/** + * This is a helper object to generate an append-only single-key/single value aggregate hash + * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates + * (and fall back to the `BytesToBytesMap` if a given key isn't found). This is 'codegened' in + * TungstenAggregate to speed up aggregates w/ key. + * + * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the + * key-value pairs. The index lookups in the array rely on linear probing (with a small number of + * maximum tries) and use an inexpensive hash function which makes it really efficient for a + * majority of lookups. However, using linear probing and an inexpensive hash function also makes it + * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even + * for certain distribution of keys) and requires us to fall back on the latter for correctness. + */ +class ColumnarAggMapCodeGenerator( ctx: CodegenContext, generatedClassName: String, groupingKeySchema: StructType, @@ -43,7 +56,7 @@ class TungstenAggregateHashMap( """.stripMargin } - def initializeAggregateHashMap(): String = { + private def initializeAggregateHashMap(): String = { val generatedSchema: String = s""" |new org.apache.spark.sql.types.StructType() @@ -76,16 +89,38 @@ class TungstenAggregateHashMap( """.stripMargin } - def generateHashFunction(): String = { + /** + * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For + * instance, if we have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * private long hash(long agg_key, long agg_key1) { + * return agg_key ^ agg_key1; + * } + * }}} + */ + private def generateHashFunction(): String = { s""" - |// TODO: Improve this Hash Function + |// TODO: Improve this hash function |private long hash($groupingKeySignature) { | return ${groupingKeys.map(_._2).mkString(" ^ ")}; |} """.stripMargin } - def generateEquals(): String = { + /** + * Generates a method that returns true if the group-by keys exist at a given index in the + * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we + * have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * private boolean equals(int idx, long agg_key, long agg_key1) { + * return batch.column(0).getLong(buckets[idx]) == agg_key && + * batch.column(1).getLong(buckets[idx]) == agg_key1; + * } + * }}} + */ + private def generateEquals(): String = { s""" |private boolean equals(int idx, $groupingKeySignature) { | return ${groupingKeys.zipWithIndex.map(k => @@ -94,10 +129,43 @@ class TungstenAggregateHashMap( """.stripMargin } - def generateFindOrInsert(): String = { + /** + * Generates a method that returns a mutable + * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the + * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the + * generated method adds the corresponding row in the associated + * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we + * have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert( + * long agg_key, long agg_key1) { + * long h = hash(agg_key, agg_key1); + * int step = 0; + * int idx = (int) h & (numBuckets - 1); + * while (step < maxSteps) { + * // Return bucket index if it's either an empty slot or already contains the key + * if (buckets[idx] == -1) { + * batch.column(0).putLong(numRows, agg_key); + * batch.column(1).putLong(numRows, agg_key1); + * batch.column(2).putLong(numRows, 0); + * buckets[idx] = numRows++; + * return batch.getRow(buckets[idx]); + * } else if (equals(idx, agg_key, agg_key1)) { + * return batch.getRow(buckets[idx]); + * } + * idx = (idx + 1) & (numBuckets - 1); + * step++; + * } + * // Didn't find it + * return null; + * } + * }}} + */ + private def generateFindOrInsert(): String = { s""" |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${ - groupingKeySignature}) { + groupingKeySignature}) { | long h = hash(${groupingKeys.map(_._2).mkString(", ")}); | int step = 0; | int idx = (int) h & (numBuckets - 1); @@ -117,8 +185,8 @@ class TungstenAggregateHashMap( | idx = (idx + 1) & (numBuckets - 1); | step++; | } - |// Didn't find it - |return null; + | // Didn't find it + | return null; |} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 6661ad7e8e62b..0a5a72c52a372 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -443,8 +443,8 @@ case class TungstenAggregate( (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) val aggregateHashMapTerm = ctx.freshName("aggregateHashMap") val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap") - val aggregateHashMapGenerator = - new TungstenAggregateHashMap(ctx, aggregateHashMapClassName, groupingKeySchema, bufferSchema) + val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName, + groupingKeySchema, bufferSchema) if (isAggregateHashMapEnabled && isAggregateHashMapSupported) { ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm, s"$aggregateHashMapTerm = new $aggregateHashMapClassName();")