Skip to content

Commit 20424da

Browse files
mengxrrxin
authored andcommitted
[SPARK-2174][MLLIB] treeReduce and treeAggregate
In `reduce` and `aggregate`, the driver node spends linear time on the number of partitions. It becomes a bottleneck when there are many partitions and the data from each partition is big. SPARK-1485 (apache#506) tracks the progress of implementing AllReduce on Spark. I did several implementations including butterfly, reduce + broadcast, and treeReduce + broadcast. treeReduce + BT broadcast seems to be right way to go for Spark. Using binary tree may introduce some overhead in communication, because the driver still need to coordinate on data shuffling. In my experiments, n -> sqrt(n) -> 1 gives the best performance in general, which is why I set "depth = 2" in MLlib algorithms. But it certainly needs more testing. I left `treeReduce` and `treeAggregate` public for easy testing. Some numbers from a test on 32-node m3.2xlarge cluster. code: ~~~ import breeze.linalg._ import org.apache.log4j._ Logger.getRootLogger.setLevel(Level.OFF) for (n <- Seq(1, 10, 100, 1000, 10000, 100000, 1000000)) { val vv = sc.parallelize(0 until 1024, 1024).map(i => DenseVector.zeros[Double](n)) var start = System.nanoTime(); vv.treeReduce(_ + _, 2); println((System.nanoTime() - start) / 1e9) start = System.nanoTime(); vv.reduce(_ + _); println((System.nanoTime() - start) / 1e9) } ~~~ out: | n | treeReduce(,2) | reduce | |---|---------------------|-----------| | 10 | 0.215538731 | 0.204206899 | | 100 | 0.278405907 | 0.205732582 | | 1000 | 0.208972182 | 0.214298272 | | 10000 | 0.194792071 | 0.349353687 | | 100000 | 0.347683285 | 6.086671892 | | 1000000 | 2.589350682 | 66.572906702 | CC: @pwendell This is clearly more scalable than the default implementation. My question is whether we should use this implementation in `reduce` and `aggregate` or put them as separate methods. The concern is that users may use `reduce` and `aggregate` as collect, where having multiple stages doesn't reduce the data size. However, in this case, `collect` is more appropriate. Author: Xiangrui Meng <[email protected]> Closes apache#1110 from mengxr/tree and squashes the following commits: c6cd267 [Xiangrui Meng] make depth default to 2 b04b96a [Xiangrui Meng] address comments 9bcc5d3 [Xiangrui Meng] add depth for readability 7495681 [Xiangrui Meng] fix compile error 142a857 [Xiangrui Meng] merge master d58a087 [Xiangrui Meng] move treeReduce and treeAggregate to mllib 8a2a59c [Xiangrui Meng] Merge branch 'master' into tree be6a88a [Xiangrui Meng] use treeAggregate in mllib 0f94490 [Xiangrui Meng] add docs eb71c33 [Xiangrui Meng] add treeReduce fe42a5e [Xiangrui Meng] add treeAggregate
1 parent 96ba04b commit 20424da

File tree

5 files changed

+98
-15
lines changed

5 files changed

+98
-15
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.annotation.Experimental
2828
import org.apache.spark.mllib.linalg._
2929
import org.apache.spark.rdd.RDD
3030
import org.apache.spark.Logging
31+
import org.apache.spark.mllib.rdd.RDDFunctions._
3132
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
3233

3334
/**
@@ -79,7 +80,7 @@ class RowMatrix(
7980
private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = {
8081
val n = numCols().toInt
8182
val vbr = rows.context.broadcast(v)
82-
rows.aggregate(BDV.zeros[Double](n))(
83+
rows.treeAggregate(BDV.zeros[Double](n))(
8384
seqOp = (U, r) => {
8485
val rBrz = r.toBreeze
8586
val a = rBrz.dot(vbr.value)
@@ -91,9 +92,7 @@ class RowMatrix(
9192
s"Do not support vector operation from type ${rBrz.getClass.getName}.")
9293
}
9394
U
94-
},
95-
combOp = (U1, U2) => U1 += U2
96-
)
95+
}, combOp = (U1, U2) => U1 += U2)
9796
}
9897

9998
/**
@@ -104,13 +103,11 @@ class RowMatrix(
104103
val nt: Int = n * (n + 1) / 2
105104

106105
// Compute the upper triangular part of the gram matrix.
107-
val GU = rows.aggregate(new BDV[Double](new Array[Double](nt)))(
106+
val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))(
108107
seqOp = (U, v) => {
109108
RowMatrix.dspr(1.0, v, U.data)
110109
U
111-
},
112-
combOp = (U1, U2) => U1 += U2
113-
)
110+
}, combOp = (U1, U2) => U1 += U2)
114111

115112
RowMatrix.triuToFull(n, GU.data)
116113
}
@@ -290,9 +287,10 @@ class RowMatrix(
290287
s"We need at least $mem bytes of memory.")
291288
}
292289

293-
val (m, mean) = rows.aggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
290+
val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
294291
seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze),
295-
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
292+
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) =>
293+
(s1._1 + s2._1, s1._2 += s2._2)
296294
)
297295

298296
updateNumRows(m)
@@ -353,10 +351,9 @@ class RowMatrix(
353351
* Computes column-wise summary statistics.
354352
*/
355353
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
356-
val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)(
354+
val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
357355
(aggregator, data) => aggregator.add(data),
358-
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
359-
)
356+
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
360357
updateNumRows(summary.count)
361358
summary
362359
}

mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi}
2525
import org.apache.spark.Logging
2626
import org.apache.spark.rdd.RDD
2727
import org.apache.spark.mllib.linalg.{Vectors, Vector}
28+
import org.apache.spark.mllib.rdd.RDDFunctions._
2829

2930
/**
3031
* Class used to solve an optimization problem using Gradient Descent.
@@ -177,7 +178,7 @@ object GradientDescent extends Logging {
177178
// Sample a subset (fraction miniBatchFraction) of the total data
178179
// compute and sum up the subgradients on this subset (this is one map-reduce)
179180
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
180-
.aggregate((BDV.zeros[Double](n), 0.0))(
181+
.treeAggregate((BDV.zeros[Double](n), 0.0))(
181182
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
182183
val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
183184
(grad, loss + l)

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.annotation.DeveloperApi
2626
import org.apache.spark.Logging
2727
import org.apache.spark.rdd.RDD
2828
import org.apache.spark.mllib.linalg.{Vectors, Vector}
29+
import org.apache.spark.mllib.rdd.RDDFunctions._
2930

3031
/**
3132
* :: DeveloperApi ::
@@ -199,7 +200,7 @@ object LBFGS extends Logging {
199200
val n = weights.length
200201
val bcWeights = data.context.broadcast(weights)
201202

202-
val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))(
203+
val (gradientSum, lossSum) = data.treeAggregate((BDV.zeros[Double](n), 0.0))(
203204
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
204205
val l = localGradient.compute(
205206
features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad))

mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ package org.apache.spark.mllib.rdd
2020
import scala.language.implicitConversions
2121
import scala.reflect.ClassTag
2222

23+
import org.apache.spark.HashPartitioner
24+
import org.apache.spark.SparkContext._
2325
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.util.Utils
2427

2528
/**
2629
* Machine learning specific RDD functions.
@@ -44,6 +47,69 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
4447
new SlidingRDD[T](self, windowSize)
4548
}
4649
}
50+
51+
/**
52+
* Reduces the elements of this RDD in a multi-level tree pattern.
53+
*
54+
* @param depth suggested depth of the tree (default: 2)
55+
* @see [[org.apache.spark.rdd.RDD#reduce]]
56+
*/
57+
def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
58+
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
59+
val cleanF = self.context.clean(f)
60+
val reducePartition: Iterator[T] => Option[T] = iter => {
61+
if (iter.hasNext) {
62+
Some(iter.reduceLeft(cleanF))
63+
} else {
64+
None
65+
}
66+
}
67+
val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it)))
68+
val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
69+
if (c.isDefined && x.isDefined) {
70+
Some(cleanF(c.get, x.get))
71+
} else if (c.isDefined) {
72+
c
73+
} else if (x.isDefined) {
74+
x
75+
} else {
76+
None
77+
}
78+
}
79+
RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth)
80+
.getOrElse(throw new UnsupportedOperationException("empty collection"))
81+
}
82+
83+
/**
84+
* Aggregates the elements of this RDD in a multi-level tree pattern.
85+
*
86+
* @param depth suggested depth of the tree (default: 2)
87+
* @see [[org.apache.spark.rdd.RDD#aggregate]]
88+
*/
89+
def treeAggregate[U: ClassTag](zeroValue: U)(
90+
seqOp: (U, T) => U,
91+
combOp: (U, U) => U,
92+
depth: Int = 2): U = {
93+
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
94+
if (self.partitions.size == 0) {
95+
return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance())
96+
}
97+
val cleanSeqOp = self.context.clean(seqOp)
98+
val cleanCombOp = self.context.clean(combOp)
99+
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
100+
var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it)))
101+
var numPartitions = partiallyAggregated.partitions.size
102+
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
103+
// If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
104+
while (numPartitions > scale + numPartitions / scale) {
105+
numPartitions /= scale
106+
val curNumPartitions = numPartitions
107+
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
108+
iter.map((i % curNumPartitions, _))
109+
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
110+
}
111+
partiallyAggregated.reduce(cleanCombOp)
112+
}
47113
}
48114

49115
private[mllib]

mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,22 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
4646
val expected = data.flatMap(x => x).sliding(3).toList
4747
assert(sliding.collect().toList === expected)
4848
}
49+
50+
test("treeAggregate") {
51+
val rdd = sc.makeRDD(-1000 until 1000, 10)
52+
def seqOp = (c: Long, x: Int) => c + x
53+
def combOp = (c1: Long, c2: Long) => c1 + c2
54+
for (depth <- 1 until 10) {
55+
val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth)
56+
assert(sum === -1000L)
57+
}
58+
}
59+
60+
test("treeReduce") {
61+
val rdd = sc.makeRDD(-1000 until 1000, 10)
62+
for (depth <- 1 until 10) {
63+
val sum = rdd.treeReduce(_ + _, depth)
64+
assert(sum === -1000)
65+
}
66+
}
4967
}

0 commit comments

Comments
 (0)