Skip to content

Commit 6b1b28c

Browse files
committed
Compute prediction-and-label RDD directly rather than by zipping, for efficiency
1 parent 18f29b9 commit 6b1b28c

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

docs/mllib-naive-bayes.md

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ val training = splits(0)
5151
val test = splits(1)
5252

5353
val model = NaiveBayes.train(training, lambda = 1.0)
54-
val prediction = model.predict(test.map(_.features))
5554

56-
val predictionAndLabel = prediction.zip(test.map(_.label))
55+
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
5756
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
5857
{% endhighlight %}
5958
</div>
@@ -71,6 +70,7 @@ can be used for evaluation and prediction.
7170
import org.apache.spark.api.java.JavaPairRDD;
7271
import org.apache.spark.api.java.JavaRDD;
7372
import org.apache.spark.api.java.function.Function;
73+
import org.apache.spark.api.java.function.PairFunction;
7474
import org.apache.spark.mllib.classification.NaiveBayes;
7575
import org.apache.spark.mllib.classification.NaiveBayesModel;
7676
import org.apache.spark.mllib.regression.LabeledPoint;
@@ -81,18 +81,12 @@ JavaRDD<LabeledPoint> test = ... // test set
8181

8282
final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
8383

84-
JavaRDD<Double> prediction =
85-
test.map(new Function<LabeledPoint, Double>() {
86-
@Override public Double call(LabeledPoint p) {
87-
return model.predict(p.features());
88-
}
89-
});
9084
JavaPairRDD<Double, Double> predictionAndLabel =
91-
prediction.zip(test.map(new Function<LabeledPoint, Double>() {
92-
@Override public Double call(LabeledPoint p) {
93-
return p.label();
85+
test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
86+
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
87+
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
9488
}
95-
}));
89+
});
9690
double accuracy = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
9791
@Override public Boolean call(Tuple2<Double, Double> pl) {
9892
return pl._1() == pl._2();

0 commit comments

Comments
 (0)