Skip to content

Commit 286aaff

Browse files
committed
let spilled aggregation in Python configurable
add spark.python.worker.memory for memory used by Python worker. Default is 512M.
1 parent e9a40f6 commit 286aaff

File tree

4 files changed

+80
-22
lines changed

4 files changed

+80
-22
lines changed

docs/configuration.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,15 @@ Apart from these, the following properties are also available, and may be useful
195195
Spark's dependencies and user dependencies. It is currently an experimental feature.
196196
</td>
197197
</tr>
198+
<tr>
199+
<td><code>spark.python.worker.memory</code></td>
200+
<td>512m</td>
201+
<td>
202+
Amount of memory to use per python worker process during aggregation, in the same
203+
format as JVM memory strings (e.g. <code>512m</code>, <code>2g</code>). If the memory
204+
used during aggregation go above this amount, it will spill the data into disks.
205+
</td>
206+
</tr>
198207
</table>
199208

200209
#### Shuffle Behavior

python/pyspark/rdd.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from pyspark.rddsampler import RDDSampler
4343
from pyspark.storagelevel import StorageLevel
4444
from pyspark.resultiterable import ResultIterable
45-
from pyspark.shuffle import Merger
45+
from pyspark.shuffle import MapMerger, ExternalHashMapMerger
4646

4747
from py4j.java_collections import ListConverter, MapConverter
4848

@@ -169,6 +169,18 @@ def _replaceRoot(self, value):
169169
self._sink(1)
170170

171171

172+
def _parse_memory(s):
173+
"""
174+
>>> _parse_memory("256m")
175+
256
176+
>>> _parse_memory("2g")
177+
2048
178+
"""
179+
units = {'g': 1024, 'm': 1, 't': 1<<20, 'k':1.0/1024}
180+
if s[-1] not in units:
181+
raise ValueError("invalid format: " + s)
182+
return int(float(s[:-1]) * units[s[-1].lower()])
183+
172184
class RDD(object):
173185

174186
"""
@@ -1249,10 +1261,14 @@ def combineLocally(iterator):
12491261
locally_combined = self.mapPartitions(combineLocally)
12501262
shuffled = locally_combined.partitionBy(numPartitions)
12511263

1252-
executorMemory = self.ctx._jsc.sc().executorMemory()
1264+
serializer = self.ctx.serializer
1265+
spill = ((self.ctx._conf.get("spark.shuffle.spill") or 'True').lower()
1266+
in ('true', '1', 'yes'))
1267+
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory") or "512m")
12531268
def _mergeCombiners(iterator):
1254-
# TODO: workdir and serializer
1255-
merger = Merger(mergeCombiners, executorMemory)
1269+
# TODO: workdir
1270+
merger = ExternalHashMapMerger(mergeCombiners, memory, serializer=serializer)\
1271+
if spill else MapMerger(mergeCombiners)
12561272
merger.merge(iterator)
12571273
return merger.iteritems()
12581274
return shuffled.mapPartitions(_mergeCombiners)

python/pyspark/shuffle.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,41 @@ def get_used_memory():
4545

4646

4747
class Merger(object):
48+
"""
49+
merge shuffled data together by combinator
50+
"""
51+
def merge(self, iterator):
52+
raise NotImplementedError
53+
54+
def iteritems(self):
55+
raise NotImplementedError
56+
57+
58+
class MapMerger(Merger):
59+
"""
60+
In memory merger based on map
61+
"""
62+
def __init__(self, combiner):
63+
self.combiner = combiner
64+
self.data = {}
65+
66+
def merge(self, iterator):
67+
d, comb = self.data, self.combiner
68+
for k,v in iter(iterator):
69+
d[k] = comb(d[k], v) if k in d else v
70+
71+
def iteritems(self):
72+
return self.data.iteritems()
73+
74+
75+
class ExternalHashMapMerger(Merger):
4876

4977
"""
5078
External merger will dump the aggregated data into disks when memory usage
5179
is above the limit, then merge them together.
5280
5381
>>> combiner = lambda x, y:x+y
54-
>>> merger = Merger(combiner, 10)
82+
>>> merger = ExternalHashMapMerger(combiner, 10)
5583
>>> N = 10000
5684
>>> merger.merge(zip(xrange(N), xrange(N)) * 10)
5785
>>> assert merger.spills > 0
@@ -63,16 +91,16 @@ class Merger(object):
6391
PARTITIONS = 64
6492
BATCH = 10000
6593

