Skip to content

Commit 400be01

Browse files
committed
address all the comments
1 parent 6178844 commit 400be01

File tree

2 files changed

+57
-36
lines changed

2 files changed

+57
-36
lines changed

python/pyspark/rdd.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,14 +1225,17 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
12251225
if numPartitions is None:
12261226
numPartitions = self._defaultReducePartitions()
12271227

1228-
# Transferring O(n) objects to Java is too expensive. Instead, we'll
1229-
# form the hash buckets in Python, transferring O(numPartitions) objects
1230-
# to Java. Each object is a (splitNumber, [objects]) pair.
1231-
# In order to void too huge objects, the objects are grouped into chunks.
1228+
# Transferring O(n) objects to Java is too expensive.
1229+
# Instead, we'll form the hash buckets in Python,
1230+
# transferring O(numPartitions) objects to Java.
1231+
# Each object is a (splitNumber, [objects]) pair.
1232+
# In order to void too huge objects, the objects are
1233+
# grouped into chunks.
12321234
outputSerializer = self.ctx._unbatched_serializer
12331235

1234-
limit = (_parse_memory(self.ctx._conf.get("spark.python.worker.memory")
1235-
or "512m") / 2)
1236+
limit = (_parse_memory(self.ctx._conf.get(
1237+
"spark.python.worker.memory", "512m") / 2)
1238+
12361239
def add_shuffle_key(split, iterator):
12371240

12381241
buckets = defaultdict(list)
@@ -1274,8 +1277,8 @@ def add_shuffle_key(split, iterator):
12741277
id(partitionFunc))
12751278
jrdd = pairRDD.partitionBy(partitioner).values()
12761279
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
1277-
# This is required so that id(partitionFunc) remains unique, even if
1278-
# partitionFunc is a lambda:
1280+
# This is required so that id(partitionFunc) remains unique,
1281+
# even if partitionFunc is a lambda:
12791282
rdd._partitionFunc = partitionFunc
12801283
return rdd
12811284

@@ -1310,8 +1313,10 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
13101313
numPartitions = self._defaultReducePartitions()
13111314

