Skip to content

Commit 07f7789

Browse files
dkobylarzjkbradley
authored andcommitted
[SPARK-8481] [MLLIB] GaussianMixtureModel.predict, GaussianMixtureModel.predictSoft variants for a single vector
This PR adds GaussianMixtureModel.predict & GaussianMixtureModel.predictSoft variants for a single vector which are useful when applying the trained model in environments where spark context is not required (or not desired) and predictions are made for single data points (vectors). Test case included. Author: Dariusz Kobylarz <[email protected]> Closes #6906 from dkobylarz/branch-1.4 and squashes the following commits: cef1f0a [Dariusz Kobylarz] [SPARK-8481] [MLlib] GaussianMixtureModel predict accepting single vector
1 parent a292c49 commit 07f7789

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ class GaussianMixtureModel(
6666
responsibilityMatrix.map(r => r.indexOf(r.max))
6767
}
6868

69+
/** Maps given point to its cluster index. */
70+
def predict(point: Vector): Int = {
71+
val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
72+
r.indexOf(r.max)
73+
}
74+
6975
/** Java-friendly version of [[predict()]] */
7076
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
7177
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
@@ -83,6 +89,13 @@ class GaussianMixtureModel(
8389
}
8490
}
8591

92+
/**
93+
* Given the input vector, return the membership values to all mixture components.
94+
*/
95+
def predictSoft(point: Vector): Array[Double] = {
96+
computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
97+
}
98+
8699
/**
87100
* Compute the partial assignments for each vector
88101
*/

mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
148148
}
149149
}
150150

151+
test("model prediction, parallel and local") {
152+
val data = sc.parallelize(GaussianTestData.data)
153+
val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
154+
155+
val batchPredictions = gmm.predict(data)
156+
batchPredictions.zip(data).collect().foreach { case (batchPred, datum) =>
157+
assert(batchPred === gmm.predict(datum))
158+
}
159+
}
160+
151161
object GaussianTestData {
152162

153163
val data = Array(

0 commit comments

Comments
 (0)