Skip to content

Commit 0ede08b

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-31007][ML] KMeans optimization based on triangle-inequality
### What changes were proposed in this pull request? apply Lemma 1 in [Using the Triangle Inequality to Accelerate K-Means](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf): > Let x be a point, and let b and c be centers. If d(b,c)>=2d(x,b) then d(x,c) >= d(x,b); It can be directly applied in EuclideanDistance, but not in CosineDistance. However, for CosineDistance we can luckily get a variant in the space of radian/angle. ### Why are the changes needed? It help improving the performance of prediction and training (mostly) ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing testsuites Closes #27758 from zhengruifeng/km_triangle. Authored-by: zhengruifeng <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent b10263b commit 0ede08b

File tree

6 files changed

+390
-45
lines changed

6 files changed

+390
-45
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml.impl
1919

2020

21-
private[ml] object Utils {
21+
private[spark] object Utils {
2222

2323
lazy val EPSILON = {
2424
var eps = 1.0
@@ -27,4 +27,55 @@ private[ml] object Utils {
2727
}
2828
eps
2929
}
30+
31+
/**
32+
* Convert an n * (n + 1) / 2 dimension array representing the upper triangular part of a matrix
33+
* into an n * n array representing the full symmetric matrix (column major).
34+
*
35+
* @param n The order of the n by n matrix.
36+
* @param triangularValues The upper triangular part of the matrix packed in an array
37+
* (column major).
38+
* @return A dense matrix which represents the symmetric matrix in column major.
39+
*/
40+
def unpackUpperTriangular(
41+
n: Int,
42+
triangularValues: Array[Double]): Array[Double] = {
43+
val symmetricValues = new Array[Double](n * n)
44+
var r = 0
45+
var i = 0
46+
while (i < n) {
47+
var j = 0
48+
while (j <= i) {
49+
symmetricValues(i * n + j) = triangularValues(r)
50+
symmetricValues(j * n + i) = triangularValues(r)
51+
r += 1
52+
j += 1
53+
}
54+
i += 1
55+
}
56+
symmetricValues
57+
}
58+
59+
/**
60+
* Indexing in an array representing the upper triangular part of a matrix
61+
* into an n * n array representing the full symmetric matrix (column major).
62+
* val symmetricValues = unpackUpperTriangularMatrix(n, triangularValues)
63+
* val matrix = new DenseMatrix(n, n, symmetricValues)
64+
* val index = indexUpperTriangularMatrix(n, i, j)
65+
* then: symmetricValues(index) == matrix(i, j)
66+
*
67+
* @param n The order of the n by n matrix.
68+
*/
69+
def indexUpperTriangular(
70+
n: Int,
71+
i: Int,
72+
j: Int): Int = {
73+
require(i >= 0 && i < n, s"Expected 0 <= i < $n, got i = $i.")
74+
require(j >= 0 && j < n, s"Expected 0 <= j < $n, got j = $j.")
75+
if (i <= j) {
76+
j * (j + 1) / 2 + i
77+
} else {
78+
i * (i + 1) / 2 + j
79+
}
80+
}
3081
}

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path
2222
import org.apache.spark.annotation.Since
2323
import org.apache.spark.broadcast.Broadcast
2424
import org.apache.spark.ml.{Estimator, Model}
25-
import org.apache.spark.ml.impl.Utils.EPSILON
25+
import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON}
2626
import org.apache.spark.ml.linalg._
2727
import org.apache.spark.ml.param._
2828
import org.apache.spark.ml.param.shared._
@@ -583,19 +583,7 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
583583
private[clustering] def unpackUpperTriangularMatrix(
584584
n: Int,
585585
triangularValues: Array[Double]): DenseMatrix = {
586-
val symmetricValues = new Array[Double](n * n)
587-
var r = 0
588-
var i = 0
589-
while (i < n) {
590-
var j = 0
591-
while (j <= i) {
592-
symmetricValues(i * n + j) = triangularValues(r)
593-
symmetricValues(j * n + i) = triangularValues(r)
594-
r += 1
595-
j += 1
596-
}
597-
i += 1
598-
}
586+
val symmetricValues = unpackUpperTriangular(n, triangularValues)
599587
new DenseMatrix(n, n, symmetricValues)
600588
}
601589

mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala

Lines changed: 217 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,125 @@
1717

1818
package org.apache.spark.mllib.clustering
1919

20+
import org.apache.spark.SparkContext
2021
import org.apache.spark.annotation.Since
22+
import org.apache.spark.broadcast.Broadcast
23+
import org.apache.spark.ml.impl.Utils.indexUpperTriangular
2124
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2225
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
2326
import org.apache.spark.mllib.util.MLUtils
2427

