Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.mllib.classification

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.annotation.Experimental

/**
* :: Experimental ::
Expand All @@ -43,4 +44,12 @@ trait ClassificationModel extends Serializable {
* @return predicted category from the trained model
*/
def predict(testData: Vector): Double

/**
* Predict values for examples stored in a JavaRDD.
* @param testData JavaRDD representing data points to be predicted
* @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
*/
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.mllib.clustering

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
Expand All @@ -40,6 +41,10 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1)
}

/** Maps given points to their cluster indices. */
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]

/**
* Return the K-means cost (sum of squared distances of points to their nearest center) for this
* model on the given data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.mllib.regression

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.annotation.Experimental

@Experimental
trait RegressionModel extends Serializable {
Expand All @@ -38,4 +39,12 @@ trait RegressionModel extends Serializable {
* @return Double prediction from the trained model
*/
def predict(testData: Vector): Double

/**
* Predict values for examples stored in a JavaRDD.
* @param testData JavaRDD representing data points to be predicted
* @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
*/
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.junit.After;
Expand Down Expand Up @@ -87,4 +89,18 @@ public void runUsingStaticMethods() {
int numAccurate2 = validatePrediction(POINTS, model2);
Assert.assertEquals(POINTS.size(), numAccurate2);
}

@Test
public void testPredictJavaRDD() {
JavaRDD<LabeledPoint> examples = sc.parallelize(POINTS, 2).cache();
NaiveBayesModel model = NaiveBayes.train(examples.rdd());
JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, Vector>() {
@Override
public Vector call(LabeledPoint v) throws Exception {
return v.features();
}});
JavaRDD<Double> predictions = model.predict(vectors);
// Should be able to get the first prediction.
predictions.first();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,18 @@ public void runKMeansUsingConstructor() {
.run(data.rdd());
assertEquals(expectedCenter, model.clusterCenters()[0]);
}

@Test
public void testPredictJavaRDD() {
List<Vector> points = Lists.newArrayList(
Vectors.dense(1.0, 2.0, 6.0),
Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0)
);
JavaRDD<Vector> data = sc.parallelize(points, 2);
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
JavaRDD<Integer> predictions = model.predict(data);
// Should be able to get the first prediction.
predictions.first();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.util.LinearDataGenerator;

public class JavaLinearRegressionSuite implements Serializable {
Expand Down Expand Up @@ -92,4 +94,23 @@ public void runLinearRegressionUsingStaticMethods() {
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}

@Test
public void testPredictJavaRDD() {
int nPoints = 100;
double A = 0.0;
double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
JavaRDD<Vector> vectors = testRDD.map(new Function<LabeledPoint, Vector>() {
@Override
public Vector call(LabeledPoint v) throws Exception {
return v.features();
}
});
JavaRDD<Double> predictions = model.predict(vectors);
// Should be able to get the first prediction.
predictions.first();
}
}