@@ -54,7 +54,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
5454 * @see [[org.apache.spark.rdd.RDD#reduce ]]
5555 */
5656 def treeReduce (f : (T , T ) => T , depth : Int ): T = {
57- require(depth >= 1 , s " Depth must be greater than 1 but got $depth. " )
57+ require(depth >= 1 , s " Depth must be greater than or equal to 1 but got $depth. " )
5858 val cleanF = self.context.clean(f)
5959 val reducePartition : Iterator [T ] => Option [T ] = iter => {
6060 if (iter.hasNext) {
@@ -63,7 +63,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
6363 None
6464 }
6565 }
66- val local = self.mapPartitions(it => Iterator (reducePartition(it)))
66+ val partiallyReduced = self.mapPartitions(it => Iterator (reducePartition(it)))
6767 val op : (Option [T ], Option [T ]) => Option [T ] = (c, x) => {
6868 if (c.isDefined && x.isDefined) {
6969 Some (cleanF(c.get, x.get))
@@ -75,7 +75,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
7575 None
7676 }
7777 }
78- RDDFunctions .fromRDD(local ).treeAggregate(Option .empty[T ])(op, op, depth)
78+ RDDFunctions .fromRDD(partiallyReduced ).treeAggregate(Option .empty[T ])(op, op, depth)
7979 .getOrElse(throw new UnsupportedOperationException (" empty collection" ))
8080 }
8181
@@ -85,26 +85,28 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
8585 * @see [[org.apache.spark.rdd.RDD#aggregate ]]
8686 */
8787 def treeAggregate [U : ClassTag ](zeroValue : U )(
88- seqOp : (U , T ) => U ,
89- combOp : (U , U ) => U ,
90- depth : Int ): U = {
91- require(depth >= 1 , s " Depth must be greater than 1 but got $depth. " )
88+ seqOp : (U , T ) => U ,
89+ combOp : (U , U ) => U ,
90+ depth : Int ): U = {
91+ require(depth >= 1 , s " Depth must be greater than or equal to 1 but got $depth. " )
9292 if (self.partitions.size == 0 ) {
9393 return Utils .clone(zeroValue, self.context.env.closureSerializer.newInstance())
9494 }
9595 val cleanSeqOp = self.context.clean(seqOp)
9696 val cleanCombOp = self.context.clean(combOp)
9797 val aggregatePartition = (it : Iterator [T ]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
98- var local = self.mapPartitions(it => Iterator (aggregatePartition(it)))
99- var numPartitions = local .partitions.size
98+ var partiallyAggregated = self.mapPartitions(it => Iterator (aggregatePartition(it)))
99+ var numPartitions = partiallyAggregated .partitions.size
100100 val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2 )
101+ // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
101102 while (numPartitions > scale + numPartitions / scale) {
102103 numPartitions /= scale
103- local = local.mapPartitionsWithIndex { (i, iter) =>
104- iter.map((i % numPartitions, _))
105- }.reduceByKey(new HashPartitioner (numPartitions), cleanCombOp).values
104+ val curNumPartitions = numPartitions
105+ partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
106+ iter.map((i % curNumPartitions, _))
107+ }.reduceByKey(new HashPartitioner (curNumPartitions), cleanCombOp).values
106108 }
107- local .reduce(cleanCombOp)
109+ partiallyAggregated .reduce(cleanCombOp)
108110 }
109111}
110112
0 commit comments