Skip to content

Commit 6acff75

Browse files
committed
use scopt for DecisionTreeRunner
1 parent be86069 commit 6acff75

File tree

1 file changed

+85
-89
lines changed

1 file changed

+85
-89
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 85 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@
1717

1818
package org.apache.spark.examples.mllib
1919

20-
import org.apache.spark.{Logging, SparkConf, SparkContext}
20+
import scopt.OptionParser
21+
22+
import org.apache.spark.{SparkConf, SparkContext}
2123
import org.apache.spark.SparkContext._
22-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
24+
import org.apache.spark.mllib.linalg.Vector
2325
import org.apache.spark.mllib.regression.LabeledPoint
24-
import org.apache.spark.mllib.tree.DecisionTree
26+
import org.apache.spark.mllib.tree.{DecisionTree, impurity}
2527
import org.apache.spark.mllib.tree.configuration._
2628
import org.apache.spark.mllib.tree.configuration.Algo._
27-
import org.apache.spark.mllib.tree.impurity._
2829
import org.apache.spark.mllib.tree.model.DecisionTreeModel
30+
import org.apache.spark.mllib.util.MLUtils
2931
import org.apache.spark.rdd.RDD
3032

3133
/**
@@ -35,124 +37,118 @@ import org.apache.spark.rdd.RDD
3537
* }}}
3638
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
3739
*/
38-
object DecisionTreeRunner extends Logging {
40+
object DecisionTreeRunner {
41+
42+
object ImpurityType extends Enumeration {
43+
type ImpurityType = Value
44+
val Gini, Entropy, Variance = Value
45+
}
3946

40-
private val usage =
41-
"""
42-
|Usage: DecisionTreeRunner --algo <Classification, Regression> --trainDataDir path
43-
| --testDataDir path --maxDepth num [--impurity <Gini,Entropy,Variance>] [--maxBins num]
44-
""".stripMargin
47+
import ImpurityType._
48+
49+
case class Params(
50+
input: String = null,
51+
algo: Algo = Classification,
52+
maxDepth: Int = 5,
53+
impurity: ImpurityType = Gini,
54+
maxBins: Int = 20)
4555

4656
def main(args: Array[String]) {
57+
val defaultParams = Params()
58+
59+
val parser = new OptionParser[Params]("DecisionTreeRunner") {
60+
head("DecisionTreeRunner: an example decision tree app.")
61+
opt[String]("algo")
62+
.text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
63+
.action((x, c) => c.copy(algo = Algo.withName(x)))
64+
opt[String]("impurity")
65+
.text(s"impurity type (${ImpurityType.values.mkString(",")}), " +
66+
s"default: ${defaultParams.impurity}")
67+
.action((x, c) => c.copy(impurity = ImpurityType.withName(x)))
68+
opt[Int]("maxDepth")
69+
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
70+
.action((x, c) => c.copy(maxDepth = x))
71+
opt[Int]("maxBins")
72+
.text(s"max number of bins, default: ${defaultParams.maxBins}")
73+
.action((x, c) => c.copy(maxBins = x))
74+
arg[String]("<input>")
75+
.text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
76+
.required()
77+
.action((x, c) => c.copy(input = x))
78+
checkConfig { params =>
79+
if (params.algo == Classification &&
80+
(params.impurity == Gini || params.impurity == Entropy)) {
81+
success
82+
} else if (params.algo == Regression && params.impurity == Variance) {
83+
success
84+
} else {
85+
failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
86+
}
87+
}
88+
}
4789

48-
if (args.length < 2) {
49-
System.err.println(usage)
50-
System.exit(1)
90+
parser.parse(args, defaultParams).map { params =>
91+
run(params)
92+
}.getOrElse {
93+
sys.exit(1)
5194
}
95+
}
5296

97+
def run(params: Params) {
5398
val conf = new SparkConf().setAppName("DecisionTreeRunner")
5499
val sc = new SparkContext(conf)
55100

56-
val argList = args.toList
57-
type OptionMap = Map[Symbol, Any]
58-
59-
def nextOption(map : OptionMap, list: List[String]): OptionMap = {
60-
list match {
61-
case Nil => map
62-
case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail)
63-
case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail)
64-
case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail)
65-
case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail)
66-
case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string)
67-
, tail)
68-
case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string),
69-
tail)
70-
case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail)
71-
case option :: tail => logError("Unknown option " + option)
72-
sys.exit(1)
73-
}
74-
}
75-
val options = nextOption(Map(), argList)
76-
logDebug(options.toString())
101+
// Load training data and cache it.
102+
val examples = MLUtils.loadLabeledData(sc, params.input).cache()
77103

