Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions core/src/main/scala/org/apache/spark/Aggregator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap}
import org.apache.spark.serializer.Serializer

/**
* :: DeveloperApi ::
Expand All @@ -40,8 +41,8 @@ case class Aggregator[K, V, C] (
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] =
combineValuesByKey(iter, null)

def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]],
context: TaskContext): Iterator[(K, C)] = {
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], context: TaskContext,
serializer: Serializer = SparkEnv.get.serializer): Iterator[(K, C)] = {
if (!externalSorting) {
val combiners = new AppendOnlyMap[K,C]
var kv: Product2[K, V] = null
Expand All @@ -54,7 +55,8 @@ case class Aggregator[K, V, C] (
}
combiners.iterator
} else {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
val combiners =
new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners, serializer)
while (iter.hasNext) {
val (k, v) = iter.next()
combiners.insert(k, v)
Expand All @@ -70,7 +72,8 @@ case class Aggregator[K, V, C] (
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] =
combineCombinersByKey(iter, null)

def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext) : Iterator[(K, C)] = {
def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext,
serializer: Serializer = SparkEnv.get.serializer) : Iterator[(K, C)] = {
if (!externalSorting) {
val combiners = new AppendOnlyMap[K,C]
var kc: Product2[K, C] = null
Expand All @@ -83,7 +86,8 @@ case class Aggregator[K, V, C] (
}
combiners.iterator
} else {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
val combiners =
new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners, serializer)
while (iter.hasNext) {
val (k, c) = iter.next()
combiners.insert(k, c)
Expand All @@ -94,4 +98,4 @@ case class Aggregator[K, V, C] (
combiners.iterator
}
}
}
}
121 changes: 98 additions & 23 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ import java.util.HashMap

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext
import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.MutableProjection
import org.apache.spark.Aggregator

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -129,32 +133,68 @@ case class Aggregate(
}
}

override def execute() = attachTree(this, "execute") {
if (groupingExpressions.isEmpty) {
child.execute().mapPartitions { iter =>
val buffer = newAggregateBuffer()
var currentRow: Row = null
while (iter.hasNext) {
currentRow = iter.next()
var i = 0
while (i < buffer.length) {
buffer(i).update(currentRow)
i += 1
}
}
val resultProjection = new Projection(resultExpressions, computedSchema)
val aggregateResults = new GenericMutableRow(computedAggregates.length)
/**
* Implementation of aggregate using external sorting.
*/
private[this] def aggregateWithExternalSorting() = {

def createCombiner(v: Row) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[Row], v: Row) = buf += v
def mergeCombiners(c1: ArrayBuffer[Row], c2: ArrayBuffer[Row]) = c1 ++ c2

child.execute().mapPartitionsWithContext {
(context, iter) =>
val aggregator =
new Aggregator[Row, Row, ArrayBuffer[Row]](createCombiner, mergeValue, mergeCombiners)

val groupingProjection = new MutableProjection(groupingExpressions, childOutput)
val tuplesByGroups = iter.map(x => (groupingProjection(x).copy(), x.copy()))
val sortedProjectionGroups =
aggregator.combineValuesByKey(tuplesByGroups.toIterator,
context, new SparkSqlSerializer(SparkEnv.get.conf))

new Iterator[Row] {
private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
private[this] val resultProjection =
new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2))
private[this] val joinedRow = new JoinedRow

override final def hasNext: Boolean = sortedProjectionGroups.hasNext

var i = 0
while (i < buffer.length) {
aggregateResults(i) = buffer(i).eval(EmptyRow)
i += 1
override final def next(): Row = {
val currentEntry = sortedProjectionGroups.next()
val currentGroupKey = currentEntry._1
val currentGroupIterator = currentEntry._2.iterator

val currentBuffer = newAggregateBuffer()
while (currentGroupIterator.hasNext) {
val currentRow = currentGroupIterator.next()
var i = 0
while (i < currentBuffer.length) {
currentBuffer(i).update(currentRow)
i += 1
}
}

var i = 0
while (i < currentBuffer.length) {
// Evaluating an aggregate buffer returns the result. No row is required since we
// already added all rows in the group using update.
aggregateResults(i) = currentBuffer(i).eval(EmptyRow)
i += 1
}
resultProjection(joinedRow(aggregateResults, currentGroupKey))
}
}
}
}

Iterator(resultProjection(aggregateResults))
}
} else {
child.execute().mapPartitions { iter =>
/**
* Implementation of aggregate without external sorting.
*/
private[this] def aggregate() = {
child.execute().mapPartitions {
iter =>
val hashTable = new HashMap[Row, Array[AggregateFunction]]
val groupingProjection = new MutableProjection(groupingExpressions, childOutput)

Expand Down Expand Up @@ -199,6 +239,41 @@ case class Aggregate(
resultProjection(joinedRow(aggregateResults, currentGroup))
}
}
}
}

override def execute() = attachTree(this, "execute") {
if (groupingExpressions.isEmpty) {
child.execute().mapPartitions {
iter =>
val buffer = newAggregateBuffer()
var currentRow: Row = null
while (iter.hasNext) {
currentRow = iter.next()
var i = 0
while (i < buffer.length) {
buffer(i).update(currentRow)
i += 1
}
}
val resultProjection = new Projection(resultExpressions, computedSchema)
val aggregateResults = new GenericMutableRow(computedAggregates.length)

var i = 0
while (i < buffer.length) {
aggregateResults(i) = buffer(i).eval(EmptyRow)
i += 1
}

Iterator(resultProjection(aggregateResults))
}
} else {
val externalSorting = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)

if (externalSorting) {
aggregateWithExternalSorting()
} else {
aggregate()
}
}
}
Expand Down
35 changes: 35 additions & 0 deletions sql/core/src/test/scala/AggregateBMSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.apache.spark.sql

import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Sum
import org.apache.spark.sql.test.TestSQLContext

/* Implicits */
import TestSQLContext._

/**
* A simple benchmark to compare the performance of group by aggregate with and
* without external sorting.
*/
class AggregateBMSuite extends QueryTest {

test("agg random 10m") {
val t0 = System.nanoTime()
val sparkAnswerIterator = testDataLarge.groupBy('a)('a, Sum('b)).collect().iterator
val t1 = System.nanoTime()
println((t1 - t0)/1000000 + " ms")
var isValid = true
while (sparkAnswerIterator.hasNext) {
val group = sparkAnswerIterator.next()
// the sum is expected to be a 10k times the grouping attribute
if (group.getLong(1) != group.getInt(0) * 10000) {
isValid = false
}
}
if (!isValid) {
fail (
"Invalid aggregation results"
)
}
}
}
6 changes: 6 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ object TestData {
(1, null) ::
(2, 2) :: Nil)

case class TestDataLarge(a: Int, b:Int)
val testDataLarge: SchemaRDD =
TestSQLContext.sparkContext.parallelize(
(1 to 10000000).map(i => TestDataLarge(i%1000, i%1000)))
testDataLarge.registerAsTable("testDataLarge")

case class UpperCaseData(N: Int, L: String)
val upperCaseData =
TestSQLContext.sparkContext.parallelize(
Expand Down