diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 59fdf659c9e11..4feaf5102f80c 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} +import org.apache.spark.serializer.Serializer /** * :: DeveloperApi :: @@ -40,8 +41,8 @@ case class Aggregator[K, V, C] ( def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] = combineValuesByKey(iter, null) - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], - context: TaskContext): Iterator[(K, C)] = { + def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], context: TaskContext, + serializer: Serializer = SparkEnv.get.serializer): Iterator[(K, C)] = { if (!externalSorting) { val combiners = new AppendOnlyMap[K,C] var kv: Product2[K, V] = null @@ -54,7 +55,8 @@ case class Aggregator[K, V, C] ( } combiners.iterator } else { - val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + val combiners = + new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners, serializer) while (iter.hasNext) { val (k, v) = iter.next() combiners.insert(k, v) @@ -70,7 +72,8 @@ case class Aggregator[K, V, C] ( def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = combineCombinersByKey(iter, null) - def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext) : Iterator[(K, C)] = { + def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext, + serializer: Serializer = SparkEnv.get.serializer) : Iterator[(K, C)] = { if (!externalSorting) { val combiners = new AppendOnlyMap[K,C] var kc: Product2[K, C] = null @@ -83,7 +86,8 @@ case class Aggregator[K, V, C] ( } combiners.iterator } else { - val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + val combiners = + new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners, serializer) while (iter.hasNext) { val (k, c) = iter.next() combiners.insert(k, c) @@ -94,4 +98,4 @@ case class Aggregator[K, V, C] ( combiners.iterator } } -} +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 36b3b956da96c..eb16590ebb641 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -21,9 +21,13 @@ import java.util.HashMap import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext +import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.expressions.MutableProjection +import org.apache.spark.Aggregator /** * :: DeveloperApi :: @@ -129,32 +133,68 @@ case class Aggregate( } } - override def execute() = attachTree(this, "execute") { - if (groupingExpressions.isEmpty) { - child.execute().mapPartitions { iter => - val buffer = newAggregateBuffer() - var currentRow: Row = null - while (iter.hasNext) { - currentRow = iter.next() - var i = 0 - while (i < buffer.length) { - buffer(i).update(currentRow) - i += 1 - } - } - val resultProjection = new Projection(resultExpressions, computedSchema) - val aggregateResults = new GenericMutableRow(computedAggregates.length) + /** + * Implementation of aggregate using external sorting. + */ + private[this] def aggregateWithExternalSorting() = { + + def createCombiner(v: Row) = ArrayBuffer(v) + def mergeValue(buf: ArrayBuffer[Row], v: Row) = buf += v + def mergeCombiners(c1: ArrayBuffer[Row], c2: ArrayBuffer[Row]) = c1 ++ c2 + + child.execute().mapPartitionsWithContext { + (context, iter) => + val aggregator = + new Aggregator[Row, Row, ArrayBuffer[Row]](createCombiner, mergeValue, mergeCombiners) + + val groupingProjection = new MutableProjection(groupingExpressions, childOutput) + val tuplesByGroups = iter.map(x => (groupingProjection(x).copy(), x.copy())) + val sortedProjectionGroups = + aggregator.combineValuesByKey(tuplesByGroups.toIterator, + context, new SparkSqlSerializer(SparkEnv.get.conf)) + + new Iterator[Row] { + private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) + private[this] val resultProjection = + new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2)) + private[this] val joinedRow = new JoinedRow + + override final def hasNext: Boolean = sortedProjectionGroups.hasNext - var i = 0 - while (i < buffer.length) { - aggregateResults(i) = buffer(i).eval(EmptyRow) - i += 1 + override final def next(): Row = { + val currentEntry = sortedProjectionGroups.next() + val currentGroupKey = currentEntry._1 + val currentGroupIterator = currentEntry._2.iterator + + val currentBuffer = newAggregateBuffer() + while (currentGroupIterator.hasNext) { + val currentRow = currentGroupIterator.next() + var i = 0 + while (i < currentBuffer.length) { + currentBuffer(i).update(currentRow) + i += 1 + } + } + + var i = 0 + while (i < currentBuffer.length) { + // Evaluating an aggregate buffer returns the result. No row is required since we + // already added all rows in the group using update. + aggregateResults(i) = currentBuffer(i).eval(EmptyRow) + i += 1 + } + resultProjection(joinedRow(aggregateResults, currentGroupKey)) + } } + } + } - Iterator(resultProjection(aggregateResults)) - } - } else { - child.execute().mapPartitions { iter => + /** + * Implementation of aggregate without external sorting. + */ + private[this] def aggregate() = { + child.execute().mapPartitions { + iter => val hashTable = new HashMap[Row, Array[AggregateFunction]] val groupingProjection = new MutableProjection(groupingExpressions, childOutput) @@ -199,6 +239,41 @@ case class Aggregate( resultProjection(joinedRow(aggregateResults, currentGroup)) } } + } + } + + override def execute() = attachTree(this, "execute") { + if (groupingExpressions.isEmpty) { + child.execute().mapPartitions { + iter => + val buffer = newAggregateBuffer() + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + var i = 0 + while (i < buffer.length) { + buffer(i).update(currentRow) + i += 1 + } + } + val resultProjection = new Projection(resultExpressions, computedSchema) + val aggregateResults = new GenericMutableRow(computedAggregates.length) + + var i = 0 + while (i < buffer.length) { + aggregateResults(i) = buffer(i).eval(EmptyRow) + i += 1 + } + + Iterator(resultProjection(aggregateResults)) + } + } else { + val externalSorting = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) + + if (externalSorting) { + aggregateWithExternalSorting() + } else { + aggregate() } } } diff --git a/sql/core/src/test/scala/AggregateBMSuite.scala b/sql/core/src/test/scala/AggregateBMSuite.scala new file mode 100644 index 0000000000000..1de9d5b09fc51 --- /dev/null +++ b/sql/core/src/test/scala/AggregateBMSuite.scala @@ -0,0 +1,35 @@ +package org.apache.spark.sql + +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.catalyst.expressions.Sum +import org.apache.spark.sql.test.TestSQLContext + +/* Implicits */ +import TestSQLContext._ + +/** + * A simple benchmark to compare the performance of group by aggregate with and + * without external sorting. + */ +class AggregateBMSuite extends QueryTest { + + test("agg random 10m") { + val t0 = System.nanoTime() + val sparkAnswerIterator = testDataLarge.groupBy('a)('a, Sum('b)).collect().iterator + val t1 = System.nanoTime() + println((t1 - t0)/1000000 + " ms") + var isValid = true + while (sparkAnswerIterator.hasNext) { + val group = sparkAnswerIterator.next() + // the sum is expected to be a 10k times the grouping attribute + if (group.getLong(1) != group.getInt(0) * 10000) { + isValid = false + } + } + if (!isValid) { + fail ( + "Invalid aggregation results" + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 002b7f0adafab..a3e174627d933 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -47,6 +47,12 @@ object TestData { (1, null) :: (2, 2) :: Nil) + case class TestDataLarge(a: Int, b:Int) + val testDataLarge: SchemaRDD = + TestSQLContext.sparkContext.parallelize( + (1 to 10000000).map(i => TestDataLarge(i%1000, i%1000))) + testDataLarge.registerAsTable("testDataLarge") + case class UpperCaseData(N: Int, L: String) val upperCaseData = TestSQLContext.sparkContext.parallelize(