78-
// Load training data.
79-
val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString)
104+
val splits = examples.randomSplit(Array(0.8, 0.2))
105+
val training = splits(0).cache()
106+
val test = splits(1).cache()
80107

81-
// Identify the type of algorithm.
82-
val algoStr = options.get('algo).get.toString
83-
val algo = algoStr match {
84-
case "Classification" => Classification
85-
case "Regression" => Regression
86-
}
108+
val numTraining = training.count()
109+
val numTest = test.count()
87110

88-
// Identify the type of impurity.
89-
val impurityStr = options.getOrElse('impurity,
90-
if (algo == Classification) "Gini" else "Variance").toString
91-
val impurity = impurityStr match {
92-
case "Gini" => Gini
93-
case "Entropy" => Entropy
94-
case "Variance" => Variance
95-
}
111+
println(s"numTraining = $numTraining, numTest = $numTest.")
96112

97-
val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt
98-
val maxBins = options.getOrElse('maxBins, "100").toString.toInt
113+
examples.unpersist(blocking = false)
99114

100-
val strategy = new Strategy(algo, impurity, maxDepth, maxBins)
101-
val model = DecisionTree.train(trainData, strategy)
115+
val impurityCalculator = params.impurity match {
116+
case Gini => impurity.Gini
117+
case Entropy => impurity.Entropy
118+
case Variance => impurity.Variance
119+
}
102120

103-
// Load test data.
104-
val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)
121+
val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins)
122+
val model = DecisionTree.train(training, strategy)
105123

106-
// Measure algorithm accuracy
107-
if (algo == Classification) {
108-
val accuracy = accuracyScore(model, testData)
109-
logDebug("accuracy = " + accuracy)
124+
if (params.algo == Classification) {
125+
val accuracy = accuracyScore(model, test)
126+
println(s"Test accuracy = $accuracy.")
110127
}
111128

112-
if (algo == Regression) {
113-
val mse = meanSquaredError(model, testData)
114-
logDebug("mean square error = " + mse)
129+
if (params.algo == Regression) {
130+
val mse = meanSquaredError(model, test)
131+
println(s"Test mean squared error = $mse.")
115132
}
116133

117134
sc.stop()
118135
}
119136

120-
/**
121-
* Load labeled data from a file. The data format used here is
122-
* <L>, <f1> <f2> ...,
123-
* where <f1>, <f2> are feature values in Double and <L> is the corresponding label as Double.
124-
*
125-
* @param sc SparkContext
126-
* @param dir Directory to the input data files.
127-
* @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
128-
* the label, and the second element represents the feature values (an array of Double).
129-
*/
130-
private def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
131-
sc.textFile(dir).map { line =>
132-
val parts = line.trim().split(",")
133-
val label = parts(0).toDouble
134-
val features = Vectors.dense(parts.slice(1,parts.length).map(_.toDouble))
135-
LabeledPoint(label, features)
136-
}
137-
}
138-
139-
// TODO: Port this method to a generic metrics package.
140137
/**
141138
* Calculates the classifier accuracy.
142139
*/
143-
private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint],
144-
threshold: Double = 0.5): Double = {
145-
def predictedValue(features: Vector) = {
140+
private def accuracyScore(
141+
model: DecisionTreeModel,
142+
data: RDD[LabeledPoint],
143+
threshold: Double = 0.5): Double = {
144+
def predictedValue(features: Vector): Double = {
146145
if (model.predict(features) < threshold) 0.0 else 1.0
147146
}
148147
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
149148
val count = data.count()
150-
logDebug("correct prediction count = " + correctCount)
151-
logDebug("data count = " + count)
152149
correctCount.toDouble / count
153150
}
154151

155-
// TODO: Port this method to a generic metrics package
156152
/**
157153
* Calculates the mean squared error for regression.
158154
*/

0 commit comments

Comments
 (0)