Skip to content

Commit 3652583

Browse files
committed
address comments
fix code style and add docs and comments use ExternalMerger for map-side aggregation check memory usage during partitionBy()
1 parent e78a0a0 commit 3652583

File tree

6 files changed

+226
-92
lines changed

6 files changed

+226
-92
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ private[spark] class PythonRDD[T: ClassTag](
5757
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
5858
val startTime = System.currentTimeMillis
5959
val env = SparkEnv.get
60-
val localdir = env.conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))
60+
val localdir = env.blockManager.diskBlockManager.localDirs.map(
61+
f => f.getPath()).mkString(",")
6162
val worker: Socket = env.createPythonWorker(pythonExec,
6263
envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))
6364

core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
4343
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
4444
* directory, create multiple subdirectories that we will hash files into, in order to avoid
4545
* having really large inodes at the top level. */
46-
private val localDirs: Array[File] = createLocalDirs()
46+
val localDirs: Array[File] = createLocalDirs()
4747
if (localDirs.isEmpty) {
4848
logError("Failed to create any local dir.")
4949
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)

docs/configuration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ Apart from these, the following properties are also available, and may be useful
201201
<td>
202202
Amount of memory to use per python worker process during aggregation, in the same
203203
format as JVM memory strings (e.g. <code>512m</code>, <code>2g</code>). If the memory
204-
used during aggregation go above this amount, it will spill the data into disks.
204+
used during aggregation goes above this amount, it will spill the data into disks.
205205
</td>
206206
</tr>
207207
</table>

python/pyspark/rdd.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
from pyspark.rddsampler import RDDSampler
4343
from pyspark.storagelevel import StorageLevel
4444
from pyspark.resultiterable import ResultIterable
45-
from pyspark.shuffle import MapMerger, ExternalHashMapMerger
45+
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
46+
get_used_memory
4647

4748
from py4j.java_collections import ListConverter, MapConverter
4849

@@ -171,18 +172,20 @@ def _replaceRoot(self, value):
171172

172173
def _parse_memory(s):
173174
"""
174-
It returns a number in MB
175+
Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
176+
return the value in MB
175177
176178
>>> _parse_memory("256m")
177179
256
178180
>>> _parse_memory("2g")
179181
2048
180182
"""
181-
units = {'g': 1024, 'm': 1, 't': 1<<20, 'k':1.0/1024}
183+
units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024}
182184
if s[-1] not in units:
183185
raise ValueError("invalid format: " + s)
184186
return int(float(s[:-1]) * units[s[-1].lower()])
185187

188+
186189
class RDD(object):
187190

188191
"""
@@ -1198,15 +1201,25 @@ def partitionBy(self, numPartitions, partitionFunc=None):
11981201
# to Java. Each object is a (splitNumber, [objects]) pair.
11991202
outputSerializer = self.ctx._unbatched_serializer
12001203

1204+
limit = _parse_memory(self.ctx._conf.get("spark.python.worker.memory")
1205+
or "512m")
12011206
def add_shuffle_key(split, iterator):
12021207

12031208
buckets = defaultdict(list)
1204-
1209+
c, batch = 0, 1000
12051210
for (k, v) in iterator:
12061211
buckets[partitionFunc(k) % numPartitions].append((k, v))
1212+
c += 1
1213+
if c % batch == 0 and get_used_memory() > limit:
1214+
for split in buckets.keys():
1215+
yield pack_long(split)
1216+
yield outputSerializer.dumps(buckets[split])
1217+
del buckets[split]
1218+
12071219
for (split, items) in buckets.iteritems():
12081220
yield pack_long(split)
12091221
yield outputSerializer.dumps(items)
1222+
12101223
keyed = PipelinedRDD(self, add_shuffle_key)
12111224
keyed._bypass_serializer = True
12121225
with _JavaStackTrace(self.context) as st:
@@ -1251,27 +1264,26 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
12511264
if numPartitions is None:
12521265
numPartitions = self._defaultReducePartitions()
12531266

1267+
serializer = self.ctx.serializer
1268+
spill = (self.ctx._conf.get("spark.shuffle.spill") or 'True').lower() == 'true'
1269+
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory") or "512m")
1270+
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
1271+
12541272
def combineLocally(iterator):
1255-
combiners = {}
1256-
for x in iterator:
1257-
(k, v) = x
1258-
if k not in combiners:
1259-
combiners[k] = createCombiner(v)
1260-
else:
1261-
combiners[k] = mergeValue(combiners[k], v)
1262-
return combiners.iteritems()
1273+
merger = ExternalMerger(agg, memory, serializer) \
1274+
if spill else InMemoryMerger(agg)
1275+
merger.combine(iterator)
1276+
return merger.iteritems()
1277+
12631278
locally_combined = self.mapPartitions(combineLocally)
12641279
shuffled = locally_combined.partitionBy(numPartitions)
12651280

1266-
serializer = self.ctx.serializer
1267-
spill = ((self.ctx._conf.get("spark.shuffle.spill") or 'True').lower()
1268-
in ('true', '1', 'yes'))
1269-
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory") or "512m")
12701281
def _mergeCombiners(iterator):
1271-
merger = ExternalHashMapMerger(mergeCombiners, memory, serializer)\
1272-
if spill else MapMerger(mergeCombiners)
1282+
merger = ExternalMerger(agg, memory, serializer) \
1283+
if spill else InMemoryMerger(agg)
12731284
merger.merge(iterator)
12741285
return merger.iteritems()
1286+
12751287
return shuffled.mapPartitions(_mergeCombiners)
12761288

12771289
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):

0 commit comments

Comments
 (0)