@@ -45,13 +45,41 @@ def get_used_memory():
4545
4646
4747class Merger (object ):
48+ """
49+ merge shuffled data together by combinator
50+ """
51+ def merge (self , iterator ):
52+ raise NotImplementedError
53+
54+ def iteritems (self ):
55+ raise NotImplementedError
56+
57+
58+ class MapMerger (Merger ):
59+ """
60+ In memory merger based on map
61+ """
62+ def __init__ (self , combiner ):
63+ self .combiner = combiner
64+ self .data = {}
65+
66+ def merge (self , iterator ):
67+ d , comb = self .data , self .combiner
68+ for k ,v in iter (iterator ):
69+ d [k ] = comb (d [k ], v ) if k in d else v
70+
71+ def iteritems (self ):
72+ return self .data .iteritems ()
73+
74+
75+ class ExternalHashMapMerger (Merger ):
4876
4977 """
5078 External merger will dump the aggregated data into disks when memory usage
5179 is above the limit, then merge them together.
5280
5381 >>> combiner = lambda x, y:x+y
54- >>> merger = Merger (combiner, 10)
82+ >>> merger = ExternalHashMapMerger (combiner, 10)
5583 >>> N = 10000
5684 >>> merger.merge(zip(xrange(N), xrange(N)) * 10)
5785 >>> assert merger.spills > 0
@@ -63,16 +91,16 @@ class Merger(object):
6391 PARTITIONS = 64
6492 BATCH = 10000
6593
66- def __init__ (self , combiner , memory_limit = 512 , path = "/tmp/pysparki /merge" ,
67- serializer = None , batch_size = 1024 , scale = 1 ):
94+ def __init__ (self , combiner , memory_limit = 512 , path = "/tmp/pyspark /merge" ,
95+ serializer = None , scale = 1 ):
6896 self .combiner = combiner
69- self .path = os .path .join (path , str (os .getpid ()))
7097 self .memory_limit = memory_limit
71- self .serializer = serializer or BatchedSerializer (AutoSerializer (), batch_size )
98+ self .path = os .path .join (path , str (os .getpid ()))
99+ self .serializer = serializer or BatchedSerializer (AutoSerializer (), 1024 )
100+ self .scale = scale
72101 self .data = {}
73102 self .pdata = []
74103 self .spills = 0
75- self .scale = scale
76104
77105 @property
78106 def used_memory (self ):
@@ -94,7 +122,7 @@ def merge(self, iterator, check=True):
94122 continue
95123
96124 c += 1
97- if c % self . BATCH == 0 and self .used_memory > self .memory_limit :
125+ if c % batch == 0 and self .used_memory > self .memory_limit :
98126 self ._first_spill ()
99127 self ._partitioned_merge (iterator , self .next_limit )
100128 break
@@ -158,7 +186,7 @@ def _external_items(self):
158186 for j in range (self .spills ):
159187 p = os .path .join (self .path , str (j ), str (i ))
160188 self .merge (self .serializer .load_stream (open (p )), check = False )
161-
189+
162190 if j > 0 and self .used_memory > hard_limit and j < self .spills - 1 :
163191 self .data .clear () # will read from disk again
164192 for v in self ._recursive_merged_items (i ):
@@ -178,12 +206,12 @@ def _recursive_merged_items(self, start):
178206 self ._spill ()
179207
180208 for i in range (start , self .PARTITIONS ):
181- m = Merger (self .combiner , self .memory_limit ,
209+ m = ExternalHashMapMerger (self .combiner , self .memory_limit ,
182210 os .path .join (self .path , 'merge' , str (i )),
183211 self .serializer , scale = self .scale * self .PARTITIONS )
184- m .pdata = [{} for x in range (self .PARTITIONS )]
212+ m .pdata = [{} for _ in range (self .PARTITIONS )]
185213 limit = self .next_limit
186-
214+
187215 for j in range (self .spills ):
188216 p = os .path .join (self .path , str (j ), str (i ))
189217 m ._partitioned_merge (self .serializer .load_stream (open (p )), 0 )
@@ -193,7 +221,7 @@ def _recursive_merged_items(self, start):
193221
194222 for v in m ._external_items ():
195223 yield v
196-
224+
197225 shutil .rmtree (self .path , True )
198226
199227
0 commit comments