1717
1818package org .apache .spark .mllib .clustering
1919
20+ import org .apache .spark .SparkContext
2021import org .apache .spark .annotation .Since
22+ import org .apache .spark .broadcast .Broadcast
23+ import org .apache .spark .ml .impl .Utils .indexUpperTriangular
2124import org .apache .spark .mllib .linalg .{Vector , Vectors }
2225import org .apache .spark .mllib .linalg .BLAS .{axpy , dot , scal }
2326import org .apache .spark .mllib .util .MLUtils
2427
2528private [spark] abstract class DistanceMeasure extends Serializable {
2629
30+ /**
31+ * Statistics used in triangle inequality to obtain useful bounds to find closest centers.
32+ * @param distance distance between two centers
33+ */
34+ def computeStatistics (distance : Double ): Double
35+
36+ /**
37+ * Statistics used in triangle inequality to obtain useful bounds to find closest centers.
38+ *
39+ * @return The packed upper triangular part of a symmetric matrix containing statistics,
40+ * matrix(i,j) represents:
41+ * 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
42+ * computation. Given point x, let i be current closest center, and d be current best
43+ * distance, if d < f(r), then we no longer need to compute the distance to center j;
44+ * 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If distance
45+ * between point x and center i is less than f(r), then center i is the closest center
46+ * to point x.
47+ */
48+ def computeStatistics (centers : Array [VectorWithNorm ]): Array [Double ] = {
49+ val k = centers.length
50+ if (k == 1 ) return Array (Double .NaN )
51+
52+ val packedValues = Array .ofDim[Double ](k * (k + 1 ) / 2 )
53+ val diagValues = Array .fill(k)(Double .PositiveInfinity )
54+ var i = 0
55+ while (i < k) {
56+ var j = i + 1
57+ while (j < k) {
58+ val d = distance(centers(i), centers(j))
59+ val s = computeStatistics(d)
60+ val index = indexUpperTriangular(k, i, j)
61+ packedValues(index) = s
62+ if (s < diagValues(i)) diagValues(i) = s
63+ if (s < diagValues(j)) diagValues(j) = s
64+ j += 1
65+ }
66+ i += 1
67+ }
68+
69+ i = 0
70+ while (i < k) {
71+ val index = indexUpperTriangular(k, i, i)
72+ packedValues(index) = diagValues(i)
73+ i += 1
74+ }
75+ packedValues
76+ }
77+
78+ /**
79+ * Compute distance between centers in a distributed way.
80+ */
81+ def computeStatisticsDistributedly (
82+ sc : SparkContext ,
83+ bcCenters : Broadcast [Array [VectorWithNorm ]]): Array [Double ] = {
84+ val k = bcCenters.value.length
85+ if (k == 1 ) return Array (Double .NaN )
86+
87+ val packedValues = Array .ofDim[Double ](k * (k + 1 ) / 2 )
88+ val diagValues = Array .fill(k)(Double .PositiveInfinity )
89+
90+ val numParts = math.min(k, 1024 )
91+ sc.range(0 , numParts, 1 , numParts)
92+ .mapPartitionsWithIndex { case (pid, _) =>
93+ val centers = bcCenters.value
94+ Iterator .range(0 , k).flatMap { i =>
95+ Iterator .range(i + 1 , k).flatMap { j =>
96+ val hash = (i, j).hashCode.abs
97+ if (hash % numParts == pid) {
98+ val d = distance(centers(i), centers(j))
99+ val s = computeStatistics(d)
100+ Iterator .single((i, j, s))
101+ } else Iterator .empty
102+ }
103+ }
104+ }.collect.foreach { case (i, j, s) =>
105+ val index = indexUpperTriangular(k, i, j)
106+ packedValues(index) = s
107+ if (s < diagValues(i)) diagValues(i) = s
108+ if (s < diagValues(j)) diagValues(j) = s
109+ }
110+
111+ var i = 0
112+ while (i < k) {
113+ val index = indexUpperTriangular(k, i, i)
114+ packedValues(index) = diagValues(i)
115+ i += 1
116+ }
117+ packedValues
118+ }
119+
27120 /**
28121 * @return the index of the closest center to the given point, as well as the cost.
29122 */
30123 def findClosest (
31- centers : TraversableOnce [VectorWithNorm ],
124+ centers : Array [VectorWithNorm ],
125+ statistics : Array [Double ],
126+ point : VectorWithNorm ): (Int , Double )
127+
128+ /**
129+ * @return the index of the closest center to the given point, as well as the cost.
130+ */
131+ def findClosest (
132+ centers : Array [VectorWithNorm ],
32133 point : VectorWithNorm ): (Int , Double ) = {
33134 var bestDistance = Double .PositiveInfinity
34135 var bestIndex = 0
35136 var i = 0
36- centers.foreach { center =>
137+ while (i < centers.length) {
138+ val center = centers(i)
37139 val currentDistance = distance(center, point)
38140 if (currentDistance < bestDistance) {
39141 bestDistance = currentDistance
@@ -48,7 +150,7 @@ private[spark] abstract class DistanceMeasure extends Serializable {
48150 * @return the K-means cost of a given point against the given cluster centers.
49151 */
50152 def pointCost (
51- centers : TraversableOnce [VectorWithNorm ],
153+ centers : Array [VectorWithNorm ],
52154 point : VectorWithNorm ): Double = {
53155 findClosest(centers, point)._2
54156 }
@@ -154,22 +256,79 @@ object DistanceMeasure {
154256}
155257
156258private [spark] class EuclideanDistanceMeasure extends DistanceMeasure {
259+
260+ /**
261+ * Statistics used in triangle inequality to obtain useful bounds to find closest centers.
262+ * @see <a href="https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf">Charles Elkan,
263+ * Using the Triangle Inequality to Accelerate k-Means</a>
264+ *
265+ * @return One element used in statistics matrix to make matrix(i,j) represents:
266+ * 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
267+ * computation. Given point x, let i be current closest center, and d be current best
268+ * squared distance, if d < r, then we no longer need to compute the distance to center
269+ * j. matrix(i,j) equals to squared of half of Euclidean distance between centers i
270+ * and j;
271+ * 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If squared
272+ * distance between point x and center i is less than r, then center i is the closest
273+ * center to point x.
274+ */
275+ override def computeStatistics (distance : Double ): Double = {
276+ 0.25 * distance * distance
277+ }
278+
279+ /**
280+ * @return the index of the closest center to the given point, as well as the cost.
281+ */
282+ override def findClosest (
283+ centers : Array [VectorWithNorm ],
284+ statistics : Array [Double ],
285+ point : VectorWithNorm ): (Int , Double ) = {
286+ var bestDistance = EuclideanDistanceMeasure .fastSquaredDistance(centers(0 ), point)
287+ if (bestDistance < statistics(0 )) return (0 , bestDistance)
288+
289+ val k = centers.length
290+ var bestIndex = 0
291+ var i = 1
292+ while (i < k) {
293+ val center = centers(i)
294+ // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
295+ // distance computation.
296+ val normDiff = center.norm - point.norm
297+ val lowerBound = normDiff * normDiff
298+ if (lowerBound < bestDistance) {
299+ val index1 = indexUpperTriangular(k, i, bestIndex)
300+ if (statistics(index1) < bestDistance) {
301+ val d = EuclideanDistanceMeasure .fastSquaredDistance(center, point)
302+ val index2 = indexUpperTriangular(k, i, i)
303+ if (d < statistics(index2)) return (i, d)
304+ if (d < bestDistance) {
305+ bestDistance = d
306+ bestIndex = i
307+ }
308+ }
309+ }
310+ i += 1
311+ }
312+ (bestIndex, bestDistance)
313+ }
314+
157315 /**
158316 * @return the index of the closest center to the given point, as well as the squared distance.
159317 */
160318 override def findClosest (
161- centers : TraversableOnce [VectorWithNorm ],
319+ centers : Array [VectorWithNorm ],
162320 point : VectorWithNorm ): (Int , Double ) = {
163321 var bestDistance = Double .PositiveInfinity
164322 var bestIndex = 0
165323 var i = 0
166- centers.foreach { center =>
324+ while (i < centers.length) {
325+ val center = centers(i)
167326 // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
168327 // distance computation.
169328 var lowerBoundOfSqDist = center.norm - point.norm
170329 lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
171330 if (lowerBoundOfSqDist < bestDistance) {
172- val distance : Double = EuclideanDistanceMeasure .fastSquaredDistance(center, point)
331+ val distance = EuclideanDistanceMeasure .fastSquaredDistance(center, point)
173332 if (distance < bestDistance) {
174333 bestDistance = distance
175334 bestIndex = i
@@ -234,6 +393,58 @@ private[spark] object EuclideanDistanceMeasure {
234393}
235394
236395private [spark] class CosineDistanceMeasure extends DistanceMeasure {
396+
397+ /**
398+ * Statistics used in triangle inequality to obtain useful bounds to find closest centers.
399+ *
400+ * @return One element used in statistics matrix to make matrix(i,j) represents:
401+ * 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
402+ * computation. Given point x, let i be current closest center, and d be current best
403+ * squared distance, if d < r, then we no longer need to compute the distance to center
404+ * j. For Cosine distance, it is similar to Euclidean distance. However, radian/angle
405+ * is used instead of Cosine distance to compute matrix(i,j): for centers i and j,
406+ * compute the radian/angle between them, halving it, and converting it back to Cosine
407+ * distance at the end;
408+ * 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If Cosine
409+ * distance between point x and center i is less than r, then center i is the closest
410+ * center to point x.
411+ */
412+ override def computeStatistics (distance : Double ): Double = {
413+ // d = 1 - cos(x)
414+ // r = 1 - cos(x/2) = 1 - sqrt((cos(x) + 1) / 2) = 1 - sqrt(1 - d/2)
415+ 1 - math.sqrt(1 - distance / 2 )
416+ }
417+
418+ /**
419+ * @return the index of the closest center to the given point, as well as the cost.
420+ */
421+ def findClosest (
422+ centers : Array [VectorWithNorm ],
423+ statistics : Array [Double ],
424+ point : VectorWithNorm ): (Int , Double ) = {
425+ var bestDistance = distance(centers(0 ), point)
426+ if (bestDistance < statistics(0 )) return (0 , bestDistance)
427+
428+ val k = centers.length
429+ var bestIndex = 0
430+ var i = 1
431+ while (i < k) {
432+ val index1 = indexUpperTriangular(k, i, bestIndex)
433+ if (statistics(index1) < bestDistance) {
434+ val center = centers(i)
435+ val d = distance(center, point)
436+ val index2 = indexUpperTriangular(k, i, i)
437+ if (d < statistics(index2)) return (i, d)
438+ if (d < bestDistance) {
439+ bestDistance = d
440+ bestIndex = i
441+ }
442+ }
443+ i += 1
444+ }
445+ (bestIndex, bestDistance)
446+ }
447+
237448 /**
238449 * @param v1: first vector
239450 * @param v2: second vector
0 commit comments