From f6e23650fcaf943f058494ddd13282f37f9452f5 Mon Sep 17 00:00:00 2001 From: Anthony Truchet Date: Thu, 17 Nov 2016 16:20:19 +0100 Subject: [PATCH] [SPARK-18471][CORE] New treeAggregate overload for big large aggregators The zero for the aggregation used to be shipped into a closure which is higly problematic when this zero is big (100s of MB is typical for ML). This change introduces a new overload of treeAggregate which only ships a function able to generate this zero. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8e673447581c..9c3c246551d4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1111,21 +1111,24 @@ abstract class RDD[T: ClassTag]( /** * Aggregates the elements of this RDD in a multi-level tree pattern. * + * This variant with a function generating the zero, provide for efficiently + * running on big aggregation structure like large dense vectors + * * @param depth suggested depth of the tree (default: 2) * @see [[org.apache.spark.rdd.RDD#aggregate]] */ - def treeAggregate[U: ClassTag](zeroValue: U)( + def treeAggregateWithZeroGenerator[U: ClassTag](zeroValueGenerator: () => U)( seqOp: (U, T) => U, combOp: (U, U) => U, depth: Int = 2): U = withScope { require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") if (partitions.length == 0) { - Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) + Utils.clone(zeroValueGenerator(), context.env.closureSerializer.newInstance()) } else { val cleanSeqOp = context.clean(seqOp) val cleanCombOp = context.clean(combOp) val aggregatePartition = - (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + (it: Iterator[T]) => it.aggregate(zeroValueGenerator())(cleanSeqOp, cleanCombOp) var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) var numPartitions = partiallyAggregated.partitions.length val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) @@ -1144,6 +1147,18 @@ abstract class RDD[T: ClassTag]( } } + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#aggregate]] + */ + def treeAggregate[U: ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2): U = + treeAggregateWithZeroGenerator(() => zeroValue)(seqOp, combOp, depth) + /** * Return the number of elements in the RDD. */