Skip to content

Commit 2b10eb2

Browse files
committed
package upper tri
package upper tri package upper tri package upper tri
1 parent b4fabb1 commit 2b10eb2

File tree

3 files changed

+113
-66
lines changed

3 files changed

+113
-66
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml.impl
1919

2020

21-
private[ml] object Utils {
21+
private[spark] object Utils {
2222

2323
lazy val EPSILON = {
2424
var eps = 1.0
@@ -27,4 +27,55 @@ private[ml] object Utils {
2727
}
2828
eps
2929
}
30+
31+
/**
32+
* Convert an n * (n + 1) / 2 dimension array representing the upper triangular part of a matrix
33+
* into an n * n array representing the full symmetric matrix (column major).
34+
*
35+
* @param n The order of the n by n matrix.
36+
* @param triangularValues The upper triangular part of the matrix packed in an array
37+
* (column major).
38+
* @return A dense matrix which represents the symmetric matrix in column major.
39+
*/
40+
def unpackUpperTriangular(
41+
n: Int,
42+
triangularValues: Array[Double]): Array[Double] = {
43+
val symmetricValues = new Array[Double](n * n)
44+
var r = 0
45+
var i = 0
46+
while (i < n) {
47+
var j = 0
48+
while (j <= i) {
49+
symmetricValues(i * n + j) = triangularValues(r)
50+
symmetricValues(j * n + i) = triangularValues(r)
51+
r += 1
52+
j += 1
53+
}
54+
i += 1
55+
}
56+
symmetricValues
57+
}
58+
59+
/**
60+
* Indexing in an array representing the upper triangular part of a matrix
61+
* into an n * n array representing the full symmetric matrix (column major).
62+
* val symmetricValues = unpackUpperTriangularMatrix(n, triangularValues)
63+
* val matrix = new DenseMatrix(n, n, symmetricValues)
64+
* val index = indexUpperTriangularMatrix(n, i, j)
65+
* then: symmetricValues(index) == matrix(i, j)
66+
*
67+
* @param n The order of the n by n matrix.
68+
*/
69+
def indexUpperTriangular(
70+
n: Int,
71+
i: Int,
72+
j: Int): Int = {
73+
require(i >= 0 && i < n, s"Expected 0 <= i < $n, got i = $i.")
74+
require(j >= 0 && j < n, s"Expected 0 <= j < $n, got j = $j.")
75+
if (i <= j) {
76+
j * (j + 1) / 2 + i
77+
} else {
78+
i * (i + 1) / 2 + j
79+
}
80+
}
3081
}

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path
2222
import org.apache.spark.annotation.Since
2323
import org.apache.spark.broadcast.Broadcast
2424
import org.apache.spark.ml.{Estimator, Model}
25-
import org.apache.spark.ml.impl.Utils.EPSILON
25+
import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON}
2626
import org.apache.spark.ml.linalg._
2727
import org.apache.spark.ml.param._
2828
import org.apache.spark.ml.param.shared._
@@ -583,19 +583,7 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
583583
private[clustering] def unpackUpperTriangularMatrix(
584584
n: Int,
585585
triangularValues: Array[Double]): DenseMatrix = {
586-
val symmetricValues = new Array[Double](n * n)
587-
var r = 0
588-
var i = 0
589-
while (i < n) {
590-
var j = 0
591-
while (j <= i) {
592-
symmetricValues(i * n + j) = triangularValues(r)
593-
symmetricValues(j * n + i) = triangularValues(r)
594-
r += 1
595-
j += 1
596-
}
597-
i += 1
598-
}
586+
val symmetricValues = unpackUpperTriangular(n, triangularValues)
599587
new DenseMatrix(n, n, symmetricValues)
600588
}
601589

mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala

Lines changed: 59 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
2020
import org.apache.spark.SparkContext
2121
import org.apache.spark.annotation.Since
2222
import org.apache.spark.broadcast.Broadcast
23+
import org.apache.spark.ml.impl.Utils.indexUpperTriangular
2324
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2425
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
2526
import org.apache.spark.mllib.util.MLUtils
@@ -35,51 +36,58 @@ private[spark] abstract class DistanceMeasure extends Serializable {
3536
/**
3637
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
3738
*
38-
* @return A symmetric matrix containing statistics, matrix(i)(j) represents:
39+
* @return The upper triangular part of a symmetric matrix containing statistics, matrix(i)(j)
40+
* represents:
3941
* 1, a lower bound r of the center i, if i==j. If distance between point x and center i
4042
* is less than f(r), then center i is the closest center to point x.
4143
* 2, a lower bound r=matrix(i)(j) to help avoiding unnecessary distance computation.
4244
* Given point x, let i be current closest center, and d be current best distance,
4345
* if d < f(r), then we no longer need to compute the distance to center j.
4446
*/
45-
def computeStatistics(centers: Array[VectorWithNorm]): Array[Array[Double]] = {
47+
def computeStatistics(centers: Array[VectorWithNorm]): Array[Double] = {
4648
val k = centers.length
47-
if (k == 1) return Array(Array(Double.NaN))
49+
if (k == 1) return Array(Double.NaN)
4850

49-
val stats = Array.ofDim[Double](k, k)
51+
val packedValues = Array.ofDim[Double](k * (k + 1) / 2)
52+
val diagValues = Array.fill(k)(Double.PositiveInfinity)
5053
var i = 0
51-
while (i < k) {
52-
stats(i)(i) = Double.PositiveInfinity
53-
i += 1
54-
}
55-
i = 0
5654
while (i < k) {
5755
var j = i + 1
5856
while (j < k) {
5957
val d = distance(centers(i), centers(j))
6058
val s = computeStatistics(d)
61-
stats(i)(j) = s
62-
stats(j)(i) = s
63-
if (s < stats(i)(i)) stats(i)(i) = s
64-
if (s < stats(j)(j)) stats(j)(j) = s
59+
val index = indexUpperTriangular(k, i, j)
60+
packedValues(index) = s
61+
if (s < diagValues(i)) diagValues(i) = s
62+
if (s < diagValues(j)) diagValues(j) = s
6563
j += 1
6664
}
6765
i += 1
6866
}
69-
stats
67+
68+
i = 0
69+
while (i < k) {
70+
val index = indexUpperTriangular(k, i, i)
71+
packedValues(index) = diagValues(i)
72+
i += 1
73+
}
74+
packedValues
7075
}
7176

7277
/**
7378
* Compute distance between centers in a distributed way.
7479
*/
7580
def computeStatisticsDistributedly(
7681
sc: SparkContext,
77-
bcCenters: Broadcast[Array[VectorWithNorm]]): Array[Array[Double]] = {
82+
bcCenters: Broadcast[Array[VectorWithNorm]]): Array[Double] = {
7883
val k = bcCenters.value.length
79-
if (k == 1) return Array(Array(Double.NaN))
84+
if (k == 1) return Array(Double.NaN)
85+
86+
val packedValues = Array.ofDim[Double](k * (k + 1) / 2)
87+
val diagValues = Array.fill(k)(Double.PositiveInfinity)
8088

8189
val numParts = math.min(k, 1024)
82-
val collected = sc.range(0, numParts, 1, numParts)
90+
sc.range(0, numParts, 1, numParts)
8391
.mapPartitionsWithIndex { case (pid, _) =>
8492
val centers = bcCenters.value
8593
Iterator.range(0, k).flatMap { i =>
@@ -88,40 +96,32 @@ private[spark] abstract class DistanceMeasure extends Serializable {
8896
if (hash % numParts == pid) {
8997
val d = distance(centers(i), centers(j))
9098
val s = computeStatistics(d)
91-
Iterator.single(((i, j), s))
99+
Iterator.single((i, j, s))
92100
} else Iterator.empty
93101
}
94102
}.filterNot(_._2 == 0)
95-
}.collectAsMap()
103+
}.foreach { case (i, j, s) =>
104+
val index = indexUpperTriangular(k, i, j)
105+
packedValues(index) = s
106+
if (s < diagValues(i)) diagValues(i) = s
107+
if (s < diagValues(j)) diagValues(j) = s
108+
}
96109

97-
val stats = Array.ofDim[Double](k, k)
98110
var i = 0
99111
while (i < k) {
100-
stats(i)(i) = Double.PositiveInfinity
101-
i += 1
102-
}
103-
i = 0
104-
while (i < k) {
105-
var j = i + 1
106-
while (j < k) {
107-
val s = collected.getOrElse((i, j), 0.0)
108-
stats(i)(j) = s
109-
stats(j)(i) = s
110-
if (s < stats(i)(i)) stats(i)(i) = s
111-
if (s < stats(j)(j)) stats(j)(j) = s
112-
j += 1
113-
}
112+
val index = indexUpperTriangular(k, i, i)
113+
packedValues(index) = diagValues(i)
114114
i += 1
115115
}
116-
stats
116+
packedValues
117117
}
118118

119119
/**
120120
* @return the index of the closest center to the given point, as well as the cost.
121121
*/
122122
def findClosest(
123123
centers: Array[VectorWithNorm],
124-
statistics: Array[Array[Double]],
124+
statistics: Array[Double],
125125
point: VectorWithNorm): (Int, Double)
126126

127127
/**
@@ -279,28 +279,33 @@ private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
279279
*/
280280
override def findClosest(
281281
centers: Array[VectorWithNorm],
282-
statistics: Array[Array[Double]],
282+
statistics: Array[Double],
283283
point: VectorWithNorm): (Int, Double) = {
284284
var bestDistance = EuclideanDistanceMeasure.fastSquaredDistance(centers(0), point)
285-
if (bestDistance < statistics(0)(0)) {
285+
if (bestDistance < statistics(0)) {
286286
return (0, bestDistance)
287287
}
288288

289+
val k = centers.length
289290
var bestIndex = 0
290291
var i = 1
291-
while (i < centers.length) {
292+
while (i < k) {
292293
val center = centers(i)
293294
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
294295
// distance computation.
295296
val normDiff = center.norm - point.norm
296297
val lowerBound = normDiff * normDiff
297-
if (lowerBound < bestDistance && statistics(i)(bestIndex) < bestDistance) {
298-
val d = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
299-
if (d < statistics(i)(i)) {
300-
return (i, d)
301-
} else if (d < bestDistance) {
302-
bestDistance = d
303-
bestIndex = i
298+
if (lowerBound < bestDistance) {
299+
val index1 = indexUpperTriangular(k, i, bestIndex)
300+
if (statistics(index1) < bestDistance) {
301+
val d = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
302+
val index2 = indexUpperTriangular(k, i, i)
303+
if (d < statistics(index2)) {
304+
return (i, d)
305+
} else if (d < bestDistance) {
306+
bestDistance = d
307+
bestIndex = i
308+
}
304309
}
305310
}
306311
i += 1
@@ -415,20 +420,23 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
415420
*/
416421
def findClosest(
417422
centers: Array[VectorWithNorm],
418-
statistics: Array[Array[Double]],
423+
statistics: Array[Double],
419424
point: VectorWithNorm): (Int, Double) = {
420425
var bestDistance = distance(centers(0), point)
421-
if (bestDistance < statistics(0)(0)) {
426+
if (bestDistance < statistics(0)) {
422427
return (0, bestDistance)
423428
}
424429

430+
val k = centers.length
425431
var bestIndex = 0
426432
var i = 1
427-
while (i < centers.length) {
428-
if (statistics(i)(bestIndex) < bestDistance) {
433+
while (i < k) {
434+
val index1 = indexUpperTriangular(k, i, bestIndex)
435+
if (statistics(index1) < bestDistance) {
429436
val center = centers(i)
430437
val d = distance(center, point)
431-
if (d < statistics(i)(i)) {
438+
val index2 = indexUpperTriangular(k, i, i)
439+
if (d < statistics(index2)) {
432440
return (i, d)
433441
} else if (d < bestDistance) {
434442
bestDistance = d

0 commit comments

Comments
 (0)