Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 50 additions & 15 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,45 +279,80 @@ class KMeans private (
*/
private def initKMeansParallel(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Initialize each run's center to a random point
// Initialize empty centers and point costs.
val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()

// Initialize each run's first center to a random point.
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))

/** Merges new centers to centers. */
def mergeNewCenters(): Unit = {
var r = 0
while (r < runs) {
centers(r) ++= newCenters(r)
newCenters(r).clear()
r += 1
}
}

// On each step, sample 2 * k points on average for each run with probability proportional
// to their squared distance from that run's current centers
// to their squared distance from that run's centers. Note that only distances between points
// and new centers are computed in each iteration.
var step = 0
while (step < initializationSteps) {
val bcCenters = data.context.broadcast(centers)
val sumCosts = data.flatMap { point =>
(0 until runs).map { r =>
(r, KMeans.pointCost(bcCenters.value(r), point))
}
}.reduceByKey(_ + _).collectAsMap()
val chosen = data.mapPartitionsWithIndex { (index, points) =>
val bcNewCenters = data.context.broadcast(newCenters)
val preCosts = costs
costs = data.zip(preCosts).map { case (point, cost) =>
Vectors.dense(
Array.tabulate(runs) { r =>
math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
})
}.cache()
val sumCosts = costs
.aggregate(Vectors.zeros(runs))(
seqOp = (s, v) => {
// s += v
axpy(1.0, v, s)
s
},
combOp = (s0, s1) => {
// s0 += s1
axpy(1.0, s1, s0)
s0
}
)
preCosts.unpersist(blocking = false)
val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
points.flatMap { p =>
pointsWithCosts.flatMap { case (p, c) =>
(0 until runs).filter { r =>
rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
}.map((_, p))
}
}.collect()
mergeNewCenters()
chosen.foreach { case (r, p) =>
centers(r) += p.toDense
newCenters(r) += p.toDense
}
step += 1
}

mergeNewCenters()
costs.unpersist(blocking = false)

// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
val bcCenters = data.context.broadcast(centers)
val weightMap = data.flatMap { p =>
(0 until runs).map { r =>
Iterator.tabulate(runs) { r =>
((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
val finalCenters = (0 until runs).map { r =>
val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
Expand Down