13121315
serializer = self.ctx.serializer
1313-
spill = (self.ctx._conf.get("spark.shuffle.spill") or 'True').lower() == 'true'
1314-
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory") or "512m")
1316+
spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
1317+
== 'true')
1318+
memory = (_parse_memory(self.ctx._conf.get(
1319+
"spark.python.worker.memory","512m")
13151320
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
13161321

13171322
def combineLocally(iterator):
@@ -1322,7 +1327,7 @@ def combineLocally(iterator):
13221327

13231328
locally_combined = self.mapPartitions(combineLocally)
13241329
shuffled = locally_combined.partitionBy(numPartitions)
1325-
1330+
13261331
def _mergeCombiners(iterator):
13271332
merger = ExternalMerger(agg, memory, serializer) \
13281333
if spill else InMemoryMerger(agg)

python/pyspark/shuffle.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,20 @@
2727
import psutil
2828

2929
def get_used_memory():
30-
""" return the used memory in MB """
30+
""" Return the used memory in MB """
3131
self = psutil.Process(os.getpid())
3232
return self.memory_info().rss >> 20
3333

3434
except ImportError:
3535

3636
def get_used_memory():
37-
""" return the used memory in MB """
37+
""" Return the used memory in MB """
3838
if platform.system() == 'Linux':
3939
for line in open('/proc/self/status'):
4040
if line.startswith('VmRSS:'):
4141
return int(line.split()[1]) >> 10
4242
else:
43-
warnings.warn("please install psutil to have better "
43+
warnings.warn("Please install psutil to have better "
4444
"support with spilling")
4545
if platform.system() == "Darwin":
4646
import resource
@@ -80,22 +80,22 @@ def __init__(self, combiner):
8080
class Merger(object):
8181

8282
"""
83-
merge shuffled data together by aggregator
83+
Merge shuffled data together by aggregator
8484
"""
8585

8686
def __init__(self, aggregator):
8787
self.agg = aggregator
8888

8989
def mergeValues(self, iterator):
90-
""" combine the items by creator and combiner """
90+
""" Combine the items by creator and combiner """
9191
raise NotImplementedError
9292

9393
def mergeCombiners(self, iterator):
94-
""" merge the combined items by mergeCombiner """
94+
""" Merge the combined items by mergeCombiner """
9595
raise NotImplementedError
9696

9797
def iteritems(self):
98-
""" return the merged items ad iterator """
98+
""" Return the merged items ad iterator """
9999
raise NotImplementedError
100100

101101

@@ -110,22 +110,22 @@ def __init__(self, aggregator):
110110
self.data = {}
111111

112112
def mergeValues(self, iterator):
113-
""" combine the items by creator and combiner """
113+
""" Combine the items by creator and combiner """
114114
# speed up attributes lookup
115115
d, creator = self.data, self.agg.createCombiner
116116
comb = self.agg.mergeValue
117117
for k, v in iterator:
118118
d[k] = comb(d[k], v) if k in d else creator(v)
119119

120120
def mergeCombiners(self, iterator):
121-
""" merge the combined items by mergeCombiner """
121+
""" Merge the combined items by mergeCombiner """
122122
# speed up attributes lookup
123123
d, comb = self.data, self.agg.mergeCombiners
124124
for k, v in iterator:
125125
d[k] = comb(d[k], v) if k in d else v
126126

127127
def iteritems(self):
128-
""" return the merged items ad iterator """
128+
""" Return the merged items ad iterator """
129129
return self.data.iteritems()
130130

131131

@@ -182,6 +182,8 @@ class ExternalMerger(Merger):
182182
499950000
183183
"""
184184

185+
TOTAL_PARTITIONS = 4096
186+
185187
def __init__(self, aggregator, memory_limit=512, serializer=None,
186188
localdirs=None, scale=1, partitions=64, batch=10000):
187189
Merger.__init__(self, aggregator)
@@ -198,32 +200,32 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
198200
self.scale = scale
199201
# unpartitioned merged data
200202
self.data = {}
201-
# partitioned merged data
203+
# partitioned merged data, list of dicts
202204
self.pdata = []
203205
# number of chunks dumped into disks
204206
self.spills = 0
205207

206208
def _get_dirs(self):
207-
""" get all the directories """
208-
path = os.environ.get("SPARK_LOCAL_DIR", "/tmp/spark")
209+
""" Get all the directories """
210+
path = os.environ.get("SPARK_LOCAL_DIR", "/tmp")
209211
dirs = path.split(",")
210212
return [os.path.join(d, "python", str(os.getpid()), str(id(self)))
211213
for d in dirs]
212214

213215
def _get_spill_dir(self, n):
214-
""" choose one directory for spill by number n """
216+
""" Choose one directory for spill by number n """
215217
return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))
216218

217219
def next_limit(self):
218220
"""
219-
return the next memory limit. If the memory is not released
221+
Return the next memory limit. If the memory is not released
220222
after spilling, it will dump the data only when the used memory
221223
starts to increase.
222224
"""
223225
return max(self.memory_limit, get_used_memory() * 1.05)
224226

225227
def mergeValues(self, iterator):
226-
""" combine the items by creator and combiner """
228+
""" Combine the items by creator and combiner """
227229
iterator = iter(iterator)
228230
# speedup attribute lookup
229231
creator, comb = self.agg.createCombiner, self.agg.mergeValue
@@ -239,11 +241,11 @@ def mergeValues(self, iterator):
239241
break
240242

241243
def _partition(self, key):
242-
""" return the partition for key """
244+
""" Return the partition for key """
243245
return (hash(key) / self.scale) % self.partitions
244246

245247
def _partitioned_mergeValues(self, iterator, limit=0):
246-
""" partition the items by key, then combine them """
248+
""" Partition the items by key, then combine them """
247249
# speedup attribute lookup
248250
creator, comb = self.agg.createCombiner, self.agg.mergeValue
249251
c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch
@@ -260,7 +262,7 @@ def _partitioned_mergeValues(self, iterator, limit=0):
260262
limit = self.next_limit()
261263

262264
def mergeCombiners(self, iterator, check=True):
263-
""" merge (K,V) pair by mergeCombiner """
265+
""" Merge (K,V) pair by mergeCombiner """
264266
iterator = iter(iterator)
265267
# speedup attribute lookup
266268
d, comb, batch = self.data, self.agg.mergeCombiners, self.batch
@@ -277,7 +279,7 @@ def mergeCombiners(self, iterator, check=True):
277279
break
278280

279281
def _partitioned_mergeCombiners(self, iterator, limit=0):
280-
""" partition the items by key, then merge them """
282+
""" Partition the items by key, then merge them """
281283
comb, pdata = self.agg.mergeCombiners, self.pdata
282284
c, hfun = 0, self._partition
283285
for k, v in iterator:
@@ -293,7 +295,7 @@ def _partitioned_mergeCombiners(self, iterator, limit=0):
293295

294296
def _first_spill(self):
295297
"""
296-
dump all the data into disks partition by partition.
298+
Dump all the data into disks partition by partition.
297299
298300
The data has not been partitioned, it will iterator the
299301
dataset once, write them into different files, has no
@@ -337,13 +339,13 @@ def _spill(self):
337339
self.spills += 1
338340

339341
def iteritems(self):
340-
""" return all merged items as iterator """
342+
""" Return all merged items as iterator """
341343
if not self.pdata and not self.spills:
342344
return self.data.iteritems()
343345
return self._external_items()
344346

345347
def _external_items(self):
346-
""" return all partitioned items as iterator """
348+
""" Return all partitioned items as iterator """
347349
assert not self.data
348350
if any(self.pdata):
349351
self._spill()
@@ -359,7 +361,10 @@ def _external_items(self):
359361
self.mergeCombiners(self.serializer.load_stream(open(p)),
360362
False)
361363

362-
if get_used_memory() > hard_limit and j < self.spills - 1:
364+
# limit the total partitions
365+
if (self.scale * self.partitions < self.TOTAL_PARTITIONS)
366+
and j < self.spills - 1
367+
and get_used_memory() > hard_limit):
363368
self.data.clear() # will read from disk again
364369
for v in self._recursive_merged_items(i):
365370
yield v
@@ -368,11 +373,17 @@ def _external_items(self):
368373
for v in self.data.iteritems():
369374
yield v
370375
self.data.clear()
376+
377+
# remove the merged partition
378+
for j in range(self.spills):
379+
path = self._get_spill_dir(j)
380+
os.remove(os.path.join(path, str(i)))
381+
371382
finally:
372383
self._cleanup()
373384

374385
def _cleanup(self):
375-
""" clean up all the files in disks """
386+
""" Clean up all the files in disks """
376387
for d in self.localdirs:
377388
shutil.rmtree(d, True)
378389

@@ -410,6 +421,11 @@ def _recursive_merged_items(self, start):
410421
for v in m._external_items():
411422
yield v
412423

424+
# remove the merged partition
425+
for j in range(self.spills):
426+
path = self._get_spill_dir(j)
427+
os.remove(os.path.join(path, str(i)))
428+
413429

414430
if __name__ == "__main__":
415431
import doctest

0 commit comments

Comments
 (0)