Skip to content

Commit be6a88a

Browse files
committed
use treeAggregate in mllib
1 parent 0f94490 commit be6a88a

File tree

3 files changed

+14
-15
lines changed

3 files changed

+14
-15
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,11 @@ class RowMatrix(
208208
val nt: Int = n * (n + 1) / 2
209209

210210
// Compute the upper triangular part of the gram matrix.
211-
val GU = rows.aggregate(new BDV[Double](new Array[Double](nt)))(
211+
val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))(
212212
seqOp = (U, v) => {
213213
RowMatrix.dspr(1.0, v, U.data)
214214
U
215-
},
216-
combOp = (U1, U2) => U1 += U2
217-
)
215+
}, combOp = (U1, U2) => U1 += U2, 2)
218216

219217
RowMatrix.triuToFull(n, GU.data)
220218
}
@@ -309,10 +307,11 @@ class RowMatrix(
309307
s"We need at least $mem bytes of memory.")
310308
}
311309

312-
val (m, mean) = rows.aggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
310+
val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
313311
seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze),
314-
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
315-
)
312+
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) =>
313+
(s1._1 + s2._1, s1._2 += s2._2),
314+
2)
316315

317316
updateNumRows(m)
318317

@@ -371,10 +370,10 @@ class RowMatrix(
371370
*/
372371
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
373372
val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
374-
val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
375-
(aggregator, data) => aggregator.add(data),
376-
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
377-
)
373+
val summary = rows.map(_.toBreeze).treeAggregate[ColumnStatisticsAggregator](zeroValue)(
374+
seqOp = (aggregator, data) => aggregator.add(data),
375+
combOp = (aggregator1, aggregator2) => aggregator1.merge(aggregator2),
376+
2)
378377
updateNumRows(summary.count)
379378
summary
380379
}

mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,14 @@ object GradientDescent extends Logging {
175175
// Sample a subset (fraction miniBatchFraction) of the total data
176176
// compute and sum up the subgradients on this subset (this is one map-reduce)
177177
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
178-
.aggregate((BDV.zeros[Double](weights.size), 0.0))(
178+
.treeAggregate((BDV.zeros[Double](weights.size), 0.0))(
179179
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
180180
val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad))
181181
(grad, loss + l)
182182
},
183183
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
184184
(grad1 += grad2, loss1 + loss2)
185-
})
185+
}, 2)
186186

187187
/**
188188
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,15 @@ object LBFGS extends Logging {
198198
val localData = data
199199
val localGradient = gradient
200200

201-
val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
201+
val (gradientSum, lossSum) = localData.treeAggregate((BDV.zeros[Double](weights.size), 0.0))(
202202
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
203203
val l = localGradient.compute(
204204
features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
205205
(grad, loss + l)
206206
},
207207
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
208208
(grad1 += grad2, loss1 + loss2)
209-
})
209+
}, 2)
210210

211211
/**
212212
* regVal is sum of weight squares if it's L2 updater;

0 commit comments

Comments
 (0)