2727 import psutil
2828
2929 def get_used_memory ():
30- """ return the used memory in MB """
30+ """ Return the used memory in MB """
3131 self = psutil .Process (os .getpid ())
3232 return self .memory_info ().rss >> 20
3333
3434except ImportError :
3535
3636 def get_used_memory ():
37- """ return the used memory in MB """
37+ """ Return the used memory in MB """
3838 if platform .system () == 'Linux' :
3939 for line in open ('/proc/self/status' ):
4040 if line .startswith ('VmRSS:' ):
4141 return int (line .split ()[1 ]) >> 10
4242 else :
43- warnings .warn ("please install psutil to have better "
43+ warnings .warn ("Please install psutil to have better "
4444 "support with spilling" )
4545 if platform .system () == "Darwin" :
4646 import resource
@@ -80,22 +80,22 @@ def __init__(self, combiner):
8080class Merger (object ):
8181
8282 """
83- merge shuffled data together by aggregator
83+ Merge shuffled data together by aggregator
8484 """
8585
8686 def __init__ (self , aggregator ):
8787 self .agg = aggregator
8888
8989 def mergeValues (self , iterator ):
90- """ combine the items by creator and combiner """
90+ """ Combine the items by creator and combiner """
9191 raise NotImplementedError
9292
9393 def mergeCombiners (self , iterator ):
94- """ merge the combined items by mergeCombiner """
94+ """ Merge the combined items by mergeCombiner """
9595 raise NotImplementedError
9696
9797 def iteritems (self ):
98- """ return the merged items ad iterator """
98+ """ Return the merged items ad iterator """
9999 raise NotImplementedError
100100
101101
@@ -110,22 +110,22 @@ def __init__(self, aggregator):
110110 self .data = {}
111111
112112 def mergeValues (self , iterator ):
113- """ combine the items by creator and combiner """
113+ """ Combine the items by creator and combiner """
114114 # speed up attributes lookup
115115 d , creator = self .data , self .agg .createCombiner
116116 comb = self .agg .mergeValue
117117 for k , v in iterator :
118118 d [k ] = comb (d [k ], v ) if k in d else creator (v )
119119
120120 def mergeCombiners (self , iterator ):
121- """ merge the combined items by mergeCombiner """
121+ """ Merge the combined items by mergeCombiner """
122122 # speed up attributes lookup
123123 d , comb = self .data , self .agg .mergeCombiners
124124 for k , v in iterator :
125125 d [k ] = comb (d [k ], v ) if k in d else v
126126
127127 def iteritems (self ):
128- """ return the merged items ad iterator """
128+ """ Return the merged items ad iterator """
129129 return self .data .iteritems ()
130130
131131
@@ -182,6 +182,8 @@ class ExternalMerger(Merger):
182182 499950000
183183 """
184184
185+ TOTAL_PARTITIONS = 4096
186+
185187 def __init__ (self , aggregator , memory_limit = 512 , serializer = None ,
186188 localdirs = None , scale = 1 , partitions = 64 , batch = 10000 ):
187189 Merger .__init__ (self , aggregator )
@@ -198,32 +200,32 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
198200 self .scale = scale
199201 # unpartitioned merged data
200202 self .data = {}
201- # partitioned merged data
203+ # partitioned merged data, list of dicts
202204 self .pdata = []
203205 # number of chunks dumped into disks
204206 self .spills = 0
205207
206208 def _get_dirs (self ):
207- """ get all the directories """
208- path = os .environ .get ("SPARK_LOCAL_DIR" , "/tmp/spark " )
209+ """ Get all the directories """
210+ path = os .environ .get ("SPARK_LOCAL_DIR" , "/tmp" )
209211 dirs = path .split ("," )
210212 return [os .path .join (d , "python" , str (os .getpid ()), str (id (self )))
211213 for d in dirs ]
212214
213215 def _get_spill_dir (self , n ):
214- """ choose one directory for spill by number n """
216+ """ Choose one directory for spill by number n """
215217 return os .path .join (self .localdirs [n % len (self .localdirs )], str (n ))
216218
217219 def next_limit (self ):
218220 """
219- return the next memory limit. If the memory is not released
221+ Return the next memory limit. If the memory is not released
220222 after spilling, it will dump the data only when the used memory
221223 starts to increase.
222224 """
223225 return max (self .memory_limit , get_used_memory () * 1.05 )
224226
225227 def mergeValues (self , iterator ):
226- """ combine the items by creator and combiner """
228+ """ Combine the items by creator and combiner """
227229 iterator = iter (iterator )
228230 # speedup attribute lookup
229231 creator , comb = self .agg .createCombiner , self .agg .mergeValue
@@ -239,11 +241,11 @@ def mergeValues(self, iterator):
239241 break
240242
241243 def _partition (self , key ):
242- """ return the partition for key """
244+ """ Return the partition for key """
243245 return (hash (key ) / self .scale ) % self .partitions
244246
245247 def _partitioned_mergeValues (self , iterator , limit = 0 ):
246- """ partition the items by key, then combine them """
248+ """ Partition the items by key, then combine them """
247249 # speedup attribute lookup
248250 creator , comb = self .agg .createCombiner , self .agg .mergeValue
249251 c , pdata , hfun , batch = 0 , self .pdata , self ._partition , self .batch
@@ -260,7 +262,7 @@ def _partitioned_mergeValues(self, iterator, limit=0):
260262 limit = self .next_limit ()
261263
262264 def mergeCombiners (self , iterator , check = True ):
263- """ merge (K,V) pair by mergeCombiner """
265+ """ Merge (K,V) pair by mergeCombiner """
264266 iterator = iter (iterator )
265267 # speedup attribute lookup
266268 d , comb , batch = self .data , self .agg .mergeCombiners , self .batch
@@ -277,7 +279,7 @@ def mergeCombiners(self, iterator, check=True):
277279 break
278280
279281 def _partitioned_mergeCombiners (self , iterator , limit = 0 ):
280- """ partition the items by key, then merge them """
282+ """ Partition the items by key, then merge them """
281283 comb , pdata = self .agg .mergeCombiners , self .pdata
282284 c , hfun = 0 , self ._partition
283285 for k , v in iterator :
@@ -293,7 +295,7 @@ def _partitioned_mergeCombiners(self, iterator, limit=0):
293295
294296 def _first_spill (self ):
295297 """
296- dump all the data into disks partition by partition.
298+ Dump all the data into disks partition by partition.
297299
298300 The data has not been partitioned, it will iterator the
299301 dataset once, write them into different files, has no
@@ -337,13 +339,13 @@ def _spill(self):
337339 self .spills += 1
338340
339341 def iteritems (self ):
340- """ return all merged items as iterator """
342+ """ Return all merged items as iterator """
341343 if not self .pdata and not self .spills :
342344 return self .data .iteritems ()
343345 return self ._external_items ()
344346
345347 def _external_items (self ):
346- """ return all partitioned items as iterator """
348+ """ Return all partitioned items as iterator """
347349 assert not self .data
348350 if any (self .pdata ):
349351 self ._spill ()
@@ -359,7 +361,10 @@ def _external_items(self):
359361 self .mergeCombiners (self .serializer .load_stream (open (p )),
360362 False )
361363
362- if get_used_memory () > hard_limit and j < self .spills - 1 :
364+ # limit the total partitions
365+ if (self .scale * self .partitions < self .TOTAL_PARTITIONS )
366+ and j < self .spills - 1
367+ and get_used_memory () > hard_limit ):
363368 self .data .clear () # will read from disk again
364369 for v in self ._recursive_merged_items (i ):
365370 yield v
@@ -368,11 +373,17 @@ def _external_items(self):
368373 for v in self .data .iteritems ():
369374 yield v
370375 self .data .clear ()
376+
377+ # remove the merged partition
378+ for j in range (self .spills ):
379+ path = self ._get_spill_dir (j )
380+ os .remove (os .path .join (path , str (i )))
381+
371382 finally :
372383 self ._cleanup ()
373384
374385 def _cleanup (self ):
375- """ clean up all the files in disks """
386+ """ Clean up all the files in disks """
376387 for d in self .localdirs :
377388 shutil .rmtree (d , True )
378389
@@ -410,6 +421,11 @@ def _recursive_merged_items(self, start):
410421 for v in m ._external_items ():
411422 yield v
412423
424+ # remove the merged partition
425+ for j in range (self .spills ):
426+ path = self ._get_spill_dir (j )
427+ os .remove (os .path .join (path , str (i )))
428+
413429
414430if __name__ == "__main__" :
415431 import doctest
0 commit comments