Skip to content

Commit 6edbd1f

Browse files
committed
Hash based disk spilling aggregation
1 parent d0ea496 commit 6edbd1f

File tree

2 files changed

+154
-10
lines changed

2 files changed

+154
-10
lines changed

python/pyspark/rdd.py

Lines changed: 127 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@
3232
import heapq
3333
from random import Random
3434
from math import sqrt, log
35+
import platform
36+
import resource
3537

3638
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
3739
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
38-
PickleSerializer, pack_long
40+
PickleSerializer, BatchedSerializer, AutoSerializer, pack_long
3941
from pyspark.join import python_join, python_left_outer_join, \
4042
python_right_outer_join, python_cogroup
4143
from pyspark.statcounter import StatCounter
@@ -168,6 +170,123 @@ def _replaceRoot(self, value):
168170
self._sink(1)
169171

170172

173+
class Merger(object):
174+
"""
175+
External merger will dump the aggregated data into disks when memory usage is above
176+
the limit, then merge them together.
177+
178+
>>> combiner = lambda x, y:x+y
179+
>>> merger = Merger(combiner, 10)
180+
>>> N = 10000
181+
>>> merger.merge(zip(xrange(N), xrange(N)) * 10)
182+
>>> merger.spills
183+
100
184+
>>> sum(1 for k,v in merger.iteritems())
185+
10000
186+
"""
187+
188+
PARTITIONS = 64
189+
BATCH = 1000
190+
191+
def __init__(self, combiner, memory_limit=256, path="/tmp/pyspark", serializer=None):
192+
self.combiner = combiner
193+
self.path = os.path.join(path, str(os.getpid()))
194+
self.memory_limit = memory_limit
195+
self.serializer = serializer or BatchedSerializer(AutoSerializer(), 1024)
196+
self.item_limit = None
197+
self.data = {}
198+
self.pdata = []
199+
self.spills = 0
200+
201+
def used_memory(self):
202+
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
203+
if platform.system() == 'Linux':
204+
rss >>= 10
205+
elif platform.system() == 'Darwin':
206+
rss >>= 20
207+
return rss
208+
209+
def merge(self, iterator):
210+
iterator = iter(iterator)
211+
d = self.data
212+
comb = self.combiner
213+
c = 0
214+
for k, v in iterator:
215+
if k in d:
216+
d[k] = comb(d[k], v)
217+
else:
218+
d[k] = v
219+
220+
if self.item_limit is not None:
221+
continue
222+
223+
c += 1
224+
if c % self.BATCH == 0 and self.used_memory() > self.memory_limit:
225+
self.item_limit = c
226+
self._first_spill()
227+
self._partitioned_merge(iterator)
228+
return
229+
230+
def _partitioned_merge(self, iterator):
231+
comb = self.combiner
232+
c = 0
233+
for k, v in iterator:
234+
d = self.pdata[hash(k) % self.PARTITIONS]
235+
if k in d:
236+
d[k] = comb(d[k], v)
237+
else:
238+
d[k] = v
239+
c += 1
240+
if c >= self.item_limit:
241+
self._spill()
242+
c = 0
243+
244+
def _first_spill(self):
245+
path = os.path.join(self.path, str(self.spills))
246+
if not os.path.exists(path):
247+
os.makedirs(path)
248+
streams = [open(os.path.join(path, str(i)), 'w')
249+
for i in range(self.PARTITIONS)]
250+
for k, v in self.data.iteritems():
251+
h = hash(k) % self.PARTITIONS
252+
self.serializer.dump_stream([(k, v)], streams[h])
253+
for s in streams:
254+
s.close()
255+
self.data.clear()
256+
self.pdata = [{} for i in range(self.PARTITIONS)]
257+
self.spills += 1
258+
259+
def _spill(self):
260+
path = os.path.join(self.path, str(self.spills))
261+
if not os.path.exists(path):
262+
os.makedirs(path)
263+
for i in range(self.PARTITIONS):
264+
p = os.path.join(path, str(i))
265+
with open(p, 'w') as f:
266+
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
267+
self.pdata[i].clear()
268+
self.spills += 1
269+
270+
def iteritems(self):
271+
if not self.pdata and not self.spills:
272+
return self.data.iteritems()
273+
return self._external_items()
274+
275+
def _external_items(self):
276+
for i in range(self.PARTITIONS):
277+
self.data = self.pdata[i]
278+
for j in range(self.spills):
279+
p = os.path.join(self.path, str(j), str(i))
280+
self.merge(self.serializer.load_stream(open(p)))
281+
os.remove(p)
282+
for k,v in self.data.iteritems():
283+
yield k,v
284+
self.data.clear()
285+
for i in range(self.spills):
286+
os.rmdir(os.path.join(self.path, str(i)))
287+
os.rmdir(self.path)
288+
289+
171290
class RDD(object):
172291

173292
"""
@@ -1247,15 +1366,12 @@ def combineLocally(iterator):
12471366
return combiners.iteritems()
12481367
locally_combined = self.mapPartitions(combineLocally)
12491368
shuffled = locally_combined.partitionBy(numPartitions)
1250-
1369+
1370+
executorMemory = self.ctx._jsc.sc().executorMemory()
12511371
def _mergeCombiners(iterator):
1252-
combiners = {}
1253-
for (k, v) in iterator:
1254-
if k not in combiners:
1255-
combiners[k] = v
1256-
else:
1257-
combiners[k] = mergeCombiners(combiners[k], v)
1258-
return combiners.iteritems()
1372+
merger = Merger(mergeCombiners, executorMemory * 0.7)
1373+
merger.merge(iterator)
1374+
return merger.iteritems()
12591375
return shuffled.mapPartitions(_mergeCombiners)
12601376

12611377
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
@@ -1314,7 +1430,8 @@ def mergeValue(xs, x):
13141430
return xs
13151431

13161432
def mergeCombiners(a, b):
1317-
return a + b
1433+
a.extend(b)
1434+
return a
13181435

13191436
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
13201437
numPartitions).mapValues(lambda x: ResultIterable(x))

python/pyspark/serializers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,33 @@ class MarshalSerializer(FramedSerializer):
297297
loads = marshal.loads
298298

299299

300+
class AutoSerializer(FramedSerializer):
301+
"""
302+
Choose marshal or cPickle as serialization protocol autumatically
303+
"""
304+
def __init__(self):
305+
FramedSerializer.__init__(self)
306+
self._type = None
307+
308+
def dumps(self, obj):
309+
try:
310+
if self._type is not None:
311+
raise TypeError("fallback")
312+
return 'M' + marshal.dumps(obj)
313+
except Exception:
314+
self._type = 'P'
315+
return 'P' + cPickle.dumps(obj, -1)
316+
317+
def loads(self, stream):
318+
_type = stream[0]
319+
if _type == 'M':
320+
return marshal.loads(stream[1:])
321+
elif _type == 'P':
322+
return cPickle.loads(stream[1:])
323+
else:
324+
raise ValueError("invalid sevialization type: %s" % _type)
325+
326+
300327
class UTF8Deserializer(Serializer):
301328
"""
302329
Deserializes streams written by String.getBytes.

0 commit comments

Comments
 (0)