diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 92a5af708d04b..4abdc3150890d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -56,10 +56,10 @@ class MultivariateGaussian @Since("1.3.0") ( /** * Compute distribution dependent constants: - * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t + * sigmaInv = sigma^-1^, where sigma = U * D * U.t * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ - private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants + private val (sigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants /** Returns density of this multivariate Gaussian at given point, x */ @@ -83,8 +83,7 @@ class MultivariateGaussian @Since("1.3.0") ( /** Returns the log-density of this multivariate Gaussian at given point, x */ private[mllib] def logpdf(x: BV[Double]): Double = { val delta = x - breezeMu - val v = rootSigmaInv * delta - u + v.t * v * -0.5 + u - 0.5 * (delta.t * (sigmaInv * delta)) } /** @@ -104,11 +103,6 @@ class MultivariateGaussian @Since("1.3.0") ( * * sigma = U * D * U.t * inv(Sigma) = U * inv(D) * U.t - * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U) - * - * and thus - * - * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^ * * To guard against singular covariance matrices, this method computes both the * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered @@ -126,11 +120,11 @@ class MultivariateGaussian @Since("1.3.0") ( // log(pseudo-determinant) is sum of the logs of all non-zero singular values val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum - // calculate the root-pseudo-inverse of the diagonal matrix of singular values - // by inverting the square root of all non-zero values - val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray)) + // calculate the pseudo-inverse of the diagonal matrix of singular values + // by inverting the non-zero values + val pinvS = diag(new DBV(d.map(v => if (v > tol) 1.0 / v else 0.0).toArray)) - (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) + (u * pinvS * u.t, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) } catch { case uex: UnsupportedOperationException => throw new IllegalArgumentException("Covariance matrix has no non-zero singular values") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index aa60deb665aeb..6e7a003475458 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -65,4 +65,19 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5) assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5) } + + test("SPARK-11302") { + val x = Vectors.dense(629, 640, 1.7188, 618.19) + val mu = Vectors.dense( + 1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697) + val sigma = Matrices.dense(4, 4, Array( + 166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053, + 169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484, + 12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373, + 164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207)) + val dist = new MultivariateGaussian(mu, sigma) + // Agrees with R's dmvnorm: 7.154782e-05 + assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9) + } + } diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 6964a45db2493..92a4bd791729a 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -218,9 +218,9 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, ... maxIterations=150, seed=10) >>> labels = model.predict(clusterdata_2).collect() - >>> labels[0]==labels[1]==labels[2] + >>> labels[0]==labels[1] True - >>> labels[3]==labels[4] + >>> labels[2]==labels[3]==labels[4] True """