Skip to content

Commit ce92a9c

Browse files
sryzapwendell
authored andcommitted
SPARK-554. Add aggregateByKey.
Author: Sandy Ryza <[email protected]> Closes #705 from sryza/sandy-spark-554 and squashes the following commits: 2302b8f [Sandy Ryza] Add MIMA exclude f52e0ad [Sandy Ryza] Fix Python tests for real 2f3afa3 [Sandy Ryza] Fix Python test 0b735e9 [Sandy Ryza] Fix line lengths ae56746 [Sandy Ryza] Fix doc (replace T with V) c2be415 [Sandy Ryza] Java and Python aggregateByKey 23bf400 [Sandy Ryza] SPARK-554. Add aggregateByKey.
1 parent 43d53d5 commit ce92a9c

File tree

8 files changed

+179
-2
lines changed

8 files changed

+179
-2
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,50 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
228228
: PartialResult[java.util.Map[K, BoundedDouble]] =
229229
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
230230

231+
/**
232+
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
233+
* This function can return a different result type, U, than the type of the values in this RDD,
234+
* V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
235+
* as in scala.TraversableOnce. The former operation is used for merging values within a
236+
* partition, and the latter is used for merging values between partitions. To avoid memory
237+
* allocation, both of these functions are allowed to modify and return their first argument
238+
* instead of creating a new U.
239+
*/
240+
def aggregateByKey[U](zeroValue: U, partitioner: Partitioner, seqFunc: JFunction2[U, V, U],
241+
combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = {
242+
implicit val ctag: ClassTag[U] = fakeClassTag
243+
fromRDD(rdd.aggregateByKey(zeroValue, partitioner)(seqFunc, combFunc))
244+
}
245+
246+
/**
247+
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
248+
* This function can return a different result type, U, than the type of the values in this RDD,
249+
* V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
250+
* as in scala.TraversableOnce. The former operation is used for merging values within a
251+
* partition, and the latter is used for merging values between partitions. To avoid memory
252+
* allocation, both of these functions are allowed to modify and return their first argument
253+
* instead of creating a new U.
254+
*/
255+
def aggregateByKey[U](zeroValue: U, numPartitions: Int, seqFunc: JFunction2[U, V, U],
256+
combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = {
257+
implicit val ctag: ClassTag[U] = fakeClassTag
258+
fromRDD(rdd.aggregateByKey(zeroValue, numPartitions)(seqFunc, combFunc))
259+
}
260+
261+
/**
262+
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
263+
* This function can return a different result type, U, than the type of the values in this RDD,
264+
* V. Thus, we need one operation for merging a V into a U and one operation for merging two U's.
265+
* The former operation is used for merging values within a partition, and the latter is used for
266+
* merging values between partitions. To avoid memory allocation, both of these functions are
267+
* allowed to modify and return their first argument instead of creating a new U.
268+
*/
269+
def aggregateByKey[U](zeroValue: U, seqFunc: JFunction2[U, V, U], combFunc: JFunction2[U, U, U]):
270+
JavaPairRDD[K, U] = {
271+
implicit val ctag: ClassTag[U] = fakeClassTag
272+
fromRDD(rdd.aggregateByKey(zeroValue)(seqFunc, combFunc))
273+
}
274+
231275
/**
232276
* Merge the values for each key using an associative function and a neutral "zero value" which
233277
* may be added to the result an arbitrary number of times, and must not change the result

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
118118
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
119119
}
120120

121+
/**
122+
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
123+
* This function can return a different result type, U, than the type of the values in this RDD,
124+
* V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
125+
* as in scala.TraversableOnce. The former operation is used for merging values within a
126+
* partition, and the latter is used for merging values between partitions. To avoid memory
127+
* allocation, both of these functions are allowed to modify and return their first argument
128+
* instead of creating a new U.
129+
*/
130+
def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U,
131+
combOp: (U, U) => U): RDD[(K, U)] = {
132+
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
133+
val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
134+
val zeroArray = new Array[Byte](zeroBuffer.limit)
135+
zeroBuffer.get(zeroArray)
136+
137+
lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
138+
def createZero() = cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray))
139+
140+
combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner)
141+
}
142+
143+
/**
144+
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
145+
* This function can return a different result type, U, than the type of the values in this RDD,
146+
* V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
147+
* as in scala.TraversableOnce. The former operation is used for merging values within a
148+
* partition, and the latter is used for merging values between partitions. To avoid memory
149+
* allocation, both of these functions are allowed to modify and return their first argument
150+
* instead of creating a new U.
151+
*/
152+
def aggregateByKey[U: ClassTag](zeroValue: U, numPartitions: Int)(seqOp: (U, V) => U,
153+
combOp: (U, U) => U): RDD[(K, U)] = {
154+
aggregateByKey(zeroValue, new HashPartitioner(numPartitions))(seqOp, combOp)
155+
}
156+
157+
/**
158+
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
159+
* This function can return a different result type, U, than the type of the values in this RDD,
160+
* V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
161+
* as in scala.TraversableOnce. The former operation is used for merging values within a
162+
* partition, and the latter is used for merging values between partitions. To avoid memory
163+
* allocation, both of these functions are allowed to modify and return their first argument
164+
* instead of creating a new U.
165+
*/
166+
def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U,
167+
combOp: (U, U) => U): RDD[(K, U)] = {
168+
aggregateByKey(zeroValue, defaultPartitioner(self))(seqOp, combOp)
169+
}
170+
121171
/**
122172
* Merge the values for each key using an associative function and a neutral "zero value" which
123173
* may be added to the result an arbitrary number of times, and must not change the result

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,37 @@ public Integer call(Integer a, Integer b) {
317317
Assert.assertEquals(33, sum);
318318
}
319319

320+
@Test
321+
public void aggregateByKey() {
322+
JavaPairRDD<Integer, Integer> pairs = sc.parallelizePairs(
323+
Arrays.asList(
324+
new Tuple2<Integer, Integer>(1, 1),
325+
new Tuple2<Integer, Integer>(1, 1),
326+
new Tuple2<Integer, Integer>(3, 2),
327+
new Tuple2<Integer, Integer>(5, 1),
328+
new Tuple2<Integer, Integer>(5, 3)), 2);
329+
330+
Map<Integer, Set<Integer>> sets = pairs.aggregateByKey(new HashSet<Integer>(),
331+
new Function2<Set<Integer>, Integer, Set<Integer>>() {
332+
@Override
333+
public Set<Integer> call(Set<Integer> a, Integer b) {
334+
a.add(b);
335+
return a;
336+
}
337+
},
338+
new Function2<Set<Integer>, Set<Integer>, Set<Integer>>() {
339+
@Override
340+
public Set<Integer> call(Set<Integer> a, Set<Integer> b) {
341+
a.addAll(b);
342+
return a;
343+
}
344+
}).collectAsMap();
345+
Assert.assertEquals(3, sets.size());
346+
Assert.assertEquals(new HashSet<Integer>(Arrays.asList(1)), sets.get(1));
347+
Assert.assertEquals(new HashSet<Integer>(Arrays.asList(2)), sets.get(3));
348+
Assert.assertEquals(new HashSet<Integer>(Arrays.asList(1, 3)), sets.get(5));
349+
}
350+
320351
@SuppressWarnings("unchecked")
321352
@Test
322353
public void foldByKey() {

core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ import org.apache.spark.SparkContext._
3030
import org.apache.spark.{Partitioner, SharedSparkContext}
3131

3232
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
33+
test("aggregateByKey") {
34+
val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2)
35+
36+
val sets = pairs.aggregateByKey(new HashSet[Int]())(_ += _, _ ++= _).collect()
37+
assert(sets.size === 3)
38+
val valuesFor1 = sets.find(_._1 == 1).get._2
39+
assert(valuesFor1.toList.sorted === List(1))
40+
val valuesFor3 = sets.find(_._1 == 3).get._2
41+
assert(valuesFor3.toList.sorted === List(2))
42+
val valuesFor5 = sets.find(_._1 == 5).get._2
43+
assert(valuesFor5.toList.sorted === List(1, 3))
44+
}
45+
3346
test("groupByKey") {
3447
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
3548
val groups = pairs.groupByKey().collect()

docs/programming-guide.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,10 @@ for details.
890890
<td> <b>reduceByKey</b>(<i>func</i>, [<i>numTasks</i>]) </td>
891891
<td> When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td>
892892
</tr>
893+
<tr>
894+
<td> <b>aggregateByKey</b>(<i>zeroValue</i>)(<i>seqOp</i>, <i>combOp</i>, [<i>numTasks</i>]) </td>
895+
<td> When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td>
896+
</tr>
893897
<tr>
894898
<td> <b>sortByKey</b>([<i>ascending</i>], [<i>numTasks</i>]) </td>
895899
<td> When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean <code>ascending</code> argument.</td>

project/MimaExcludes.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ object MimaExcludes {
5252
ProblemFilters.exclude[MissingMethodProblem](
5353
"org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"),
5454
ProblemFilters.exclude[MissingMethodProblem](
55-
"org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1")
55+
"org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"),
56+
ProblemFilters.exclude[MissingMethodProblem](
57+
"org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$"
58+
+ "createZero$1")
5659
) ++
5760
Seq( // Ignore some private methods in ALS.
5861
ProblemFilters.exclude[MissingMethodProblem](

python/pyspark/rdd.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1178,6 +1178,20 @@ def _mergeCombiners(iterator):
11781178
combiners[k] = mergeCombiners(combiners[k], v)
11791179
return combiners.iteritems()
11801180
return shuffled.mapPartitions(_mergeCombiners)
1181+
1182+
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
1183+
"""
1184+
Aggregate the values of each key, using given combine functions and a neutral "zero value".
1185+
This function can return a different result type, U, than the type of the values in this RDD,
1186+
V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
1187+
The former operation is used for merging values within a partition, and the latter is used
1188+
for merging values between partitions. To avoid memory allocation, both of these functions are
1189+
allowed to modify and return their first argument instead of creating a new U.
1190+
"""
1191+
def createZero():
1192+
return copy.deepcopy(zeroValue)
1193+
1194+
return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions)
11811195

11821196
def foldByKey(self, zeroValue, func, numPartitions=None):
11831197
"""
@@ -1190,7 +1204,10 @@ def foldByKey(self, zeroValue, func, numPartitions=None):
11901204
>>> rdd.foldByKey(0, add).collect()
11911205
[('a', 2), ('b', 1)]
11921206
"""
1193-
return self.combineByKey(lambda v: func(zeroValue, v), func, func, numPartitions)
1207+
def createZero():
1208+
return copy.deepcopy(zeroValue)
1209+
1210+
return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
11941211

11951212

11961213
# TODO: support variant with custom partitioner

python/pyspark/tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,21 @@ def test_deleting_input_files(self):
188188
os.unlink(tempFile.name)
189189
self.assertRaises(Exception, lambda: filtered_data.count())
190190

191+
def testAggregateByKey(self):
192+
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
193+
def seqOp(x, y):
194+
x.add(y)
195+
return x
196+
197+
def combOp(x, y):
198+
x |= y
199+
return x
200+
201+
sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect())
202+
self.assertEqual(3, len(sets))
203+
self.assertEqual(set([1]), sets[1])
204+
self.assertEqual(set([2]), sets[3])
205+
self.assertEqual(set([1, 3]), sets[5])
191206

192207
class TestIO(PySparkTestCase):
193208

0 commit comments

Comments
 (0)