@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
2020import org .apache .spark .SparkContext
2121import org .apache .spark .annotation .Since
2222import org .apache .spark .broadcast .Broadcast
23+ import org .apache .spark .ml .impl .Utils .indexUpperTriangular
2324import org .apache .spark .mllib .linalg .{Vector , Vectors }
2425import org .apache .spark .mllib .linalg .BLAS .{axpy , dot , scal }
2526import org .apache .spark .mllib .util .MLUtils
@@ -35,51 +36,58 @@ private[spark] abstract class DistanceMeasure extends Serializable {
3536 /**
3637 * Statistics used in triangle inequality to obtain useful bounds to find closest centers.
3738 *
38- * @return A symmetric matrix containing statistics, matrix(i)(j) represents:
39+ * @return The upper triangular part of a symmetric matrix containing statistics, matrix(i)(j)
40+ * represents:
3941 * 1, a lower bound r of the center i, if i==j. If distance between point x and center i
4042 * is less than f(r), then center i is the closest center to point x.
4143 * 2, a lower bound r=matrix(i)(j) to help avoiding unnecessary distance computation.
4244 * Given point x, let i be current closest center, and d be current best distance,
4345 * if d < f(r), then we no longer need to compute the distance to center j.
4446 */
45- def computeStatistics (centers : Array [VectorWithNorm ]): Array [Array [ Double ] ] = {
47+ def computeStatistics (centers : Array [VectorWithNorm ]): Array [Double ] = {
4648 val k = centers.length
47- if (k == 1 ) return Array (Array ( Double .NaN ) )
49+ if (k == 1 ) return Array (Double .NaN )
4850
49- val stats = Array .ofDim[Double ](k, k)
51+ val packedValues = Array .ofDim[Double ](k * (k + 1 ) / 2 )
52+ val diagValues = Array .fill(k)(Double .PositiveInfinity )
5053 var i = 0
51- while (i < k) {
52- stats(i)(i) = Double .PositiveInfinity
53- i += 1
54- }
55- i = 0
5654 while (i < k) {
5755 var j = i + 1
5856 while (j < k) {
5957 val d = distance(centers(i), centers(j))
6058 val s = computeStatistics(d)
61- stats(i)(j) = s
62- stats(j)(i ) = s
63- if (s < stats (i)(i)) stats(i) (i) = s
64- if (s < stats (j)(j)) stats(j) (j) = s
59+ val index = indexUpperTriangular(k, i, j)
60+ packedValues(index ) = s
61+ if (s < diagValues (i)) diagValues (i) = s
62+ if (s < diagValues (j)) diagValues (j) = s
6563 j += 1
6664 }
6765 i += 1
6866 }
69- stats
67+
68+ i = 0
69+ while (i < k) {
70+ val index = indexUpperTriangular(k, i, i)
71+ packedValues(index) = diagValues(i)
72+ i += 1
73+ }
74+ packedValues
7075 }
7176
7277 /**
7378 * Compute distance between centers in a distributed way.
7479 */
7580 def computeStatisticsDistributedly (
7681 sc : SparkContext ,
77- bcCenters : Broadcast [Array [VectorWithNorm ]]): Array [Array [ Double ] ] = {
82+ bcCenters : Broadcast [Array [VectorWithNorm ]]): Array [Double ] = {
7883 val k = bcCenters.value.length
79- if (k == 1 ) return Array (Array (Double .NaN ))
84+ if (k == 1 ) return Array (Double .NaN )
85+
86+ val packedValues = Array .ofDim[Double ](k * (k + 1 ) / 2 )
87+ val diagValues = Array .fill(k)(Double .PositiveInfinity )
8088
8189 val numParts = math.min(k, 1024 )
82- val collected = sc.range(0 , numParts, 1 , numParts)
90+ sc.range(0 , numParts, 1 , numParts)
8391 .mapPartitionsWithIndex { case (pid, _) =>
8492 val centers = bcCenters.value
8593 Iterator .range(0 , k).flatMap { i =>
@@ -88,40 +96,32 @@ private[spark] abstract class DistanceMeasure extends Serializable {
8896 if (hash % numParts == pid) {
8997 val d = distance(centers(i), centers(j))
9098 val s = computeStatistics(d)
91- Iterator .single((( i, j) , s))
99+ Iterator .single((i, j, s))
92100 } else Iterator .empty
93101 }
94102 }.filterNot(_._2 == 0 )
95- }.collectAsMap()
103+ }.foreach { case (i, j, s) =>
104+ val index = indexUpperTriangular(k, i, j)
105+ packedValues(index) = s
106+ if (s < diagValues(i)) diagValues(i) = s
107+ if (s < diagValues(j)) diagValues(j) = s
108+ }
96109
97- val stats = Array .ofDim[Double ](k, k)
98110 var i = 0
99111 while (i < k) {
100- stats(i)(i) = Double .PositiveInfinity
101- i += 1
102- }
103- i = 0
104- while (i < k) {
105- var j = i + 1
106- while (j < k) {
107- val s = collected.getOrElse((i, j), 0.0 )
108- stats(i)(j) = s
109- stats(j)(i) = s
110- if (s < stats(i)(i)) stats(i)(i) = s
111- if (s < stats(j)(j)) stats(j)(j) = s
112- j += 1
113- }
112+ val index = indexUpperTriangular(k, i, i)
113+ packedValues(index) = diagValues(i)
114114 i += 1
115115 }
116- stats
116+ packedValues
117117 }
118118
119119 /**
120120 * @return the index of the closest center to the given point, as well as the cost.
121121 */
122122 def findClosest (
123123 centers : Array [VectorWithNorm ],
124- statistics : Array [Array [ Double ] ],
124+ statistics : Array [Double ],
125125 point : VectorWithNorm ): (Int , Double )
126126
127127 /**
@@ -279,28 +279,33 @@ private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
279279 */
280280 override def findClosest (
281281 centers : Array [VectorWithNorm ],
282- statistics : Array [Array [ Double ] ],
282+ statistics : Array [Double ],
283283 point : VectorWithNorm ): (Int , Double ) = {
284284 var bestDistance = EuclideanDistanceMeasure .fastSquaredDistance(centers(0 ), point)
285- if (bestDistance < statistics(0 )( 0 ) ) {
285+ if (bestDistance < statistics(0 )) {
286286 return (0 , bestDistance)
287287 }
288288
289+ val k = centers.length
289290 var bestIndex = 0
290291 var i = 1
291- while (i < centers.length ) {
292+ while (i < k ) {
292293 val center = centers(i)
293294 // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
294295 // distance computation.
295296 val normDiff = center.norm - point.norm
296297 val lowerBound = normDiff * normDiff
297- if (lowerBound < bestDistance && statistics(i)(bestIndex) < bestDistance) {
298- val d = EuclideanDistanceMeasure .fastSquaredDistance(center, point)
299- if (d < statistics(i)(i)) {
300- return (i, d)
301- } else if (d < bestDistance) {
302- bestDistance = d
303- bestIndex = i
298+ if (lowerBound < bestDistance) {
299+ val index1 = indexUpperTriangular(k, i, bestIndex)
300+ if (statistics(index1) < bestDistance) {
301+ val d = EuclideanDistanceMeasure .fastSquaredDistance(center, point)
302+ val index2 = indexUpperTriangular(k, i, i)
303+ if (d < statistics(index2)) {
304+ return (i, d)
305+ } else if (d < bestDistance) {
306+ bestDistance = d
307+ bestIndex = i
308+ }
304309 }
305310 }
306311 i += 1
@@ -415,20 +420,23 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
415420 */
416421 def findClosest (
417422 centers : Array [VectorWithNorm ],
418- statistics : Array [Array [ Double ] ],
423+ statistics : Array [Double ],
419424 point : VectorWithNorm ): (Int , Double ) = {
420425 var bestDistance = distance(centers(0 ), point)
421- if (bestDistance < statistics(0 )( 0 ) ) {
426+ if (bestDistance < statistics(0 )) {
422427 return (0 , bestDistance)
423428 }
424429
430+ val k = centers.length
425431 var bestIndex = 0
426432 var i = 1
427- while (i < centers.length) {
428- if (statistics(i)(bestIndex) < bestDistance) {
433+ while (i < k) {
434+ val index1 = indexUpperTriangular(k, i, bestIndex)
435+ if (statistics(index1) < bestDistance) {
429436 val center = centers(i)
430437 val d = distance(center, point)
431- if (d < statistics(i)(i)) {
438+ val index2 = indexUpperTriangular(k, i, i)
439+ if (d < statistics(index2)) {
432440 return (i, d)
433441 } else if (d < bestDistance) {
434442 bestDistance = d
0 commit comments