From 930a98e1389ba2a4aa72e85e8ce0662ff95c903a Mon Sep 17 00:00:00 2001 From: lewuathe Date: Tue, 6 Jan 2015 23:02:02 +0900 Subject: [PATCH 1/2] [SPARK-5019] Update GMM API to use MultivariateGaussian --- .../spark/examples/mllib/DenseGmmEM.scala | 3 ++- .../mllib/clustering/GaussianMixtureEM.scala | 4 +--- .../mllib/clustering/GaussianMixtureModel.scala | 17 +++++++++-------- .../GMMExpectationMaximizationSuite.scala | 14 ++++++++------ 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala index 948c350953e27..b275ee4e8cd62 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala @@ -52,9 +52,10 @@ object DenseGmmEM { .setMaxIterations(maxIterations) .run(data) + val gaussians = clusters.gaussians for (i <- 0 until clusters.k) { println("weight=%f\nmu=%s\nsigma=\n%s\n" format - (clusters.weight(i), clusters.mu(i), clusters.sigma(i))) + (clusters.weight(i), gaussians(i).mu, gaussians(i).sigma)) } println("Cluster labels (first <= 100):") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala index bdf984aee4dae..3cff967d2009e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -121,9 +121,7 @@ class GaussianMixtureEM private ( // diagonal covariance matrices using component variances // derived from the samples val (weights, gaussians) = initialModel match { - case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) => - new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix) - }) + case Some(gmm) => (gmm.weight, gmm.gaussians) case None => { val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index b461ea4f0f06e..973adf00cca31 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -36,13 +36,18 @@ import org.apache.spark.mllib.util.MLUtils * covariance matrix for Gaussian i */ class GaussianMixtureModel( - val weight: Array[Double], - val mu: Array[Vector], - val sigma: Array[Matrix]) extends Serializable { + val weight: Array[Double], + private val mu: Array[Vector], + private val sigma: Array[Matrix]) extends Serializable { /** Number of gaussians in mixture */ def k: Int = weight.length + /** Multivariate Gaussian models which compose GMM **/ + val gaussians: Array[MultivariateGaussian] = (0 until k).map {i => + new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix) + }.toArray + /** Maps given points to their cluster indices. */ def predict(points: RDD[Vector]): RDD[Int] = { val responsibilityMatrix = predictSoft(points) @@ -55,11 +60,7 @@ class GaussianMixtureModel( */ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext - val dists = sc.broadcast { - (0 until k).map { i => - new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix) - }.toArray - } + val dists = sc.broadcast(gaussians) val weights = sc.broadcast(weight) points.map { x => computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala index 23feb82874b70..102ef74454159 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala @@ -37,10 +37,11 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0)) val gmm = new GaussianMixtureEM().setK(1).run(data) + val gaussians = gmm.gaussians assert(gmm.weight(0) ~== Ew absTol 1E-5) - assert(gmm.mu(0) ~== Emu absTol 1E-5) - assert(gmm.sigma(0) ~== Esigma absTol 1E-5) + assert(Vectors.fromBreeze(gaussians(0).mu) ~== Emu absTol 1E-5) + assert(Matrices.fromBreeze(gaussians(0).sigma) ~== Esigma absTol 1E-5) } test("two clusters") { @@ -67,12 +68,13 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex .setK(2) .setInitialModel(initialGmm) .run(data) + val gaussians = gmm.gaussians assert(gmm.weight(0) ~== Ew(0) absTol 1E-3) assert(gmm.weight(1) ~== Ew(1) absTol 1E-3) - assert(gmm.mu(0) ~== Emu(0) absTol 1E-3) - assert(gmm.mu(1) ~== Emu(1) absTol 1E-3) - assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3) - assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3) + assert(Vectors.fromBreeze(gaussians(0).mu) ~== Emu(0) absTol 1E-3) + assert(Vectors.fromBreeze(gaussians(1).mu) ~== Emu(1) absTol 1E-3) + assert(Matrices.fromBreeze(gaussians(0).sigma) ~== Esigma(0) absTol 1E-3) + assert(Matrices.fromBreeze(gaussians(1).sigma) ~== Esigma(1) absTol 1E-3) } } From 6b177dc123e5b0901f2099658128564ff2bd6e33 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Tue, 6 Jan 2015 23:13:39 +0900 Subject: [PATCH 2/2] [SPARK-5019] Fix styles --- .../apache/spark/mllib/clustering/GaussianMixtureModel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 973adf00cca31..b7a90750485ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -43,7 +43,7 @@ class GaussianMixtureModel( /** Number of gaussians in mixture */ def k: Int = weight.length - /** Multivariate Gaussian models which compose GMM **/ + /** Multivariate Gaussian models which compose GMM */ val gaussians: Array[MultivariateGaussian] = (0 until k).map {i => new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix) }.toArray