2528
private[spark] abstract class DistanceMeasure extends Serializable {
2629

30+
/**
31+
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
32+
* @param distance distance between two centers
33+
*/
34+
def computeStatistics(distance: Double): Double
35+
36+
/**
37+
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
38+
*
39+
* @return The packed upper triangular part of a symmetric matrix containing statistics,
40+
* matrix(i,j) represents:
41+
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
42+
* computation. Given point x, let i be current closest center, and d be current best
43+
* distance, if d < f(r), then we no longer need to compute the distance to center j;
44+
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If distance
45+
* between point x and center i is less than f(r), then center i is the closest center
46+
* to point x.
47+
*/
48+
def computeStatistics(centers: Array[VectorWithNorm]): Array[Double] = {
49+
val k = centers.length
50+
if (k == 1) return Array(Double.NaN)
51+
52+
val packedValues = Array.ofDim[Double](k * (k + 1) / 2)
53+
val diagValues = Array.fill(k)(Double.PositiveInfinity)
54+
var i = 0
55+
while (i < k) {
56+
var j = i + 1
57+
while (j < k) {
58+
val d = distance(centers(i), centers(j))
59+
val s = computeStatistics(d)
60+
val index = indexUpperTriangular(k, i, j)
61+
packedValues(index) = s
62+
if (s < diagValues(i)) diagValues(i) = s
63+
if (s < diagValues(j)) diagValues(j) = s
64+
j += 1
65+
}
66+
i += 1
67+
}
68+
69+
i = 0
70+
while (i < k) {
71+
val index = indexUpperTriangular(k, i, i)
72+
packedValues(index) = diagValues(i)
73+
i += 1
74+
}
75+
packedValues
76+
}
77+
78+
/**
79+
* Compute distance between centers in a distributed way.
80+
*/
81+
def computeStatisticsDistributedly(
82+
sc: SparkContext,
83+
bcCenters: Broadcast[Array[VectorWithNorm]]): Array[Double] = {
84+
val k = bcCenters.value.length
85+
if (k == 1) return Array(Double.NaN)
86+
87+
val packedValues = Array.ofDim[Double](k * (k + 1) / 2)
88+
val diagValues = Array.fill(k)(Double.PositiveInfinity)
89+
90+
val numParts = math.min(k, 1024)
91+
sc.range(0, numParts, 1, numParts)
92+
.mapPartitionsWithIndex { case (pid, _) =>
93+
val centers = bcCenters.value
94+
Iterator.range(0, k).flatMap { i =>
95+
Iterator.range(i + 1, k).flatMap { j =>
96+
val hash = (i, j).hashCode.abs
97+
if (hash % numParts == pid) {
98+
val d = distance(centers(i), centers(j))
99+
val s = computeStatistics(d)
100+
Iterator.single((i, j, s))
101+
} else Iterator.empty
102+
}
103+
}
104+
}.collect.foreach { case (i, j, s) =>
105+
val index = indexUpperTriangular(k, i, j)
106+
packedValues(index) = s
107+
if (s < diagValues(i)) diagValues(i) = s
108+
if (s < diagValues(j)) diagValues(j) = s
109+
}
110+
111+
var i = 0
112+
while (i < k) {
113+
val index = indexUpperTriangular(k, i, i)
114+
packedValues(index) = diagValues(i)
115+
i += 1
116+
}
117+
packedValues
118+
}
119+
27120
/**
28121
* @return the index of the closest center to the given point, as well as the cost.
29122
*/
30123
def findClosest(
31-
centers: TraversableOnce[VectorWithNorm],
124+
centers: Array[VectorWithNorm],
125+
statistics: Array[Double],
126+
point: VectorWithNorm): (Int, Double)
127+
128+
/**
129+
* @return the index of the closest center to the given point, as well as the cost.
130+
*/
131+
def findClosest(
132+
centers: Array[VectorWithNorm],
32133
point: VectorWithNorm): (Int, Double) = {
33134
var bestDistance = Double.PositiveInfinity
34135
var bestIndex = 0
35136
var i = 0
36-
centers.foreach { center =>
137+
while (i < centers.length) {
138+
val center = centers(i)
37139
val currentDistance = distance(center, point)
38140
if (currentDistance < bestDistance) {
39141
bestDistance = currentDistance
@@ -48,7 +150,7 @@ private[spark] abstract class DistanceMeasure extends Serializable {
48150
* @return the K-means cost of a given point against the given cluster centers.
49151
*/
50152
def pointCost(
51-
centers: TraversableOnce[VectorWithNorm],
153+
centers: Array[VectorWithNorm],
52154
point: VectorWithNorm): Double = {
53155
findClosest(centers, point)._2
54156
}
@@ -154,22 +256,79 @@ object DistanceMeasure {
154256
}
155257

156258
private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
259+
260+
/**
261+
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
262+
* @see <a href="https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf">Charles Elkan,
263+
* Using the Triangle Inequality to Accelerate k-Means</a>
264+
*
265+
* @return One element used in statistics matrix to make matrix(i,j) represents:
266+
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
267+
* computation. Given point x, let i be current closest center, and d be current best
268+
* squared distance, if d < r, then we no longer need to compute the distance to center
269+
* j. matrix(i,j) equals to squared of half of Euclidean distance between centers i
270+
* and j;
271+
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If squared
272+
* distance between point x and center i is less than r, then center i is the closest
273+
* center to point x.
274+
*/
275+
override def computeStatistics(distance: Double): Double = {
276+
0.25 * distance * distance
277+
}
278+
279+
/**
280+
* @return the index of the closest center to the given point, as well as the cost.
281+
*/
282+
override def findClosest(
283+
centers: Array[VectorWithNorm],
284+
statistics: Array[Double],
285+
point: VectorWithNorm): (Int, Double) = {
286+
var bestDistance = EuclideanDistanceMeasure.fastSquaredDistance(centers(0), point)
287+
if (bestDistance < statistics(0)) return (0, bestDistance)
288+
289+
val k = centers.length
290+
var bestIndex = 0
291+
var i = 1
292+
while (i < k) {
293+
val center = centers(i)
294+
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
295+
// distance computation.
296+
val normDiff = center.norm - point.norm
297+
val lowerBound = normDiff * normDiff
298+
if (lowerBound < bestDistance) {
299+
val index1 = indexUpperTriangular(k, i, bestIndex)
300+
if (statistics(index1) < bestDistance) {
301+
val d = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
302+
val index2 = indexUpperTriangular(k, i, i)
303+
if (d < statistics(index2)) return (i, d)
304+
if (d < bestDistance) {
305+
bestDistance = d
306+
bestIndex = i
307+
}
308+
}
309+
}
310+
i += 1
311+
}
312+
(bestIndex, bestDistance)
313+
}
314+
157315
/**
158316
* @return the index of the closest center to the given point, as well as the squared distance.
159317
*/
160318
override def findClosest(
161-
centers: TraversableOnce[VectorWithNorm],
319+
centers: Array[VectorWithNorm],
162320
point: VectorWithNorm): (Int, Double) = {
163321
var bestDistance = Double.PositiveInfinity
164322
var bestIndex = 0
165323
var i = 0
166-
centers.foreach { center =>
324+
while (i < centers.length) {
325+
val center = centers(i)
167326
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
168327
// distance computation.
169328
var lowerBoundOfSqDist = center.norm - point.norm
170329
lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
171330
if (lowerBoundOfSqDist < bestDistance) {
172-
val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
331+
val distance = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
173332
if (distance < bestDistance) {
174333
bestDistance = distance
175334
bestIndex = i
@@ -234,6 +393,58 @@ private[spark] object EuclideanDistanceMeasure {
234393
}
235394

236395
private[spark] class CosineDistanceMeasure extends DistanceMeasure {
396+
397+
/**
398+
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
399+
*
400+
* @return One element used in statistics matrix to make matrix(i,j) represents:
401+
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
402+
* computation. Given point x, let i be current closest center, and d be current best
403+
* squared distance, if d < r, then we no longer need to compute the distance to center
404+
* j. For Cosine distance, it is similar to Euclidean distance. However, radian/angle
405+
* is used instead of Cosine distance to compute matrix(i,j): for centers i and j,
406+
* compute the radian/angle between them, halving it, and converting it back to Cosine
407+
* distance at the end;
408+
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If Cosine
409+
* distance between point x and center i is less than r, then center i is the closest
410+
* center to point x.
411+
*/
412+
override def computeStatistics(distance: Double): Double = {
413+
// d = 1 - cos(x)
414+
// r = 1 - cos(x/2) = 1 - sqrt((cos(x) + 1) / 2) = 1 - sqrt(1 - d/2)
415+
1 - math.sqrt(1 - distance / 2)
416+
}
417+
418+
/**
419+
* @return the index of the closest center to the given point, as well as the cost.
420+
*/
421+
def findClosest(
422+
centers: Array[VectorWithNorm],
423+
statistics: Array[Double],
424+
point: VectorWithNorm): (Int, Double) = {
425+
var bestDistance = distance(centers(0), point)
426+
if (bestDistance < statistics(0)) return (0, bestDistance)
427+
428+
val k = centers.length
429+
var bestIndex = 0
430+
var i = 1
431+
while (i < k) {
432+
val index1 = indexUpperTriangular(k, i, bestIndex)
433+
if (statistics(index1) < bestDistance) {
434+
val center = centers(i)
435+
val d = distance(center, point)
436+
val index2 = indexUpperTriangular(k, i, i)
437+
if (d < statistics(index2)) return (i, d)
438+
if (d < bestDistance) {
439+
bestDistance = d
440+
bestIndex = i
441+
}
442+
}
443+
i += 1
444+
}
445+
(bestIndex, bestDistance)
446+
}
447+
237448
/**
238449
* @param v1: first vector
239450
* @param v2: second vector

0 commit comments

Comments
 (0)