Skip to content

Commit c1ea3af

Browse files
ScrapCodesmateiz
authored andcommitted
Spark 1162 Implemented takeOrdered in pyspark.
Since python does not have a library for max heap and usual tricks like inverting values etc.. does not work for all cases. We have our own implementation of max heap. Author: Prashant Sharma <[email protected]> Closes #97 from ScrapCodes/SPARK-1162/pyspark-top-takeOrdered2 and squashes the following commits: 35f86ba [Prashant Sharma] code review 2b1124d [Prashant Sharma] fixed tests e8a08e2 [Prashant Sharma] Code review comments. 49e6ba7 [Prashant Sharma] SPARK-1162 added takeOrdered to pyspark
1 parent 5d1feda commit c1ea3af

File tree

1 file changed

+102
-5
lines changed

1 file changed

+102
-5
lines changed

python/pyspark/rdd.py

Lines changed: 102 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from tempfile import NamedTemporaryFile
3030
from threading import Thread
3131
import warnings
32-
from heapq import heappush, heappop, heappushpop
32+
import heapq
3333

3434
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
3535
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -41,9 +41,9 @@
4141

4242
from py4j.java_collections import ListConverter, MapConverter
4343

44-
4544
__all__ = ["RDD"]
4645

46+
4747
def _extract_concise_traceback():
4848
"""
4949
This function returns the traceback info for a callsite, returns a dict
@@ -91,6 +91,73 @@ def __exit__(self, type, value, tb):
9191
if _spark_stack_depth == 0:
9292
self._context._jsc.setCallSite(None)
9393

94+
class MaxHeapQ(object):
95+
"""
96+
An implementation of MaxHeap.
97+
>>> import pyspark.rdd
98+
>>> heap = pyspark.rdd.MaxHeapQ(5)
99+
>>> [heap.insert(i) for i in range(10)]
100+
[None, None, None, None, None, None, None, None, None, None]
101+
>>> sorted(heap.getElements())
102+
[0, 1, 2, 3, 4]
103+
>>> heap = pyspark.rdd.MaxHeapQ(5)
104+
>>> [heap.insert(i) for i in range(9, -1, -1)]
105+
[None, None, None, None, None, None, None, None, None, None]
106+
>>> sorted(heap.getElements())
107+
[0, 1, 2, 3, 4]
108+
>>> heap = pyspark.rdd.MaxHeapQ(1)
109+
>>> [heap.insert(i) for i in range(9, -1, -1)]
110+
[None, None, None, None, None, None, None, None, None, None]
111+
>>> heap.getElements()
112+
[0]
113+
"""
114+
115+
def __init__(self, maxsize):
116+
# we start from q[1], this makes calculating children as trivial as 2 * k
117+
self.q = [0]
118+
self.maxsize = maxsize
119+
120+
def _swim(self, k):
121+
while (k > 1) and (self.q[k/2] < self.q[k]):
122+
self._swap(k, k/2)
123+
k = k/2
124+
125+
def _swap(self, i, j):
126+
t = self.q[i]
127+
self.q[i] = self.q[j]
128+
self.q[j] = t
129+
130+
def _sink(self, k):
131+
N = self.size()
132+
while 2 * k <= N:
133+
j = 2 * k
134+
# Here we test if both children are greater than parent
135+
# if not swap with larger one.
136+
if j < N and self.q[j] < self.q[j + 1]:
137+
j = j + 1
138+
if(self.q[k] > self.q[j]):
139+
break
140+
self._swap(k, j)
141+
k = j
142+
143+
def size(self):
144+
return len(self.q) - 1
145+
146+
def insert(self, value):
147+
if (self.size()) < self.maxsize:
148+
self.q.append(value)
149+
self._swim(self.size())
150+
else:
151+
self._replaceRoot(value)
152+
153+
def getElements(self):
154+
return self.q[1:]
155+
156+
def _replaceRoot(self, value):
157+
if(self.q[1] > value):
158+
self.q[1] = value
159+
self._sink(1)
160+
94161
class RDD(object):
95162
"""
96163
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
@@ -696,23 +763,53 @@ def top(self, num):
696763
Note: It returns the list sorted in descending order.
697764
>>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
698765
[12]
699-
>>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2)
766+
>>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2)
700767
[6, 5]
701768
"""
702769
def topIterator(iterator):
703770
q = []
704771
for k in iterator:
705772
if len(q) < num:
706-
heappush(q, k)
773+
heapq.heappush(q, k)
707774
else:
708-
heappushpop(q, k)
775+
heapq.heappushpop(q, k)
709776
yield q
710777

711778
def merge(a, b):
712779
return next(topIterator(a + b))
713780

714781
return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)
715782

783+
def takeOrdered(self, num, key=None):
784+
"""
785+
Get the N elements from a RDD ordered in ascending order or as specified
786+
by the optional key function.
787+
788+
>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
789+
[1, 2, 3, 4, 5, 6]
790+
>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x)
791+
[10, 9, 7, 6, 5, 4]
792+
"""
793+
794+
def topNKeyedElems(iterator, key_=None):
795+
q = MaxHeapQ(num)
796+
for k in iterator:
797+
if key_ != None:
798+
k = (key_(k), k)
799+
q.insert(k)
800+
yield q.getElements()
801+
802+
def unKey(x, key_=None):
803+
if key_ != None:
804+
x = [i[1] for i in x]
805+
return x
806+
807+
def merge(a, b):
808+
return next(topNKeyedElems(a + b))
809+
result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge)
810+
return sorted(unKey(result, key), key=key)
811+
812+
716813
def take(self, num):
717814
"""
718815
Take the first num elements of the RDD.

0 commit comments

Comments
 (0)