diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index 6332301e30cbd..b7a1d90d24d72 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.classification +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD -import org.apache.spark.annotation.Experimental /** * :: Experimental :: @@ -43,4 +44,12 @@ trait ClassificationModel extends Serializable { * @return predicted category from the trained model */ def predict(testData: Vector): Double + + /** + * Predict values for examples stored in a JavaRDD. + * @param testData JavaRDD representing data points to be predicted + * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction + */ + def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = + predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index ce14b06241932..fba21aefaaacd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.clustering +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector @@ -40,6 +41,10 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1) } + /** Maps given points to their cluster indices. */ + def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = + predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + /** * Return the K-means cost (sum of squared distances of points to their nearest center) for this * model on the given data. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index b27e158b43f9a..64b02f7a6e7a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.regression +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.annotation.Experimental @Experimental trait RegressionModel extends Serializable { @@ -38,4 +39,12 @@ trait RegressionModel extends Serializable { * @return Double prediction from the trained model */ def predict(testData: Vector): Double + + /** + * Predict values for examples stored in a JavaRDD. + * @param testData JavaRDD representing data points to be predicted + * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction + */ + def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = + predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index c80b1134ed1b2..743a43a139c0c 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -19,6 +19,8 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.junit.After; @@ -87,4 +89,18 @@ public void runUsingStaticMethods() { int numAccurate2 = validatePrediction(POINTS, model2); Assert.assertEquals(POINTS.size(), numAccurate2); } + + @Test + public void testPredictJavaRDD() { + JavaRDD examples = sc.parallelize(POINTS, 2).cache(); + NaiveBayesModel model = NaiveBayes.train(examples.rdd()); + JavaRDD vectors = examples.map(new Function() { + @Override + public Vector call(LabeledPoint v) throws Exception { + return v.features(); + }}); + JavaRDD predictions = model.predict(vectors); + // Should be able to get the first prediction. + predictions.first(); + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java index 49a614bd90cab..0c916ca378034 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -88,4 +88,18 @@ public void runKMeansUsingConstructor() { .run(data.rdd()); assertEquals(expectedCenter, model.clusterCenters()[0]); } + + @Test + public void testPredictJavaRDD() { + List points = Lists.newArrayList( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0) + ); + JavaRDD data = sc.parallelize(points, 2); + KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); + JavaRDD predictions = model.predict(data); + // Should be able to get the first prediction. + predictions.first(); + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java index 7151e553512b3..6dc6877691036 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -25,8 +25,10 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.util.LinearDataGenerator; public class JavaLinearRegressionSuite implements Serializable { @@ -92,4 +94,23 @@ public void runLinearRegressionUsingStaticMethods() { Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); } + @Test + public void testPredictJavaRDD() { + int nPoints = 100; + double A = 0.0; + double[] weights = {10, 10}; + JavaRDD testRDD = sc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); + LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); + JavaRDD vectors = testRDD.map(new Function() { + @Override + public Vector call(LabeledPoint v) throws Exception { + return v.features(); + } + }); + JavaRDD predictions = model.predict(vectors); + // Should be able to get the first prediction. + predictions.first(); + } }