|
29 | 29 | from tempfile import NamedTemporaryFile |
30 | 30 | from threading import Thread |
31 | 31 | import warnings |
32 | | -from heapq import heappush, heappop, heappushpop |
| 32 | +import heapq |
33 | 33 |
|
34 | 34 | from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ |
35 | 35 | BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long |
|
41 | 41 |
|
42 | 42 | from py4j.java_collections import ListConverter, MapConverter |
43 | 43 |
|
44 | | - |
45 | 44 | __all__ = ["RDD"] |
46 | 45 |
|
| 46 | + |
47 | 47 | def _extract_concise_traceback(): |
48 | 48 | """ |
49 | 49 | This function returns the traceback info for a callsite, returns a dict |
@@ -91,6 +91,73 @@ def __exit__(self, type, value, tb): |
91 | 91 | if _spark_stack_depth == 0: |
92 | 92 | self._context._jsc.setCallSite(None) |
93 | 93 |
|
| 94 | +class MaxHeapQ(object): |
| 95 | + """ |
| 96 | + An implementation of MaxHeap. |
| 97 | + >>> import pyspark.rdd |
| 98 | + >>> heap = pyspark.rdd.MaxHeapQ(5) |
| 99 | + >>> [heap.insert(i) for i in range(10)] |
| 100 | + [None, None, None, None, None, None, None, None, None, None] |
| 101 | + >>> sorted(heap.getElements()) |
| 102 | + [0, 1, 2, 3, 4] |
| 103 | + >>> heap = pyspark.rdd.MaxHeapQ(5) |
| 104 | + >>> [heap.insert(i) for i in range(9, -1, -1)] |
| 105 | + [None, None, None, None, None, None, None, None, None, None] |
| 106 | + >>> sorted(heap.getElements()) |
| 107 | + [0, 1, 2, 3, 4] |
| 108 | + >>> heap = pyspark.rdd.MaxHeapQ(1) |
| 109 | + >>> [heap.insert(i) for i in range(9, -1, -1)] |
| 110 | + [None, None, None, None, None, None, None, None, None, None] |
| 111 | + >>> heap.getElements() |
| 112 | + [0] |
| 113 | + """ |
| 114 | + |
| 115 | + def __init__(self, maxsize): |
| 116 | + # we start from q[1], this makes calculating children as trivial as 2 * k |
| 117 | + self.q = [0] |
| 118 | + self.maxsize = maxsize |
| 119 | + |
| 120 | + def _swim(self, k): |
| 121 | + while (k > 1) and (self.q[k/2] < self.q[k]): |
| 122 | + self._swap(k, k/2) |
| 123 | + k = k/2 |
| 124 | + |
| 125 | + def _swap(self, i, j): |
| 126 | + t = self.q[i] |
| 127 | + self.q[i] = self.q[j] |
| 128 | + self.q[j] = t |
| 129 | + |
| 130 | + def _sink(self, k): |
| 131 | + N = self.size() |
| 132 | + while 2 * k <= N: |
| 133 | + j = 2 * k |
| 134 | + # Here we test if both children are greater than parent |
| 135 | + # if not swap with larger one. |
| 136 | + if j < N and self.q[j] < self.q[j + 1]: |
| 137 | + j = j + 1 |
| 138 | + if(self.q[k] > self.q[j]): |
| 139 | + break |
| 140 | + self._swap(k, j) |
| 141 | + k = j |
| 142 | + |
| 143 | + def size(self): |
| 144 | + return len(self.q) - 1 |
| 145 | + |
| 146 | + def insert(self, value): |
| 147 | + if (self.size()) < self.maxsize: |
| 148 | + self.q.append(value) |
| 149 | + self._swim(self.size()) |
| 150 | + else: |
| 151 | + self._replaceRoot(value) |
| 152 | + |
| 153 | + def getElements(self): |
| 154 | + return self.q[1:] |
| 155 | + |
| 156 | + def _replaceRoot(self, value): |
| 157 | + if(self.q[1] > value): |
| 158 | + self.q[1] = value |
| 159 | + self._sink(1) |
| 160 | + |
94 | 161 | class RDD(object): |
95 | 162 | """ |
96 | 163 | A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. |
@@ -696,23 +763,53 @@ def top(self, num): |
696 | 763 | Note: It returns the list sorted in descending order. |
697 | 764 | >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) |
698 | 765 | [12] |
699 | | - >>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2) |
| 766 | + >>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2) |
700 | 767 | [6, 5] |
701 | 768 | """ |
702 | 769 | def topIterator(iterator): |
703 | 770 | q = [] |
704 | 771 | for k in iterator: |
705 | 772 | if len(q) < num: |
706 | | - heappush(q, k) |
| 773 | + heapq.heappush(q, k) |
707 | 774 | else: |
708 | | - heappushpop(q, k) |
| 775 | + heapq.heappushpop(q, k) |
709 | 776 | yield q |
710 | 777 |
|
711 | 778 | def merge(a, b): |
712 | 779 | return next(topIterator(a + b)) |
713 | 780 |
|
714 | 781 | return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True) |
715 | 782 |
|
| 783 | + def takeOrdered(self, num, key=None): |
| 784 | + """ |
| 785 | + Get the N elements from a RDD ordered in ascending order or as specified |
| 786 | + by the optional key function. |
| 787 | +
|
| 788 | + >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) |
| 789 | + [1, 2, 3, 4, 5, 6] |
| 790 | + >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x) |
| 791 | + [10, 9, 7, 6, 5, 4] |
| 792 | + """ |
| 793 | + |
| 794 | + def topNKeyedElems(iterator, key_=None): |
| 795 | + q = MaxHeapQ(num) |
| 796 | + for k in iterator: |
| 797 | + if key_ != None: |
| 798 | + k = (key_(k), k) |
| 799 | + q.insert(k) |
| 800 | + yield q.getElements() |
| 801 | + |
| 802 | + def unKey(x, key_=None): |
| 803 | + if key_ != None: |
| 804 | + x = [i[1] for i in x] |
| 805 | + return x |
| 806 | + |
| 807 | + def merge(a, b): |
| 808 | + return next(topNKeyedElems(a + b)) |
| 809 | + result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge) |
| 810 | + return sorted(unKey(result, key), key=key) |
| 811 | + |
| 812 | + |
716 | 813 | def take(self, num): |
717 | 814 | """ |
718 | 815 | Take the first num elements of the RDD. |
|
0 commit comments