66-
def __init__(self, combiner, memory_limit=512, path="/tmp/pysparki/merge",
67-
serializer=None, batch_size=1024, scale=1):
94+
def __init__(self, combiner, memory_limit=512, path="/tmp/pyspark/merge",
95+
serializer=None, scale=1):
6896
self.combiner = combiner
69-
self.path = os.path.join(path, str(os.getpid()))
7097
self.memory_limit = memory_limit
71-
self.serializer = serializer or BatchedSerializer(AutoSerializer(), batch_size)
98+
self.path = os.path.join(path, str(os.getpid()))
99+
self.serializer = serializer or BatchedSerializer(AutoSerializer(), 1024)
100+
self.scale = scale
72101
self.data = {}
73102
self.pdata = []
74103
self.spills = 0
75-
self.scale = scale
76104

77105
@property
78106
def used_memory(self):
@@ -94,7 +122,7 @@ def merge(self, iterator, check=True):
94122
continue
95123

96124
c += 1
97-
if c % self.BATCH == 0 and self.used_memory > self.memory_limit:
125+
if c % batch == 0 and self.used_memory > self.memory_limit:
98126
self._first_spill()
99127
self._partitioned_merge(iterator, self.next_limit)
100128
break
@@ -158,7 +186,7 @@ def _external_items(self):
158186
for j in range(self.spills):
159187
p = os.path.join(self.path, str(j), str(i))
160188
self.merge(self.serializer.load_stream(open(p)), check=False)
161-
189+
162190
if j > 0 and self.used_memory > hard_limit and j < self.spills - 1:
163191
self.data.clear() # will read from disk again
164192
for v in self._recursive_merged_items(i):
@@ -178,12 +206,12 @@ def _recursive_merged_items(self, start):
178206
self._spill()
179207

180208
for i in range(start, self.PARTITIONS):
181-
m = Merger(self.combiner, self.memory_limit,
209+
m = ExternalHashMapMerger(self.combiner, self.memory_limit,
182210
os.path.join(self.path, 'merge', str(i)),
183211
self.serializer, scale=self.scale * self.PARTITIONS)
184-
m.pdata = [{} for x in range(self.PARTITIONS)]
212+
m.pdata = [{} for _ in range(self.PARTITIONS)]
185213
limit = self.next_limit
186-
214+
187215
for j in range(self.spills):
188216
p = os.path.join(self.path, str(j), str(i))
189217
m._partitioned_merge(self.serializer.load_stream(open(p)), 0)
@@ -193,7 +221,7 @@ def _recursive_merged_items(self, start):
193221

194222
for v in m._external_items():
195223
yield v
196-
224+
197225
shutil.rmtree(self.path, True)
198226

199227

python/pyspark/tests.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from pyspark.context import SparkContext
3535
from pyspark.files import SparkFiles
3636
from pyspark.serializers import read_int
37-
from pyspark.shuffle import Merger
37+
from pyspark.shuffle import MapMerger, ExternalHashMapMerger
3838

3939
_have_scipy = False
4040
try:
@@ -54,23 +54,28 @@ def setUp(self):
5454
self.N = 1<<18
5555
self.l = [i for i in xrange(self.N)]
5656
self.data = zip(self.l, self.l)
57-
Merger.PARTITIONS = 8
58-
Merger.BATCH = 1<<14
57+
ExternalHashMapMerger.PARTITIONS = 8
58+
ExternalHashMapMerger.BATCH = 1<<14
59+
60+
def test_in_memory(self):
61+
m = MapMerger(lambda x,y: x+y)
62+
m.merge(self.data)
63+
self.assertEqual(sum(v for k,v in m.iteritems()), sum(xrange(self.N)))
5964

6065
def test_small_dataset(self):
61-
m = Merger(lambda x,y: x+y, 1000)
66+
m = ExternalHashMapMerger(lambda x,y: x+y, 1000)
6267
m.merge(self.data)
6368
self.assertEqual(m.spills, 0)
6469
self.assertEqual(sum(v for k,v in m.iteritems()), sum(xrange(self.N)))
6570

6671
def test_medium_dataset(self):
67-
m = Merger(lambda x,y: x+y, 10)
72+
m = ExternalHashMapMerger(lambda x,y: x+y, 10)
6873
m.merge(self.data * 3)
6974
self.assertTrue(m.spills >= 1)
7075
self.assertEqual(sum(v for k,v in m.iteritems()), sum(xrange(self.N)) * 3)
7176

7277
def test_huge_dataset(self):
73-
m = Merger(lambda x,y: x + y, 10)
78+
m = ExternalHashMapMerger(lambda x,y: x + y, 10)
7479
m.merge(map(lambda (k,v): (k, [str(v)]), self.data) * 10)
7580
self.assertTrue(m.spills >= 1)
7681
self.assertEqual(sum(len(v) for k,v in m._recursive_merged_items(0)), self.N * 10)

0 commit comments

Comments
 (0)