Skip to content

Commit 473e5a2

Browse files
committed
init
init init init init init init init init init init use nativeBLAS for dense input add py refactor refactor refactor nit revert BLAS.ger revert BLAS.ger revert BLAS.ger nit nit simplify
1 parent ebdf41d commit 473e5a2

File tree

6 files changed

+312
-66
lines changed

6 files changed

+312
-66
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ private[spark] object BLAS extends Serializable {
271271
}
272272

273273
/**
274-
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
274+
* Adds alpha * v * v.t to a matrix in-place. This is the same as BLAS's ?SPR.
275275
*
276276
* @param U the upper triangular part of the matrix packed in an array (column major)
277277
*/

mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class MultivariateGaussian @Since("2.0.0") (
5555
*/
5656
@transient private lazy val tuple = {
5757
val (rootSigmaInv, u) = calculateCovarianceConstants
58-
val rootSigmaInvMat = Matrices.fromBreeze(rootSigmaInv)
58+
val rootSigmaInvMat = Matrices.fromBreeze(rootSigmaInv).toDense
5959
val rootSigmaInvMulMu = rootSigmaInvMat.multiply(mean)
6060
(rootSigmaInvMat, u, rootSigmaInvMulMu)
6161
}
@@ -81,6 +81,43 @@ class MultivariateGaussian @Since("2.0.0") (
8181
u - 0.5 * BLAS.dot(v, v)
8282
}
8383

84+
private[ml] def pdf(X: Matrix): Vector = {
85+
val m = X.numRows
86+
val n = X.numCols
87+
val mat = new DenseMatrix(m, n, Array.ofDim[Double](m * n))
88+
pdf(X, mat)
89+
}
90+
91+
private[ml] def pdf(X: Matrix, mat: DenseMatrix): Vector = {
92+
require(!mat.isTransposed)
93+
val localU = u
94+
val localRootSigmaInvMat = rootSigmaInvMat
95+
val localRootSigmaInvMulMu = rootSigmaInvMulMu.toArray
96+
97+
BLAS.gemm(1.0, X, localRootSigmaInvMat.transpose, 0.0, mat)
98+
val arr = mat.values
99+
val m = mat.numRows
100+
val n = mat.numCols
101+
102+
val pdfArr = Array.ofDim[Double](m)
103+
var i = 0
104+
while (i < m) {
105+
var squaredSum = 0.0
106+
var index = i
107+
var j = 0
108+
while (j < n) {
109+
val d = arr(index) - localRootSigmaInvMulMu(j)
110+
squaredSum += d * d
111+
index += m
112+
j += 1
113+
}
114+
pdfArr(i) = math.exp(localU - 0.5 * squaredSum)
115+
i += 1
116+
}
117+
118+
Vectors.dense(pdfArr)
119+
}
120+
84121
/**
85122
* Calculate distribution dependent components used for the density function:
86123
* pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))

mllib-local/src/test/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussianSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
2727
test("univariate") {
2828
val x1 = Vectors.dense(0.0)
2929
val x2 = Vectors.dense(1.5)
30+
val mat = Matrices.fromVectors(Seq(x1, x2))
3031

3132
val mu = Vectors.dense(0.0)
3233
val sigma1 = Matrices.dense(1, 1, Array(1.0))
@@ -35,18 +36,21 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
3536
assert(dist1.logpdf(x2) ~== -2.0439385332046727 absTol 1E-5)
3637
assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
3738
assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
39+
assert(dist1.pdf(mat) ~== Vectors.dense(0.39894, 0.12952) absTol 1E-5)
3840

3941
val sigma2 = Matrices.dense(1, 1, Array(4.0))
4042
val dist2 = new MultivariateGaussian(mu, sigma2)
4143
assert(dist2.logpdf(x1) ~== -1.612085713764618 absTol 1E-5)
4244
assert(dist2.logpdf(x2) ~== -1.893335713764618 absTol 1E-5)
4345
assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
4446
assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
47+
assert(dist2.pdf(mat) ~== Vectors.dense(0.19947, 0.15057) absTol 1E-5)
4548
}
4649

4750
test("multivariate") {
4851
val x1 = Vectors.dense(0.0, 0.0)
4952
val x2 = Vectors.dense(1.0, 1.0)
53+
val mat = Matrices.fromVectors(Seq(x1, x2))
5054

5155
val mu = Vectors.dense(0.0, 0.0)
5256
val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
@@ -55,28 +59,33 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
5559
assert(dist1.logpdf(x2) ~== -2.8378770664093453 absTol 1E-5)
5660
assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
5761
assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
62+
assert(dist1.pdf(mat) ~== Vectors.dense(0.15915, 0.05855) absTol 1E-5)
5863

5964
val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
6065
val dist2 = new MultivariateGaussian(mu, sigma2)
6166
assert(dist2.logpdf(x1) ~== -2.810832140937002 absTol 1E-5)
6267
assert(dist2.logpdf(x2) ~== -3.3822607123655732 absTol 1E-5)
6368
assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
6469
assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
70+
assert(dist2.pdf(mat) ~== Vectors.dense(0.060155, 0.033971) absTol 1E-5)
6571
}
6672

6773
test("multivariate degenerate") {
6874
val x1 = Vectors.dense(0.0, 0.0)
6975
val x2 = Vectors.dense(1.0, 1.0)
76+
val mat = Matrices.fromVectors(Seq(x1, x2))
7077

7178
val mu = Vectors.dense(0.0, 0.0)
7279
val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
7380
val dist = new MultivariateGaussian(mu, sigma)
7481
assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
7582
assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
83+
assert(dist.pdf(mat) ~== Vectors.dense(0.11254, 0.068259) absTol 1E-5)
7684
}
7785

7886
test("SPARK-11302") {
7987
val x = Vectors.dense(629, 640, 1.7188, 618.19)
88+
val mat = Matrices.fromVectors(Seq(x))
8089
val mu = Vectors.dense(
8190
1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697)
8291
val sigma = Matrices.dense(4, 4, Array(
@@ -87,5 +96,6 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
8796
val dist = new MultivariateGaussian(mu, sigma)
8897
// Agrees with R's dmvnorm: 7.154782e-05
8998
assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9)
99+
assert(dist.pdf(mat) ~== Vectors.dense(7.154782224045512E-5) absTol 1E-5)
90100
}
91101
}

0 commit comments

Comments
 (0)