Skip to content

Commit aa519e3

Browse files
mengxrrxin
authored andcommitted
[SPARK-1636][MLLIB] Move main methods to examples
* `NaiveBayes` -> `SparseNaiveBayes` * `KMeans` -> `DenseKMeans` * `SVMWithSGD` and `LogisticRegerssionWithSGD` -> `BinaryClassification` * `ALS` -> `MovieLensALS` * `LinearRegressionWithSGD`, `LassoWithSGD`, and `RidgeRegressionWithSGD` -> `LinearRegression` * `DecisionTree` -> `DecisionTreeRunner` `scopt` is used for parsing command-line parameters. `scopt` has MIT license and it only depends on `scala-library`. Example help message: ~~~ BinaryClassification: an example app for binary classification. Usage: BinaryClassification [options] <input> --numIterations <value> number of iterations --stepSize <value> initial step size, default: 1.0 --algorithm <value> algorithm (SVM,LR), default: LR --regType <value> regularization type (L1,L2), default: L2 --regParam <value> regularization parameter, default: 0.1 <input> input paths to labeled examples in LIBSVM format ~~~ Author: Xiangrui Meng <[email protected]> Closes #584 from mengxr/mllib-main and squashes the following commits: 7b58c60 [Xiangrui Meng] minor 6e35d7e [Xiangrui Meng] make imports explicit and fix code style c6178c9 [Xiangrui Meng] update TS PCA/SVD to use new spark-submit 6acff75 [Xiangrui Meng] use scopt for DecisionTreeRunner be86069 [Xiangrui Meng] use main instead of extending App b3edf68 [Xiangrui Meng] move DecisionTree's main method to examples 8bfaa5a [Xiangrui Meng] change NaiveBayesParams to Params fe23dcb [Xiangrui Meng] remove main from KMeans and add DenseKMeans as an example 67f4448 [Xiangrui Meng] remove main methods from linear regression algorithms and add LinearRegression example b066bbc [Xiangrui Meng] remove main from ALS and add MovieLensALS example b040f3b [Xiangrui Meng] change BinaryClassificationParams to Params 577945b [Xiangrui Meng] remove unused imports from NB 3d299bc [Xiangrui Meng] remove main from LR/SVM and add an example app for binary classification f70878e [Xiangrui Meng] remove main from NaiveBayes and add an example NaiveBayes app 01ec2cd [Xiangrui Meng] Merge branch 'master' into mllib-main 9420692 [Xiangrui Meng] add scopt to examples dependencies (cherry picked from commit 3f38334) Signed-off-by: Reynold Xin <[email protected]>
1 parent 0995787 commit aa519e3

File tree

19 files changed

+795
-321
lines changed

19 files changed

+795
-321
lines changed

examples/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@
166166
</exclusion>
167167
</exclusions>
168168
</dependency>
169+
<dependency>
170+
<groupId>com.github.scopt</groupId>
171+
<artifactId>scopt_${scala.binary.version}</artifactId>
172+
<version>3.2.0</version>
173+
</dependency>
169174
</dependencies>
170175

