Skip to content

Commit 04fa122

Browse files
srowenmengxr
authored andcommitted
SPARK-2293. Replace RDD.zip usage by map with predict inside.
This is the only occurrence of this pattern in the examples that needs to be replaced. It only addresses the example change. Author: Sean Owen <[email protected]> Closes apache#1250 from srowen/SPARK-2293 and squashes the following commits: 6b1b28c [Sean Owen] Compute prediction-and-label RDD directly rather than by zipping, for efficiency
1 parent 5fccb56 commit 04fa122

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

docs/mllib-naive-bayes.md

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ val training = splits(0)
5151
val test = splits(1)
5252

5353
val model = NaiveBayes.train(training, lambda = 1.0)
54-
val prediction = model.predict(test.map(_.features))
5554

56-
val predictionAndLabel = prediction.zip(test.map(_.label))
55+
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
5756
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
5857
{% endhighlight %}
5958
</div>
@@ -71,6 +70,7 @@ can be used for evaluation and prediction.
7170
import org.apache.spark.api.java.JavaPairRDD;
7271
import org.apache.spark.api.java.JavaRDD;
7372
import org.apache.spark.api.java.function.Function;
73+
import org.apache.spark.api.java.function.PairFunction;
7474
import org.apache.spark.mllib.classification.NaiveBayes;
7575
import org.apache.spark.mllib.classification.NaiveBayesModel;
7676
import org.apache.spark.mllib.regression.LabeledPoint;
@@ -81,18 +81,12 @@ JavaRDD<LabeledPoint> test = ... // test set
8181

8282
final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
8383

84-
JavaRDD<Double> prediction =
85-
test.map(new Function<LabeledPoint, Double>() {
86-
@Override public Double call(LabeledPoint p) {
87-
return model.predict(p.features());
88-
}
89-
});
9084
JavaPairRDD<Double, Double> predictionAndLabel =
91-
prediction.zip(test.map(new Function<LabeledPoint, Double>() {
92-
@Override public Double call(LabeledPoint p) {
93-
return p.label();
85+
test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
86+
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
87+
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
9488
}
95-
}));
89+
});
9690
double accuracy = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
9791
@Override public Boolean call(Tuple2<Double, Double> pl) {
9892
return pl._1() == pl._2();

0 commit comments

Comments
 (0)