Skip to content

Commit 23bf400

Browse files
committed
SPARK-554. Add aggregateByKey.
1 parent d45e0c6 commit 23bf400

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
118118
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
119119
}
120120

121+
/**
122+
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
123+
* This function can return a different result type, U, than the type of the values in this RDD,
124+
* V. Thus, we need one operation for merging a T into a U and one operation for merging two U's,
125+
* as in scala.TraversableOnce. The former operation is used for merging values within a partition,
126+
* and the latter is used for merging values between partitions. To avoid memory allocation, both
127+
* of these functions are allowed to modify and return their first argument instead of creating a
128+
* new U.
129+
*/
130+
def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U,
131+
combOp: (U, U) => U): RDD[(K, U)] = {
132+
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
133+
val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
134+
val zeroArray = new Array[Byte](zeroBuffer.limit)
135+
zeroBuffer.get(zeroArray)
136+
137+
lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
138+
def createZero() = cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray))
139+
140+
combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner)
141+
}
142+
143+
/**
144+
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
145+
* This function can return a different result type, U, than the type of the values in this RDD,
146+
* V. Thus, we need one operation for merging a T into a U and one operation for merging two U's,
147+
* as in scala.TraversableOnce. The former operation is used for merging values within a partition,
148+
* and the latter is used for merging values between partitions. To avoid memory allocation, both
149+
* of these functions are allowed to modify and return their first argument instead of creating a
150+
* new U.
151+
*/
152+
def aggregateByKey[U: ClassTag](zeroValue: U, numPartitions: Int)(seqOp: (U, V) => U,
153+
combOp: (U, U) => U): RDD[(K, U)] = {
154+
aggregateByKey(zeroValue, new HashPartitioner(numPartitions))(seqOp, combOp)
155+
}
156+
157+
/**
158+
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
159+
* This function can return a different result type, U, than the type of the values in this RDD,
160+
* V. Thus, we need one operation for merging a T into a U and one operation for merging two U's,
161+
* as in scala.TraversableOnce. The former operation is used for merging values within a partition,
162+
* and the latter is used for merging values between partitions. To avoid memory allocation, both
163+
* of these functions are allowed to modify and return their first argument instead of creating a
164+
* new U.
165+
*/
166+
def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U,
167+
combOp: (U, U) => U): RDD[(K, U)] = {
168+
aggregateByKey(zeroValue, defaultPartitioner(self))(seqOp, combOp)
169+
}
170+
121171
/**
122172
* Merge the values for each key using an associative function and a neutral "zero value" which
123173
* may be added to the result an arbitrary number of times, and must not change the result

core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ import org.apache.spark.SparkContext._
3030
import org.apache.spark.{Partitioner, SharedSparkContext}
3131

3232
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
33+
test("aggregateByKey") {
34+
val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2)
35+
36+
val sets = pairs.aggregateByKey(new HashSet[Int]())(_ += _, _ ++= _).collect()
37+
assert(sets.size === 3)
38+
val valuesFor1 = sets.find(_._1 == 1).get._2
39+
assert(valuesFor1.toList.sorted === List(1))
40+
val valuesFor3 = sets.find(_._1 == 3).get._2
41+
assert(valuesFor3.toList.sorted === List(2))
42+
val valuesFor5 = sets.find(_._1 == 5).get._2
43+
assert(valuesFor5.toList.sorted === List(1, 3))
44+
}
45+
3346
test("groupByKey") {
3447
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
3548
val groups = pairs.groupByKey().collect()

docs/programming-guide.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,10 @@ for details.
890890
<td> <b>reduceByKey</b>(<i>func</i>, [<i>numTasks</i>]) </td>
891891
<td> When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td>
892892
</tr>
893+
<tr>
894+
<td> <b>aggregateByKey</b>(<i>zeroValue</i>)(<i>seqOp</i>, <i>combOp</i>, [<i>numTasks</i>]) </td>
895+
<td> When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td>
896+
</tr>
893897
<tr>
894898
<td> <b>sortByKey</b>([<i>ascending</i>], [<i>numTasks</i>]) </td>
895899
<td> When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean <code>ascending</code> argument.</td>

0 commit comments

Comments
 (0)