Skip to content

Commit db436e3

Browse files
daviesJoshRosen
authored andcommitted
[SPARK-2871] [PySpark] add key argument for max(), min() and top(n)
RDD.max(key=None) param key: A function used to generate key for comparing >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0]) >>> rdd.max() 43.0 >>> rdd.max(key=str) 5.0 RDD.min(key=None) Find the minimum item in this RDD. param key: A function used to generate key for comparing >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0]) >>> rdd.min() 2.0 >>> rdd.min(key=str) 10.0 RDD.top(num, key=None) Get the top N elements from a RDD. Note: It returns the list sorted in descending order. >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) [12] >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2) [6, 5] >>> sc.parallelize([10, 4, 2, 12, 3]).top(3, key=str) [4, 3, 2] Author: Davies Liu <[email protected]> Closes apache#2094 from davies/cmp and squashes the following commits: ccbaf25 [Davies Liu] add `key` to top() ad7e374 [Davies Liu] fix tests 2f63512 [Davies Liu] change `comp` to `key` in min/max dd91e08 [Davies Liu] add `comp` argument for RDD.max() and RDD.min()
1 parent 3519b5e commit db436e3

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

python/pyspark/rdd.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -810,23 +810,37 @@ def func(iterator):
810810

811811
return self.mapPartitions(func).fold(zeroValue, combOp)
812812

813-
def max(self):
813+
def max(self, key=None):
814814
"""
815815
Find the maximum item in this RDD.
816816
817-
>>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).max()
817+
@param key: A function used to generate key for comparing
818+
819+
>>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0])
820+
>>> rdd.max()
818821
43.0
822+
>>> rdd.max(key=str)
823+
5.0
819824
"""
820-
return self.reduce(max)
825+
if key is None:
826+
return self.reduce(max)
827+
return self.reduce(lambda a, b: max(a, b, key=key))
821828

822-
def min(self):
829+
def min(self, key=None):
823830
"""
824831
Find the minimum item in this RDD.
825832
826-
>>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min()
827-
1.0
833+
@param key: A function used to generate key for comparing
834+
835+
>>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0])
836+
>>> rdd.min()
837+
2.0
838+
>>> rdd.min(key=str)
839+
10.0
828840
"""
829-
return self.reduce(min)
841+
if key is None:
842+
return self.reduce(min)
843+
return self.reduce(lambda a, b: min(a, b, key=key))
830844

831845
def sum(self):
832846
"""
@@ -924,7 +938,7 @@ def mergeMaps(m1, m2):
924938
return m1
925939
return self.mapPartitions(countPartition).reduce(mergeMaps)
926940

927-
def top(self, num):
941+
def top(self, num, key=None):
928942
"""
929943
Get the top N elements from a RDD.
930944
@@ -933,20 +947,16 @@ def top(self, num):
933947
[12]
934948
>>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2)
935949
[6, 5]
950+
>>> sc.parallelize([10, 4, 2, 12, 3]).top(3, key=str)
951+
[4, 3, 2]
936952
"""
937953
def topIterator(iterator):
938-
q = []
939-
for k in iterator:
940-
if len(q) < num:
941-
heapq.heappush(q, k)
942-
else:
943-
heapq.heappushpop(q, k)
944-
yield q
954+
yield heapq.nlargest(num, iterator, key=key)
945955

946956
def merge(a, b):
947-
return next(topIterator(a + b))
957+
return heapq.nlargest(num, a + b, key=key)
948958

949-
return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)
959+
return self.mapPartitions(topIterator).reduce(merge)
950960

951961
def takeOrdered(self, num, key=None):
952962
"""

0 commit comments

Comments
 (0)