Skip to content

Commit f70878e

Browse files
committed
remove main from NaiveBayes and add an example NaiveBayes app
1 parent 01ec2cd commit f70878e

File tree

2 files changed

+100
-19
lines changed

2 files changed

+100
-19
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib
19+
20+
import org.apache.log4j.{Level, Logger}
21+
import scopt.OptionParser
22+
23+
import org.apache.spark.{SparkConf, SparkContext}
24+
import org.apache.spark.mllib.classification.NaiveBayes
25+
import org.apache.spark.mllib.util.{MLUtils, MulticlassLabelParser}
26+
27+
/**
28+
* An example naive Bayes app. Run with
29+
* {{{
30+
* ./bin/spark-example org.apache.spark.examples.mllib.SparseNaiveBayes [options] <input>
31+
* }}}
32+
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
33+
*/
34+
object SparseNaiveBayes extends App {
35+
36+
case class NaiveBayesParams(
37+
input: String = null,
38+
minPartitions: Int = 0,
39+
numFeatures: Int = -1,
40+
lambda: Double = 1.0)
41+
42+
val defaultParams = NaiveBayesParams()
43+
44+
val parser = new OptionParser[NaiveBayesParams]("SparseNaiveBayes") {
45+
head("SparseNaiveBayes: an example naive Bayes app for LIBSVM data.")
46+
opt[Int]("numPartitions")
47+
.text("min number of partitions")
48+
.action((x, c) => c.copy(minPartitions = x))
49+
opt[Int]("numFeatures")
50+
.text("number of features")
51+
.action((x, c) => c.copy(numFeatures = x))
52+
opt[Double]("lambda")
53+
.text(s"lambda (smoothing constant), default: ${defaultParams.lambda}")
54+
.action((x, c) => c.copy(lambda = x))
55+
arg[String]("<input>")
56+
.text("input paths to labeled examples in LIBSVM format")
57+
.required()
58+
.action((x, c) => c.copy(input = x))
59+
}
60+
61+
parser.parse(args, defaultParams).map { params =>
62+
run(params)
63+
}.getOrElse {
64+
sys.exit(1)
65+
}
66+
67+
def run(params: NaiveBayesParams) {
68+
val conf = new SparkConf().setAppName(s"SparseNaiveBayes with $params")
69+
val sc = new SparkContext(conf)
70+
71+
Logger.getRootLogger.setLevel(Level.WARN)
72+
73+
val minPartitions =
74+
if (params.minPartitions > 0) params.minPartitions else sc.defaultMinPartitions
75+
76+
val examples = MLUtils.loadLibSVMData(sc, params.input, MulticlassLabelParser,
77+
params.numFeatures, minPartitions)
78+
// Cache examples because it will be used in both training and evaluation.
79+
examples.cache()
80+
81+
val splits = examples.randomSplit(Array(0.8, 0.2))
82+
val training = splits(0)
83+
val test = splits(1)
84+
85+
val numTraining = training.count()
86+
val numTest = test.count()
87+
88+
println(s"numTraining = $numTraining, numTest = $numTest.")
89+
90+
val model = new NaiveBayes().setLambda(params.lambda).run(training)
91+
92+
val prediction = model.predict(test.map(_.features))
93+
val predictionAndLabel = prediction.zip(test.map(_.label))
94+
val accuracy = predictionAndLabel.filter(x => x._1 == x._2).count().toDouble / numTest
95+
96+
println(s"Test accuracy = $accuracy.")
97+
98+
sc.stop()
99+
}
100+
}

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -158,23 +158,4 @@ object NaiveBayes {
158158
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
159159
new NaiveBayes(lambda).run(input)
160160
}
161-
162-
def main(args: Array[String]) {
163-
if (args.length != 2 && args.length != 3) {
164-
println("Usage: NaiveBayes <master> <input_dir> [<lambda>]")
165-
System.exit(1)
166-
}
167-
val sc = new SparkContext(args(0), "NaiveBayes")
168-
val data = MLUtils.loadLabeledData(sc, args(1))
169-
val model = if (args.length == 2) {
170-
NaiveBayes.train(data)
171-
} else {
172-
NaiveBayes.train(data, args(2).toDouble)
173-
}
174-
175-
println("Pi\n: " + model.pi)
176-
println("Theta:\n" + model.theta)
177-
178-
sc.stop()
179-
}
180161
}

0 commit comments

Comments
 (0)