1717
1818package org .apache .spark .examples .mllib
1919
20- import org .apache .spark .{Logging , SparkConf , SparkContext }
20+ import scopt .OptionParser
21+
22+ import org .apache .spark .{SparkConf , SparkContext }
2123import org .apache .spark .SparkContext ._
22- import org .apache .spark .mllib .linalg .{ Vector , Vectors }
24+ import org .apache .spark .mllib .linalg .Vector
2325import org .apache .spark .mllib .regression .LabeledPoint
24- import org .apache .spark .mllib .tree .DecisionTree
26+ import org .apache .spark .mllib .tree .{ DecisionTree , impurity }
2527import org .apache .spark .mllib .tree .configuration ._
2628import org .apache .spark .mllib .tree .configuration .Algo ._
27- import org .apache .spark .mllib .tree .impurity ._
2829import org .apache .spark .mllib .tree .model .DecisionTreeModel
30+ import org .apache .spark .mllib .util .MLUtils
2931import org .apache .spark .rdd .RDD
3032
3133/**
@@ -35,124 +37,118 @@ import org.apache.spark.rdd.RDD
3537 * }}}
3638 * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
3739 */
38- object DecisionTreeRunner extends Logging {
40+ object DecisionTreeRunner {
41+
42+ object ImpurityType extends Enumeration {
43+ type ImpurityType = Value
44+ val Gini, Entropy, Variance = Value
45+ }
3946
40- private val usage =
41- """
42- |Usage: DecisionTreeRunner --algo <Classification, Regression> --trainDataDir path
43- | --testDataDir path --maxDepth num [--impurity <Gini,Entropy,Variance>] [--maxBins num]
44- """ .stripMargin
47+ import ImpurityType ._
48+
49+ case class Params (
50+ input : String = null ,
51+ algo : Algo = Classification ,
52+ maxDepth : Int = 5 ,
53+ impurity : ImpurityType = Gini ,
54+ maxBins : Int = 20 )
4555
4656 def main (args : Array [String ]) {
57+ val defaultParams = Params ()
58+
59+ val parser = new OptionParser [Params ](" DecisionTreeRunner" ) {
60+ head(" DecisionTreeRunner: an example decision tree app." )
61+ opt[String ](" algo" )
62+ .text(s " algorithm ( ${Algo .values.mkString(" ," )}), default: ${defaultParams.algo}" )
63+ .action((x, c) => c.copy(algo = Algo .withName(x)))
64+ opt[String ](" impurity" )
65+ .text(s " impurity type ( ${ImpurityType .values.mkString(" ," )}), " +
66+ s " default: ${defaultParams.impurity}" )
67+ .action((x, c) => c.copy(impurity = ImpurityType .withName(x)))
68+ opt[Int ](" maxDepth" )
69+ .text(s " max depth of the tree, default: ${defaultParams.maxDepth}" )
70+ .action((x, c) => c.copy(maxDepth = x))
71+ opt[Int ](" maxBins" )
72+ .text(s " max number of bins, default: ${defaultParams.maxBins}" )
73+ .action((x, c) => c.copy(maxBins = x))
74+ arg[String ](" <input>" )
75+ .text(" input paths to labeled examples in dense format (label,f0 f1 f2 ...)" )
76+ .required()
77+ .action((x, c) => c.copy(input = x))
78+ checkConfig { params =>
79+ if (params.algo == Classification &&
80+ (params.impurity == Gini || params.impurity == Entropy )) {
81+ success
82+ } else if (params.algo == Regression && params.impurity == Variance ) {
83+ success
84+ } else {
85+ failure(s " Algo ${params.algo} is not compatible with impurity ${params.impurity}. " )
86+ }
87+ }
88+ }
4789
48- if (args.length < 2 ) {
49- System .err.println(usage)
50- System .exit(1 )
90+ parser.parse(args, defaultParams).map { params =>
91+ run(params)
92+ }.getOrElse {
93+ sys.exit(1 )
5194 }
95+ }
5296
97+ def run (params : Params ) {
5398 val conf = new SparkConf ().setAppName(" DecisionTreeRunner" )
5499 val sc = new SparkContext (conf)
55100
56- val argList = args.toList
57- type OptionMap = Map [Symbol , Any ]
58-
59- def nextOption (map : OptionMap , list : List [String ]): OptionMap = {
60- list match {
61- case Nil => map
62- case " --algo" :: string :: tail => nextOption(map ++ Map (' algo -> string), tail)
63- case " --impurity" :: string :: tail => nextOption(map ++ Map (' impurity -> string), tail)
64- case " --maxDepth" :: string :: tail => nextOption(map ++ Map (' maxDepth -> string), tail)
65- case " --maxBins" :: string :: tail => nextOption(map ++ Map (' maxBins -> string), tail)
66- case " --trainDataDir" :: string :: tail => nextOption(map ++ Map (' trainDataDir -> string)
67- , tail)
68- case " --testDataDir" :: string :: tail => nextOption(map ++ Map (' testDataDir -> string),
69- tail)
70- case string :: Nil => nextOption(map ++ Map (' infile -> string), list.tail)
71- case option :: tail => logError(" Unknown option " + option)
72- sys.exit(1 )
73- }
74- }
75- val options = nextOption(Map (), argList)
76- logDebug(options.toString())
101+ // Load training data and cache it.
102+ val examples = MLUtils .loadLabeledData(sc, params.input).cache()
77103
78- // Load training data.
79- val trainData = loadLabeledData(sc, options.get(' trainDataDir ).get.toString)
104+ val splits = examples.randomSplit(Array (0.8 , 0.2 ))
105+ val training = splits(0 ).cache()
106+ val test = splits(1 ).cache()
80107
81- // Identify the type of algorithm.
82- val algoStr = options.get(' algo ).get.toString
83- val algo = algoStr match {
84- case " Classification" => Classification
85- case " Regression" => Regression
86- }
108+ val numTraining = training.count()
109+ val numTest = test.count()
87110
88- // Identify the type of impurity.
89- val impurityStr = options.getOrElse(' impurity ,
90- if (algo == Classification ) " Gini" else " Variance" ).toString
91- val impurity = impurityStr match {
92- case " Gini" => Gini
93- case " Entropy" => Entropy
94- case " Variance" => Variance
95- }
111+ println(s " numTraining = $numTraining, numTest = $numTest. " )
96112
97- val maxDepth = options.getOrElse(' maxDepth , " 1" ).toString.toInt
98- val maxBins = options.getOrElse(' maxBins , " 100" ).toString.toInt
113+ examples.unpersist(blocking = false )
99114
100- val strategy = new Strategy (algo, impurity, maxDepth, maxBins)
101- val model = DecisionTree .train(trainData, strategy)
115+ val impurityCalculator = params.impurity match {
116+ case Gini => impurity.Gini
117+ case Entropy => impurity.Entropy
118+ case Variance => impurity.Variance
119+ }
102120
103- // Load test data.
104- val testData = loadLabeledData(sc, options.get( ' testDataDir ).get.toString )
121+ val strategy = new Strategy (params.algo, impurityCalculator, params.maxDepth, params.maxBins)
122+ val model = DecisionTree .train(training, strategy )
105123
106- // Measure algorithm accuracy
107- if (algo == Classification ) {
108- val accuracy = accuracyScore(model, testData)
109- logDebug(" accuracy = " + accuracy)
124+ if (params.algo == Classification ) {
125+ val accuracy = accuracyScore(model, test)
126+ println(s " Test accuracy = $accuracy. " )
110127 }
111128
112- if (algo == Regression ) {
113- val mse = meanSquaredError(model, testData )
114- logDebug( " mean square error = " + mse)
129+ if (params. algo == Regression ) {
130+ val mse = meanSquaredError(model, test )
131+ println( s " Test mean squared error = $ mse. " )
115132 }
116133
117134 sc.stop()
118135 }
119136
120- /**
121- * Load labeled data from a file. The data format used here is
122- * <L>, <f1> <f2> ...,
123- * where <f1>, <f2> are feature values in Double and <L> is the corresponding label as Double.
124- *
125- * @param sc SparkContext
126- * @param dir Directory to the input data files.
127- * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
128- * the label, and the second element represents the feature values (an array of Double).
129- */
130- private def loadLabeledData (sc : SparkContext , dir : String ): RDD [LabeledPoint ] = {
131- sc.textFile(dir).map { line =>
132- val parts = line.trim().split(" ," )
133- val label = parts(0 ).toDouble
134- val features = Vectors .dense(parts.slice(1 ,parts.length).map(_.toDouble))
135- LabeledPoint (label, features)
136- }
137- }
138-
139- // TODO: Port this method to a generic metrics package.
140137 /**
141138 * Calculates the classifier accuracy.
142139 */
143- private def accuracyScore (model : DecisionTreeModel , data : RDD [LabeledPoint ],
144- threshold : Double = 0.5 ): Double = {
145- def predictedValue (features : Vector ) = {
140+ private def accuracyScore (
141+ model : DecisionTreeModel ,
142+ data : RDD [LabeledPoint ],
143+ threshold : Double = 0.5 ): Double = {
144+ def predictedValue (features : Vector ): Double = {
146145 if (model.predict(features) < threshold) 0.0 else 1.0
147146 }
148147 val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
149148 val count = data.count()
150- logDebug(" correct prediction count = " + correctCount)
151- logDebug(" data count = " + count)
152149 correctCount.toDouble / count
153150 }
154151
155- // TODO: Port this method to a generic metrics package
156152 /**
157153 * Calculates the mean squared error for regression.
158154 */
0 commit comments