@@ -27,6 +27,7 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
2727 test(" univariate" ) {
2828 val x1 = Vectors .dense(0.0 )
2929 val x2 = Vectors .dense(1.5 )
30+ val mat = Matrices .fromVectors(Seq (x1, x2))
3031
3132 val mu = Vectors .dense(0.0 )
3233 val sigma1 = Matrices .dense(1 , 1 , Array (1.0 ))
@@ -35,18 +36,21 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
3536 assert(dist1.logpdf(x2) ~== - 2.0439385332046727 absTol 1E-5 )
3637 assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5 )
3738 assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5 )
39+ assert(dist1.pdf(mat) ~== Vectors .dense(0.39894 , 0.12952 ) absTol 1E-5 )
3840
3941 val sigma2 = Matrices .dense(1 , 1 , Array (4.0 ))
4042 val dist2 = new MultivariateGaussian (mu, sigma2)
4143 assert(dist2.logpdf(x1) ~== - 1.612085713764618 absTol 1E-5 )
4244 assert(dist2.logpdf(x2) ~== - 1.893335713764618 absTol 1E-5 )
4345 assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5 )
4446 assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5 )
47+ assert(dist2.pdf(mat) ~== Vectors .dense(0.19947 , 0.15057 ) absTol 1E-5 )
4548 }
4649
4750 test(" multivariate" ) {
4851 val x1 = Vectors .dense(0.0 , 0.0 )
4952 val x2 = Vectors .dense(1.0 , 1.0 )
53+ val mat = Matrices .fromVectors(Seq (x1, x2))
5054
5155 val mu = Vectors .dense(0.0 , 0.0 )
5256 val sigma1 = Matrices .dense(2 , 2 , Array (1.0 , 0.0 , 0.0 , 1.0 ))
@@ -55,28 +59,33 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
5559 assert(dist1.logpdf(x2) ~== - 2.8378770664093453 absTol 1E-5 )
5660 assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5 )
5761 assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5 )
62+ assert(dist1.pdf(mat) ~== Vectors .dense(0.15915 , 0.05855 ) absTol 1E-5 )
5863
5964 val sigma2 = Matrices .dense(2 , 2 , Array (4.0 , - 1.0 , - 1.0 , 2.0 ))
6065 val dist2 = new MultivariateGaussian (mu, sigma2)
6166 assert(dist2.logpdf(x1) ~== - 2.810832140937002 absTol 1E-5 )
6267 assert(dist2.logpdf(x2) ~== - 3.3822607123655732 absTol 1E-5 )
6368 assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5 )
6469 assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5 )
70+ assert(dist2.pdf(mat) ~== Vectors .dense(0.060155 , 0.033971 ) absTol 1E-5 )
6571 }
6672
6773 test(" multivariate degenerate" ) {
6874 val x1 = Vectors .dense(0.0 , 0.0 )
6975 val x2 = Vectors .dense(1.0 , 1.0 )
76+ val mat = Matrices .fromVectors(Seq (x1, x2))
7077
7178 val mu = Vectors .dense(0.0 , 0.0 )
7279 val sigma = Matrices .dense(2 , 2 , Array (1.0 , 1.0 , 1.0 , 1.0 ))
7380 val dist = new MultivariateGaussian (mu, sigma)
7481 assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5 )
7582 assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5 )
83+ assert(dist.pdf(mat) ~== Vectors .dense(0.11254 , 0.068259 ) absTol 1E-5 )
7684 }
7785
7886 test(" SPARK-11302" ) {
7987 val x = Vectors .dense(629 , 640 , 1.7188 , 618.19 )
88+ val mat = Matrices .fromVectors(Seq (x))
8089 val mu = Vectors .dense(
8190 1055.3910505836575 , 1070.489299610895 , 1.39020554474708 , 1040.5907503867697 )
8291 val sigma = Matrices .dense(4 , 4 , Array (
@@ -87,5 +96,6 @@ class MultivariateGaussianSuite extends SparkMLFunSuite {
8796 val dist = new MultivariateGaussian (mu, sigma)
8897 // Agrees with R's dmvnorm: 7.154782e-05
8998 assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9 )
99+ assert(dist.pdf(mat) ~== Vectors .dense(7.154782224045512E-5 ) absTol 1E-5 )
90100 }
91101}
0 commit comments