171176
<build>
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib
19+
20+
import org.apache.log4j.{Level, Logger}
21+
import scopt.OptionParser
22+
23+
import org.apache.spark.{SparkConf, SparkContext}
24+
import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD, SVMWithSGD}
25+
import org.apache.spark.mllib.evaluation.binary.BinaryClassificationMetrics
26+
import org.apache.spark.mllib.util.MLUtils
27+
import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater}
28+
29+
/**
30+
* An example app for binary classification. Run with
31+
* {{{
32+
* ./bin/run-example org.apache.spark.examples.mllib.BinaryClassification
33+
* }}}
34+
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
35+
*/
36+
object BinaryClassification {
37+
38+
object Algorithm extends Enumeration {
39+
type Algorithm = Value
40+
val SVM, LR = Value
41+
}
42+
43+
object RegType extends Enumeration {
44+
type RegType = Value
45+
val L1, L2 = Value
46+
}
47+
48+
import Algorithm._
49+
import RegType._
50+
51+
case class Params(
52+
input: String = null,
53+
numIterations: Int = 100,
54+
stepSize: Double = 1.0,
55+
algorithm: Algorithm = LR,
56+
regType: RegType = L2,
57+
regParam: Double = 0.1)
58+
59+
def main(args: Array[String]) {
60+
val defaultParams = Params()
61+
62+
val parser = new OptionParser[Params]("BinaryClassification") {
63+
head("BinaryClassification: an example app for binary classification.")
64+
opt[Int]("numIterations")
65+
.text("number of iterations")
66+
.action((x, c) => c.copy(numIterations = x))
67+
opt[Double]("stepSize")
68+
.text(s"initial step size, default: ${defaultParams.stepSize}")
69+
.action((x, c) => c.copy(stepSize = x))
70+
opt[String]("algorithm")
71+
.text(s"algorithm (${Algorithm.values.mkString(",")}), " +
72+
s"default: ${defaultParams.algorithm}")
73+
.action((x, c) => c.copy(algorithm = Algorithm.withName(x)))
74+
opt[String]("regType")
75+
.text(s"regularization type (${RegType.values.mkString(",")}), " +
76+
s"default: ${defaultParams.regType}")
77+
.action((x, c) => c.copy(regType = RegType.withName(x)))
78+
opt[Double]("regParam")
79+
.text(s"regularization parameter, default: ${defaultParams.regParam}")
80+
arg[String]("<input>")
81+
.required()
82+
.text("input paths to labeled examples in LIBSVM format")
83+
.action((x, c) => c.copy(input = x))
84+
}
85+
86+
parser.parse(args, defaultParams).map { params =>
87+
run(params)
88+
} getOrElse {
89+
sys.exit(1)
90+
}
91+
}
92+
93+
def run(params: Params) {
94+
val conf = new SparkConf().setAppName(s"BinaryClassification with $params")
95+
val sc = new SparkContext(conf)
96+
97+
Logger.getRootLogger.setLevel(Level.WARN)
98+
99+
val examples = MLUtils.loadLibSVMData(sc, params.input).cache()
100+
101+
val splits = examples.randomSplit(Array(0.8, 0.2))
102+
val training = splits(0).cache()
103+
val test = splits(1).cache()
104+
105+
val numTraining = training.count()
106+
val numTest = test.count()
107+
println(s"Training: $numTraining, test: $numTest.")
108+
109+
examples.unpersist(blocking = false)
110+
111+
val updater = params.regType match {
112+
case L1 => new L1Updater()
113+
case L2 => new SquaredL2Updater()
114+
}
115+
116+
val model = params.algorithm match {
117+
case LR =>
118+
val algorithm = new LogisticRegressionWithSGD()
119+
algorithm.optimizer
120+
.setNumIterations(params.numIterations)
121+
.setStepSize(params.stepSize)
122+
.setUpdater(updater)
123+
.setRegParam(params.regParam)
124+
algorithm.run(training).clearThreshold()
125+
case SVM =>
126+
val algorithm = new SVMWithSGD()
127+
algorithm.optimizer
128+
.setNumIterations(params.numIterations)
129+
.setStepSize(params.stepSize)
130+
.setUpdater(updater)
131+
.setRegParam(params.regParam)
132+
algorithm.run(training).clearThreshold()
133+
}
134+
135+
val prediction = model.predict(test.map(_.features))
136+
val predictionAndLabel = prediction.zip(test.map(_.label))
137+
138+
val metrics = new BinaryClassificationMetrics(predictionAndLabel)
139+
140+
println(s"Test areaUnderPR = ${metrics.areaUnderPR()}.")
141+
println(s"Test areaUnderROC = ${metrics.areaUnderROC()}.")
142+
143+
sc.stop()
144+
}
145+
}
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib
19+
20+
import scopt.OptionParser
21+
22+
import org.apache.spark.{SparkConf, SparkContext}
23+
import org.apache.spark.SparkContext._
24+
import org.apache.spark.mllib.linalg.Vector
25+
import org.apache.spark.mllib.regression.LabeledPoint
26+
import org.apache.spark.mllib.tree.{DecisionTree, impurity}
27+
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
28+
import org.apache.spark.mllib.tree.configuration.Algo._
29+
import org.apache.spark.mllib.tree.model.DecisionTreeModel
30+
import org.apache.spark.mllib.util.MLUtils
31+
import org.apache.spark.rdd.RDD
32+
33+
/**
34+
* An example runner for decision tree. Run with
35+
* {{{
36+
* ./bin/spark-example org.apache.spark.examples.mllib.DecisionTreeRunner [options]
37+
* }}}
38+
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
39+
*/
40+
object DecisionTreeRunner {
41+
42+
object ImpurityType extends Enumeration {
43+
type ImpurityType = Value
44+
val Gini, Entropy, Variance = Value
45+
}
46+
47+
import ImpurityType._
48+
49+
case class Params(
50+
input: String = null,
51+
algo: Algo = Classification,
52+
maxDepth: Int = 5,
53+
impurity: ImpurityType = Gini,
54+
maxBins: Int = 20)
55+
56+
def main(args: Array[String]) {
57+
val defaultParams = Params()
58+
59+
val parser = new OptionParser[Params]("DecisionTreeRunner") {
60+
head("DecisionTreeRunner: an example decision tree app.")
61+
opt[String]("algo")
62+
.text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
63+
.action((x, c) => c.copy(algo = Algo.withName(x)))
64+
opt[String]("impurity")
65+
.text(s"impurity type (${ImpurityType.values.mkString(",")}), " +
66+
s"default: ${defaultParams.impurity}")
67+
.action((x, c) => c.copy(impurity = ImpurityType.withName(x)))
68+
opt[Int]("maxDepth")
69+
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
70+
.action((x, c) => c.copy(maxDepth = x))
71+
opt[Int]("maxBins")
72+
.text(s"max number of bins, default: ${defaultParams.maxBins}")
73+
.action((x, c) => c.copy(maxBins = x))
74+
arg[String]("<input>")
75+
.text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
76+
.required()
77+
.action((x, c) => c.copy(input = x))
78+
checkConfig { params =>
79+
if (params.algo == Classification &&
80+
(params.impurity == Gini || params.impurity == Entropy)) {
81+
success
82+
} else if (params.algo == Regression && params.impurity == Variance) {
83+
success
84+
} else {
85+
failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
86+
}
87+
}
88+
}
89+
90+
parser.parse(args, defaultParams).map { params =>
91+
run(params)
92+
}.getOrElse {
93+
sys.exit(1)
94+
}
95+
}
96+
97+
def run(params: Params) {
98+
val conf = new SparkConf().setAppName("DecisionTreeRunner")
99+
val sc = new SparkContext(conf)
100+
101+
// Load training data and cache it.
102+
val examples = MLUtils.loadLabeledData(sc, params.input).cache()
103+
104+
val splits = examples.randomSplit(Array(0.8, 0.2))
105+
val training = splits(0).cache()
106+
val test = splits(1).cache()
107+
108+
val numTraining = training.count()
109+
val numTest = test.count()
110+
111+
println(s"numTraining = $numTraining, numTest = $numTest.")
112+
113+
examples.unpersist(blocking = false)
114+
115+
val impurityCalculator = params.impurity match {
116+
case Gini => impurity.Gini
117+
case Entropy => impurity.Entropy
118+
case Variance => impurity.Variance
119+
}
120+
121+
val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins)
122+
val model = DecisionTree.train(training, strategy)
123+
124+
if (params.algo == Classification) {
125+
val accuracy = accuracyScore(model, test)
126+
println(s"Test accuracy = $accuracy.")
127+
}
128+
129+
if (params.algo == Regression) {
130+
val mse = meanSquaredError(model, test)
131+
println(s"Test mean squared error = $mse.")
132+
}
133+
134+
sc.stop()
135+
}
136+
137+
/**
138+
* Calculates the classifier accuracy.
139+
*/
140+
private def accuracyScore(
141+
model: DecisionTreeModel,
142+
data: RDD[LabeledPoint],
143+
threshold: Double = 0.5): Double = {
144+
def predictedValue(features: Vector): Double = {
145+
if (model.predict(features) < threshold) 0.0 else 1.0
146+
}
147+
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
148+
val count = data.count()
149+
correctCount.toDouble / count
150+
}
151+
152+
/**
153+
* Calculates the mean squared error for regression.
154+
*/
155+
private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
156+
data.map { y =>
157+
val err = tree.predict(y.features) - y.label
158+
err * err
159+
}.mean()
160+
}
161+
}

0 commit comments

Comments
 (0)