Skip to content

Commit e75f710

Browse files
dkobylarzCodingCat
authored andcommitted
[SPARK-8481] [MLLIB] GaussianMixtureModel predict accepting single vector
Resubmit of [apache#6906] for adding single-vec predict to GMMs CC: dkobylarz mengxr To be merged with master and branch-1.5 Primary author: dkobylarz Author: Dariusz Kobylarz <[email protected]> Closes apache#8039 from jkbradley/gmm-predict-vec and squashes the following commits: bfbedc4 [Dariusz Kobylarz] [SPARK-8481] [MLlib] GaussianMixtureModel predict accepting single vector
1 parent 5b00a27 commit e75f710

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ class GaussianMixtureModel(
6666
responsibilityMatrix.map(r => r.indexOf(r.max))
6767
}
6868

69+
/** Maps given point to its cluster index. */
70+
def predict(point: Vector): Int = {
71+
val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
72+
r.indexOf(r.max)
73+
}
74+
6975
/** Java-friendly version of [[predict()]] */
7076
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
7177
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
@@ -83,6 +89,13 @@ class GaussianMixtureModel(
8389
}
8490
}
8591

92+
/**
93+
* Given the input vector, return the membership values to all mixture components.
94+
*/
95+
def predictSoft(point: Vector): Array[Double] = {
96+
computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
97+
}
98+
8699
/**
87100
* Compute the partial assignments for each vector
88101
*/

mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
148148
}
149149
}
150150

151+
test("model prediction, parallel and local") {
152+
val data = sc.parallelize(GaussianTestData.data)
153+
val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
154+
155+
val batchPredictions = gmm.predict(data)
156+
batchPredictions.zip(data).collect().foreach { case (batchPred, datum) =>
157+
assert(batchPred === gmm.predict(datum))
158+
}
159+
}
160+
151161
object GaussianTestData {
152162

153163
val data = Array(

0 commit comments

Comments
 (0)