Skip to content

Commit cad91bf

Browse files
committed
call gc.collect() after data.clear() to release memory as much as
possible.
1 parent 37d71f7 commit cad91bf

File tree

1 file changed

+36
-35
lines changed

1 file changed

+36
-35
lines changed

python/pyspark/shuffle.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import platform
2121
import shutil
2222
import warnings
23+
import gc
2324

2425
from pyspark.serializers import BatchedSerializer, PickleSerializer
2526

@@ -242,7 +243,7 @@ def mergeValues(self, iterator):
242243

243244
c += 1
244245
if c % batch == 0 and get_used_memory() > self.memory_limit:
245-
self._first_spill()
246+
self._spill()
246247
self._partitioned_mergeValues(iterator, self._next_limit())
247248
break
248249

@@ -280,7 +281,7 @@ def mergeCombiners(self, iterator, check=True):
280281

281282
c += 1
282283
if c % batch == 0 and get_used_memory() > self.memory_limit:
283-
self._first_spill()
284+
self._spill()
284285
self._partitioned_mergeCombiners(iterator, self._next_limit())
285286
break
286287

@@ -299,33 +300,6 @@ def _partitioned_mergeCombiners(self, iterator, limit=0):
299300
self._spill()
300301
limit = self._next_limit()
301302

302-
def _first_spill(self):
303-
"""
304-
Dump all the data into disks partition by partition.
305-
306-
The data has not been partitioned, it will iterator the
307-
dataset once, write them into different files, has no
308-
additional memory. It only called when the memory goes
309-
above limit at the first time.
310-
"""
311-
path = self._get_spill_dir(self.spills)
312-
if not os.path.exists(path):
313-
os.makedirs(path)
314-
# open all the files for writing
315-
streams = [open(os.path.join(path, str(i)), 'w')
316-
for i in range(self.partitions)]
317-
318-
for k, v in self.data.iteritems():
319-
h = self._partition(k)
320-
# put one item in batch, make it compatitable with load_stream
321-
# it will increase the memory if dump them in batch
322-
self.serializer.dump_stream([(k, v)], streams[h])
323-
for s in streams:
324-
s.close()
325-
self.data.clear()
326-
self.pdata = [{} for i in range(self.partitions)]
327-
self.spills += 1
328-
329303
def _spill(self):
330304
"""
331305
dump already partitioned data into disks.
@@ -336,13 +310,38 @@ def _spill(self):
336310
if not os.path.exists(path):
337311
os.makedirs(path)
338312

339-
for i in range(self.partitions):
340-
p = os.path.join(path, str(i))
341-
with open(p, "w") as f:
342-
# dump items in batch
343-
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
344-
self.pdata[i].clear()
313+
if not self.pdata:
314+
# The data has not been partitioned, it will iterator the
315+
# dataset once, write them into different files, has no
316+
# additional memory. It only called when the memory goes
317+
# above limit at the first time.
318+
319+
# open all the files for writing
320+
streams = [open(os.path.join(path, str(i)), 'w')
321+
for i in range(self.partitions)]
322+
323+
for k, v in self.data.iteritems():
324+
h = self._partition(k)
325+
# put one item in batch, make it compatitable with load_stream
326+
# it will increase the memory if dump them in batch
327+
self.serializer.dump_stream([(k, v)], streams[h])
328+
329+
for s in streams:
330+
s.close()
331+
332+
self.data.clear()
333+
self.pdata = [{} for i in range(self.partitions)]
334+
335+
else:
336+
for i in range(self.partitions):
337+
p = os.path.join(path, str(i))
338+
with open(p, "w") as f:
339+
# dump items in batch
340+
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
341+
self.pdata[i].clear()
342+
345343
self.spills += 1
344+
gc.collect() # release the memory as much as possible
346345

347346
def iteritems(self):
348347
""" Return all merged items as iterator """
@@ -372,13 +371,15 @@ def _external_items(self):
372371
and j < self.spills - 1
373372
and get_used_memory() > hard_limit):
374373
self.data.clear() # will read from disk again
374+
gc.collect() # release the memory as much as possible
375375
for v in self._recursive_merged_items(i):
376376
yield v
377377
return
378378

379379
for v in self.data.iteritems():
380380
yield v
381381
self.data.clear()
382+
gc.collect()
382383

383384
# remove the merged partition
384385
for j in range(self.spills):

0 commit comments

Comments
 (0)