Skip to content

Commit f0f563a

Browse files
committed
[SPARK-100354] [MLLIB] fix some apparent memory issues in k-means|| initializaiton
* do not cache first cost RDD * change following cost RDD cache level to MEMORY_AND_DISK * remove Vector wrapper to save a object per instance Further improvements will be addressed in SPARK-10329 cc: yu-iskw HuJiayin Author: Xiangrui Meng <[email protected]> Closes #8526 from mengxr/SPARK-10354.
1 parent 8694c3a commit f0f563a

File tree

1 file changed

+14
-7
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/clustering

1 file changed

+14
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ class KMeans private (
369369
: Array[Array[VectorWithNorm]] = {
370370
// Initialize empty centers and point costs.
371371
val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
372-
var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()
372+
var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity))
373373

374374
// Initialize each run's first center to a random point.
375375
val seed = new XORShiftRandom(this.seed).nextInt()
@@ -394,21 +394,28 @@ class KMeans private (
394394
val bcNewCenters = data.context.broadcast(newCenters)
395395
val preCosts = costs
396396
costs = data.zip(preCosts).map { case (point, cost) =>
397-
Vectors.dense(
398397
Array.tabulate(runs) { r =>
399398
math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
400-
})
401-
}.cache()
399+
}
400+
}.persist(StorageLevel.MEMORY_AND_DISK)
402401
val sumCosts = costs
403-
.aggregate(Vectors.zeros(runs))(
402+
.aggregate(new Array[Double](runs))(
404403
seqOp = (s, v) => {
405404
// s += v
406-
axpy(1.0, v, s)
405+
var r = 0
406+
while (r < runs) {
407+
s(r) += v(r)
408+
r += 1
409+
}
407410
s
408411
},
409412
combOp = (s0, s1) => {
410413
// s0 += s1
411-
axpy(1.0, s1, s0)
414+
var r = 0
415+
while (r < runs) {
416+
s0(r) += s1(r)
417+
r += 1
418+
}
412419
s0
413420
}
414421
)

0 commit comments

Comments
 (0)