Skip to content

Commit 86ee81e

Browse files
srowenmengxr
authored andcommitted
[SPARK-11302][MLLIB] 2) Multivariate Gaussian Model with Covariance matrix returns incorrect answer in some cases
Fix computation of root-sigma-inverse in multivariate Gaussian; add a test and fix related Python mixture model test. Supersedes #9293 Author: Sean Owen <[email protected]> Closes #9309 from srowen/SPARK-11302.2. (cherry picked from commit 826e1e3) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent abb0ca7 commit 86ee81e

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class MultivariateGaussian @Since("1.3.0") (
5656

5757
/**
5858
* Compute distribution dependent constants:
59-
* rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t
59+
* rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t
6060
* u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
6161
*/
6262
private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
@@ -104,11 +104,11 @@ class MultivariateGaussian @Since("1.3.0") (
104104
*
105105
* sigma = U * D * U.t
106106
* inv(Sigma) = U * inv(D) * U.t
107-
* = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U)
107+
* = (D^{-1/2}^ * U.t).t * (D^{-1/2}^ * U.t)
108108
*
109109
* and thus
110110
*
111-
* -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^
111+
* -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U.t * (x-mu))^2^
112112
*
113113
* To guard against singular covariance matrices, this method computes both the
114114
* pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered
@@ -130,7 +130,7 @@ class MultivariateGaussian @Since("1.3.0") (
130130
// by inverting the square root of all non-zero values
131131
val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
132132

133-
(pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
133+
(pinvS * u.t, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
134134
} catch {
135135
case uex: UnsupportedOperationException =>
136136
throw new IllegalArgumentException("Covariance matrix has no non-zero singular values")

mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,19 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext
6565
assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
6666
assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
6767
}
68+
69+
test("SPARK-11302") {
70+
val x = Vectors.dense(629, 640, 1.7188, 618.19)
71+
val mu = Vectors.dense(
72+
1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697)
73+
val sigma = Matrices.dense(4, 4, Array(
74+
166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053,
75+
169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484,
76+
12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373,
77+
164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207))
78+
val dist = new MultivariateGaussian(mu, sigma)
79+
// Agrees with R's dmvnorm: 7.154782e-05
80+
assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9)
81+
}
82+
6883
}

python/pyspark/mllib/clustering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,9 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
205205
>>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
206206
... maxIterations=150, seed=10)
207207
>>> labels = model.predict(clusterdata_2).collect()
208-
>>> labels[0]==labels[1]==labels[2]
208+
>>> labels[0]==labels[1]
209209
True
210-
>>> labels[3]==labels[4]
210+
>>> labels[2]==labels[3]==labels[4]
211211
True
212212
"""
213213

0 commit comments

Comments
 (0)