Skip to content

Commit e03bc37

Browse files
holdenkmateiz
authored andcommitted
SPARK-1242 Add aggregate to python rdd
Author: Holden Karau <[email protected]> Closes #139 from holdenk/add_aggregate_to_python_api and squashes the following commits: 0f39ae3 [Holden Karau] Merge in master 4879c75 [Holden Karau] CR feedback, fix issue with empty RDDs in aggregate 70b4724 [Holden Karau] Style fixes from code review 96b047b [Holden Karau] Add aggregate to python rdd
1 parent 095b518 commit e03bc37

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

python/pyspark/rdd.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def _collect_iterator_through_file(self, iterator):
599599
def reduce(self, f):
600600
"""
601601
Reduces the elements of this RDD using the specified commutative and
602-
associative binary operator.
602+
associative binary operator. Currently reduces partitions locally.
603603
604604
>>> from operator import add
605605
>>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
@@ -641,7 +641,34 @@ def func(iterator):
641641
vals = self.mapPartitions(func).collect()
642642
return reduce(op, vals, zeroValue)
643643

644-
# TODO: aggregate
644+
def aggregate(self, zeroValue, seqOp, combOp):
645+
"""
646+
Aggregate the elements of each partition, and then the results for all
647+
the partitions, using a given combine functions and a neutral "zero
648+
value."
649+
650+
The functions C{op(t1, t2)} is allowed to modify C{t1} and return it
651+
as its result value to avoid object allocation; however, it should not
652+
modify C{t2}.
653+
654+
The first function (seqOp) can return a different result type, U, than
655+
the type of this RDD. Thus, we need one operation for merging a T into an U
656+
and one operation for merging two U
657+
658+
>>> seqOp = (lambda x, y: (x[0] + y, x[1] + 1))
659+
>>> combOp = (lambda x, y: (x[0] + y[0], x[1] + y[1]))
660+
>>> sc.parallelize([1, 2, 3, 4]).aggregate((0, 0), seqOp, combOp)
661+
(10, 4)
662+
>>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
663+
(0, 0)
664+
"""
665+
def func(iterator):
666+
acc = zeroValue
667+
for obj in iterator:
668+
acc = seqOp(acc, obj)
669+
yield acc
670+
671+
return self.mapPartitions(func).fold(zeroValue, combOp)
645672

646673

647674
def max(self):

0 commit comments

Comments
 (0)