@@ -44,6 +44,13 @@ import org.apache.spark.sql.{SQLContext, DataFrame}
4444 * {{{
4545 * ./bin/run-example ml.DecisionTreeExample [options]
4646 * }}}
47+ * Note that Decision Trees can take a large amount of memory. If the run-example command above
48+ * fails, try running via spark-submit and specifying the amount of memory as at least 1g.
49+ * For local mode, run
50+ * {{{
51+ * ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g
52+ * [examples JAR path] [options]
53+ * }}}
4754 * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
4855 */
4956object DecisionTreeExample {
@@ -70,7 +77,7 @@ object DecisionTreeExample {
7077 val parser = new OptionParser [Params ](" DecisionTreeExample" ) {
7178 head(" DecisionTreeExample: an example decision tree app." )
7279 opt[String ](" algo" )
73- .text(s " algorithm (Classification, Regression ), default: ${defaultParams.algo}" )
80+ .text(s " algorithm (classification, regression ), default: ${defaultParams.algo}" )
7481 .action((x, c) => c.copy(algo = x))
7582 opt[Int ](" maxDepth" )
7683 .text(s " max depth of the tree, default: ${defaultParams.maxDepth}" )
@@ -222,18 +229,23 @@ object DecisionTreeExample {
222229 // (1) For classification, re-index classes.
223230 val labelColName = if (algo == " classification" ) " indexedLabel" else " label"
224231 if (algo == " classification" ) {
225- val labelIndexer = new StringIndexer ().setInputCol(" labelString" ).setOutputCol(labelColName)
232+ val labelIndexer = new StringIndexer ()
233+ .setInputCol(" labelString" )
234+ .setOutputCol(labelColName)
226235 stages += labelIndexer
227236 }
228237 // (2) Identify categorical features using VectorIndexer.
229238 // Features with more than maxCategories values will be treated as continuous.
230- val featuresIndexer = new VectorIndexer ().setInputCol(" features" )
231- .setOutputCol(" indexedFeatures" ).setMaxCategories(10 )
239+ val featuresIndexer = new VectorIndexer ()
240+ .setInputCol(" features" )
241+ .setOutputCol(" indexedFeatures" )
242+ .setMaxCategories(10 )
232243 stages += featuresIndexer
233244 // (3) Learn DecisionTree
234245 val dt = algo match {
235246 case " classification" =>
236- new DecisionTreeClassifier ().setFeaturesCol(" indexedFeatures" )
247+ new DecisionTreeClassifier ()
248+ .setFeaturesCol(" indexedFeatures" )
237249 .setLabelCol(labelColName)
238250 .setMaxDepth(params.maxDepth)
239251 .setMaxBins(params.maxBins)
@@ -242,7 +254,8 @@ object DecisionTreeExample {
242254 .setCacheNodeIds(params.cacheNodeIds)
243255 .setCheckpointInterval(params.checkpointInterval)
244256 case " regression" =>
245- new DecisionTreeRegressor ().setFeaturesCol(" indexedFeatures" )
257+ new DecisionTreeRegressor ()
258+ .setFeaturesCol(" indexedFeatures" )
246259 .setLabelCol(labelColName)
247260 .setMaxDepth(params.maxDepth)
248261 .setMaxBins(params.maxBins)
0 commit comments