-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-2538] [PySpark] Hash based disk spilling aggregation #1460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6edbd1f
e9a40f6
286aaff
57ee7ef
24cec6a
e78a0a0
3652583
e6cc7f9
1a97ce4
fdd0a49
6178844
400be01
e74b785
f6bd5d6
67e6eba
dcf03a9
902f036
37d71f7
cad91bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,6 +42,8 @@ | |
| from pyspark.rddsampler import RDDSampler | ||
| from pyspark.storagelevel import StorageLevel | ||
| from pyspark.resultiterable import ResultIterable | ||
| from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ | ||
| get_used_memory | ||
|
|
||
| from py4j.java_collections import ListConverter, MapConverter | ||
|
|
||
|
|
@@ -197,6 +199,22 @@ def _replaceRoot(self, value): | |
| self._sink(1) | ||
|
|
||
|
|
||
| def _parse_memory(s): | ||
| """ | ||
| Parse a memory string in the format supported by Java (e.g. 1g, 200m) and | ||
| return the value in MB | ||
|
|
||
| >>> _parse_memory("256m") | ||
| 256 | ||
| >>> _parse_memory("2g") | ||
| 2048 | ||
| """ | ||
| units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024} | ||
| if s[-1] not in units: | ||
| raise ValueError("invalid format: " + s) | ||
| return int(float(s[:-1]) * units[s[-1].lower()]) | ||
|
|
||
|
|
||
| class RDD(object): | ||
|
|
||
| """ | ||
|
|
@@ -1207,20 +1225,49 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): | |
| if numPartitions is None: | ||
| numPartitions = self._defaultReducePartitions() | ||
|
|
||
| # Transferring O(n) objects to Java is too expensive. Instead, we'll | ||
| # form the hash buckets in Python, transferring O(numPartitions) objects | ||
| # to Java. Each object is a (splitNumber, [objects]) pair. | ||
| # Transferring O(n) objects to Java is too expensive. | ||
| # Instead, we'll form the hash buckets in Python, | ||
| # transferring O(numPartitions) objects to Java. | ||
| # Each object is a (splitNumber, [objects]) pair. | ||
| # In order to avoid too huge objects, the objects are | ||
| # grouped into chunks. | ||
| outputSerializer = self.ctx._unbatched_serializer | ||
|
|
||
| limit = (_parse_memory(self.ctx._conf.get( | ||
| "spark.python.worker.memory", "512m")) / 2) | ||
|
|
||
| def add_shuffle_key(split, iterator): | ||
|
|
||
| buckets = defaultdict(list) | ||
| c, batch = 0, min(10 * numPartitions, 1000) | ||
|
|
||
| for (k, v) in iterator: | ||
| buckets[partitionFunc(k) % numPartitions].append((k, v)) | ||
| c += 1 | ||
|
|
||
| # check used memory and avg size of chunk of objects | ||
| if (c % 1000 == 0 and get_used_memory() > limit | ||
| or c > batch): | ||
| n, size = len(buckets), 0 | ||
| for split in buckets.keys(): | ||
| yield pack_long(split) | ||
| d = outputSerializer.dumps(buckets[split]) | ||
| del buckets[split] | ||
| yield d | ||
| size += len(d) | ||
|
|
||
| avg = (size / n) >> 20 | ||
| # let 1M < avg < 10M | ||
| if avg < 1: | ||
| batch *= 1.5 | ||
| elif avg > 10: | ||
| batch = max(batch / 1.5, 1) | ||
| c = 0 | ||
|
|
||
| for (split, items) in buckets.iteritems(): | ||
| yield pack_long(split) | ||
| yield outputSerializer.dumps(items) | ||
|
|
||
| keyed = PipelinedRDD(self, add_shuffle_key) | ||
| keyed._bypass_serializer = True | ||
| with _JavaStackTrace(self.context) as st: | ||
|
|
@@ -1230,8 +1277,8 @@ def add_shuffle_key(split, iterator): | |
| id(partitionFunc)) | ||
| jrdd = pairRDD.partitionBy(partitioner).values() | ||
| rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) | ||
| # This is required so that id(partitionFunc) remains unique, even if | ||
| # partitionFunc is a lambda: | ||
| # This is required so that id(partitionFunc) remains unique, | ||
| # even if partitionFunc is a lambda: | ||
| rdd._partitionFunc = partitionFunc | ||
| return rdd | ||
|
|
||
|
|
@@ -1265,26 +1312,28 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, | |
| if numPartitions is None: | ||
| numPartitions = self._defaultReducePartitions() | ||
|
|
||
| serializer = self.ctx.serializer | ||
| spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() | ||
| == 'true') | ||
| memory = _parse_memory(self.ctx._conf.get( | ||
| "spark.python.worker.memory", "512m")) | ||
| agg = Aggregator(createCombiner, mergeValue, mergeCombiners) | ||
|
|
||
| def combineLocally(iterator): | ||
| combiners = {} | ||
| for x in iterator: | ||
| (k, v) = x | ||
| if k not in combiners: | ||
| combiners[k] = createCombiner(v) | ||
| else: | ||
| combiners[k] = mergeValue(combiners[k], v) | ||
| return combiners.iteritems() | ||
| merger = ExternalMerger(agg, memory * 0.9, serializer) \ | ||
| if spill else InMemoryMerger(agg) | ||
| merger.mergeValues(iterator) | ||
| return merger.iteritems() | ||
|
|
||
| locally_combined = self.mapPartitions(combineLocally) | ||
| shuffled = locally_combined.partitionBy(numPartitions) | ||
|
|
||
| def _mergeCombiners(iterator): | ||
| combiners = {} | ||
| for (k, v) in iterator: | ||
| if k not in combiners: | ||
| combiners[k] = v | ||
| else: | ||
| combiners[k] = mergeCombiners(combiners[k], v) | ||
| return combiners.iteritems() | ||
| merger = ExternalMerger(agg, memory, serializer) \ | ||
| if spill else InMemoryMerger(agg) | ||
| merger.mergeCombiners(iterator) | ||
| return merger.iteritems() | ||
|
|
||
| return shuffled.mapPartitions(_mergeCombiners) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only implements external merging in the reduce tasks, but we need it in the map tasks too. For that you'll need to modify the Merger interface to take |
||
|
|
||
| def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): | ||
|
|
@@ -1343,7 +1392,8 @@ def mergeValue(xs, x): | |
| return xs | ||
|
|
||
| def mergeCombiners(a, b): | ||
| return a + b | ||
| a.extend(b) | ||
| return a | ||
|
|
||
| return self.combineByKey(createCombiner, mergeValue, mergeCombiners, | ||
| numPartitions).mapValues(lambda x: ResultIterable(x)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -193,7 +193,7 @@ def load_stream(self, stream): | |
| return chain.from_iterable(self._load_stream_without_unbatching(stream)) | ||
|
|
||
| def _load_stream_without_unbatching(self, stream): | ||
| return self.serializer.load_stream(stream) | ||
| return self.serializer.load_stream(stream) | ||
|
|
||
| def __eq__(self, other): | ||
| return (isinstance(other, BatchedSerializer) and | ||
|
|
@@ -302,6 +302,33 @@ class MarshalSerializer(FramedSerializer): | |
| loads = marshal.loads | ||
|
|
||
|
|
||
| class AutoSerializer(FramedSerializer): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we actually use this by default yet or will it fail for NumPy arrays? If it won't work by default, we should use PickleSerializer instead and wait to fix this.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will fails in some cases, such as array, so it's not safe to make it as default. I will improve it later and try to make it as default. Currently, it's still useful, because people can use it in most cases. |
||
| """ | ||
| Choose marshal or cPickle as serialization protocol autumatically | ||
| """ | ||
| def __init__(self): | ||
| FramedSerializer.__init__(self) | ||
| self._type = None | ||
|
|
||
| def dumps(self, obj): | ||
| if self._type is not None: | ||
| return 'P' + cPickle.dumps(obj, -1) | ||
| try: | ||
| return 'M' + marshal.dumps(obj) | ||
| except Exception: | ||
| self._type = 'P' | ||
| return 'P' + cPickle.dumps(obj, -1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the objects are not marshal-able but are pickle-able, is there a big performance cost to throwing an exception on each write? Would be good to test this, because if not, we can make this serializer our default where we now use Pickle. Even if there is a cost maybe we can do something where if 10% of the objects written fail to marshal we switch to always using pickle.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had add an fast path for it, no exception cost any more. |
||
|
|
||
| def loads(self, obj): | ||
| _type = obj[0] | ||
| if _type == 'M': | ||
| return marshal.loads(obj[1:]) | ||
| elif _type == 'P': | ||
| return cPickle.loads(obj[1:]) | ||
| else: | ||
| raise ValueError("invalid sevialization type: %s" % _type) | ||
|
|
||
|
|
||
| class UTF8Deserializer(Serializer): | ||
| """ | ||
| Deserializes streams written by String.getBytes. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of passing spark.local.dir, we should figure out which directories the DiskBlockManager created (you can get it from
env.blockManager) and pass a comma-separated list of those. This way the data for this Spark application is all in one directory, and Java can make sure we clean it all up at the end. Otherwise the way you have things set up now, those directories are never cleared if the Python worker crashes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DiskBlockManager.localDirs is private, make it public?