Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class MultivariateGaussian @Since("1.3.0") (

/**
* Compute distribution dependent constants:
* rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t
* rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t
* u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
*/
private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
Expand Down Expand Up @@ -104,11 +104,11 @@ class MultivariateGaussian @Since("1.3.0") (
*
* sigma = U * D * U.t
* inv(Sigma) = U * inv(D) * U.t
* = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U)
* = (D^{-1/2}^ * U.t).t * (D^{-1/2}^ * U.t)
*
* and thus
*
* -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^
* -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U.t * (x-mu))^2^
*
* To guard against singular covariance matrices, this method computes both the
* pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered
Expand All @@ -130,7 +130,7 @@ class MultivariateGaussian @Since("1.3.0") (
// by inverting the square root of all non-zero values
val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))

(pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
(pinvS * u.t, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
} catch {
case uex: UnsupportedOperationException =>
throw new IllegalArgumentException("Covariance matrix has no non-zero singular values")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,19 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext
assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5)
assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5)
}

test("SPARK-11302") {
val x = Vectors.dense(629, 640, 1.7188, 618.19)
val mu = Vectors.dense(
1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697)
val sigma = Matrices.dense(4, 4, Array(
166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053,
169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484,
12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373,
164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207))
val dist = new MultivariateGaussian(mu, sigma)
// Agrees with R's dmvnorm: 7.154782e-05
assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9)
}

}
4 changes: 2 additions & 2 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
>>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
... maxIterations=150, seed=10)
>>> labels = model.predict(clusterdata_2).collect()
>>> labels[0]==labels[1]==labels[2]
>>> labels[0]==labels[1]
True
>>> labels[3]==labels[4]
>>> labels[2]==labels[3]==labels[4]
True

.. versionadded:: 1.3.0
Expand Down