Skip to content

Commit b9bee04

Browse files
committed
Updated DT examples
1 parent 57eee9f commit b9bee04

File tree

2 files changed

+68
-30
lines changed

2 files changed

+68
-30
lines changed

docs/mllib-decision-tree.md

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -162,47 +162,64 @@ val labelAndPreds = data.map { point =>
162162
}
163163
val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / data.count
164164
println("Training Error = " + trainErr)
165+
println("Learned classification tree model:\n" + model)
165166
{% endhighlight %}
166167
</div>
167168

168169
<div data-lang="java">
169170
{% highlight java %}
171+
import java.util.HashMap;
170172
import scala.Tuple2;
173+
import org.apache.spark.api.java.function.Function2;
171174
import org.apache.spark.api.java.JavaPairRDD;
172175
import org.apache.spark.api.java.JavaRDD;
176+
import org.apache.spark.api.java.JavaSparkContext;
173177
import org.apache.spark.api.java.function.Function;
174178
import org.apache.spark.api.java.function.PairFunction;
175179
import org.apache.spark.mllib.regression.LabeledPoint;
176180
import org.apache.spark.mllib.tree.DecisionTree;
177181
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
178-
179-
JavaRDD<LabeledPoint> data = ... // data set
180-
181-
// Train a DecisionTree model.
182+
import org.apache.spark.mllib.util.MLUtils;
183+
import org.apache.spark.SparkConf;
184+
185+
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
186+
JavaSparkContext sc = new JavaSparkContext(sparkConf);
187+
188+
String datapath = "data/mllib/sample_libsvm_data.txt";
189+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
190+
// Compute the number of classes from the data.
191+
Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
192+
@Override public Double call(LabeledPoint p) {
193+
return p.label();
194+
}
195+
}).countByValue().size();
196+
197+
// Set parameters.
182198
// Empty categoricalFeaturesInfo indicates all features are continuous.
183-
Integer numClasses = ... // number of classes
184199
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
185200
String impurity = "gini";
186201
Integer maxDepth = 5;
187202
Integer maxBins = 100;
188203

204+
// Train a DecisionTree model for classification.
189205
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
190206
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
191207

192208
// Evaluate model on training instances and compute training error
193-
JavaPairRDD<Double, Double> predictionAndLabel =
209+
JavaPairRDD<Double, Double> predictionAndLabel =
194210
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
195211
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
196212
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
197213
}
198214
});
199-
Double trainErr = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
215+
Double trainErr =
216+
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
200217
@Override public Boolean call(Tuple2<Double, Double> pl) {
201218
return !pl._1().equals(pl._2());
202219
}
203220
}).count() / data.count();
204-
System.out.print("Training error: " + trainErr);
205-
System.out.print("Learned model:\n" + model);
221+
System.out.println("Training error: " + trainErr);
222+
System.out.println("Learned classification tree model:\n" + model);
206223
{% endhighlight %}
207224
</div>
208225

@@ -225,6 +242,8 @@ predictions = model.predict(data.map(lambda x: x.features))
225242
labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions)
226243
trainErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(data.count())
227244
print('Training Error = ' + str(trainErr))
245+
print('Learned classification tree model:')
246+
print(model)
228247
{% endhighlight %}
229248

230249
Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
@@ -268,47 +287,63 @@ val labelsAndPredictions = data.map { point =>
268287
}
269288
val trainMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
270289
println("Training Mean Squared Error = " + trainMSE)
290+
println("Learned regression tree model:\n" + model)
271291
{% endhighlight %}
272292
</div>
273293

274294
<div data-lang="java">
275295
{% highlight java %}
296+
import java.util.HashMap;
297+
import scala.Tuple2;
298+
import org.apache.spark.api.java.function.Function2;
276299
import org.apache.spark.api.java.JavaPairRDD;
277300
import org.apache.spark.api.java.JavaRDD;
301+
import org.apache.spark.api.java.JavaSparkContext;
278302
import org.apache.spark.api.java.function.Function;
279303
import org.apache.spark.api.java.function.PairFunction;
280304
import org.apache.spark.mllib.regression.LabeledPoint;
281305
import org.apache.spark.mllib.tree.DecisionTree;
282306
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
283-
import scala.Tuple2;
307+
import org.apache.spark.mllib.util.MLUtils;
308+
import org.apache.spark.SparkConf;
284309

285-
JavaRDD<LabeledPoint> data = ... // data set
310+
String datapath = "data/mllib/sample_libsvm_data.txt";
311+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
286312

287-
// Train a DecisionTree model.
313+
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
314+
JavaSparkContext sc = new JavaSparkContext(sparkConf);
315+
316+
// Set parameters.
288317
// Empty categoricalFeaturesInfo indicates all features are continuous.
289318
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
290319
String impurity = "variance";
291320
Integer maxDepth = 5;
292321
Integer maxBins = 100;
293322

323+
// Train a DecisionTree model.
294324
final DecisionTreeModel model = DecisionTree.trainRegressor(data,
295325
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
296326

297327
// Evaluate model on training instances and compute training error
298-
JavaPairRDD<Double, Double> predictionAndLabel =
328+
JavaPairRDD<Double, Double> predictionAndLabel =
299329
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
300330
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
301331
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
302332
}
303333
});
304-
Double trainMSE = predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
334+
Double trainMSE =
335+
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
305336
@Override public Double call(Tuple2<Double, Double> pl) {
306-
Double diff = pl._1() - pl._2();
337+
Double diff = pl._1() - pl._2();
307338
return diff * diff;
308339
}
309-
}).sum() / data.count();
310-
System.out.print("Training Mean Squared Error: " + trainMSE);
311-
System.out.print("Learned model:\n" + model);
340+
}).reduce(new Function2<Double, Double, Double>() {
341+
@Override public Double call(Double a, Double b) {
342+
return a + b;
343+
}
344+
}) / data.count();
345+
System.out.println("Training Mean Squared Error: " + trainMSE);
346+
System.out.println("Learned regression tree model:\n" + model);
312347
{% endhighlight %}
313348
</div>
314349

@@ -331,6 +366,8 @@ predictions = model.predict(data.map(lambda x: x.features))
331366
labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions)
332367
trainMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(data.count())
333368
print('Training Mean Squared Error = ' + str(trainMSE))
369+
print('Learned regression tree model:')
370+
print(model)
334371
{% endhighlight %}
335372

336373
Note: When making predictions for a dataset, it is more efficient to do batch prediction rather

examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919

2020
import java.util.HashMap;
2121

22-
import scala.reflect.ClassTag;
2322
import scala.Tuple2;
2423

2524
import org.apache.spark.api.java.function.Function2;
26-
import org.apache.spark.api.java.JavaPairRDD;
25+
import org.apache.spark.api.java.JavaPairRDD;
2726
import org.apache.spark.api.java.JavaRDD;
2827
import org.apache.spark.api.java.JavaSparkContext;
2928
import org.apache.spark.api.java.function.Function;
@@ -34,30 +33,33 @@
3433
import org.apache.spark.mllib.util.MLUtils;
3534
import org.apache.spark.SparkConf;
3635

37-
3836
/**
3937
* Classification and regression using decision trees.
4038
*/
4139
public final class JavaDecisionTree {
4240

4341
public static void main(String[] args) {
44-
if (args.length != 1) {
42+
String datapath = "data/mllib/sample_libsvm_data.txt";
43+
if (args.length == 1) {
44+
datapath = args[0];
45+
} else if (args.length > 1) {
4546
System.err.println("Usage: JavaDecisionTree <libsvm format data file>");
4647
System.exit(1);
4748
}
4849
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
4950
JavaSparkContext sc = new JavaSparkContext(sparkConf);
50-
String datapath = args[0];
5151

52-
JavaRDD<LabeledPoint> data = JavaRDD.fromRDD(MLUtils.loadLibSVMFile(sc.sc(), datapath));
52+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
5353

5454
// Compute the number of classes from the data.
5555
Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
5656
@Override public Double call(LabeledPoint p) {
5757
return p.label();
5858
}
5959
}).countByValue().size();
60-
// Empty categoricalFeaturesInfo indicates all features are continuous.
60+
61+
// Set parameters.
62+
// Empty categoricalFeaturesInfo indicates all features are continuous.
6163
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
6264
String impurity = "gini";
6365
Integer maxDepth = 5;
@@ -80,12 +82,11 @@ public static void main(String[] args) {
8082
return !pl._1().equals(pl._2());
8183
}
8284
}).count() / data.count();
83-
System.out.print("Training error: " + trainErr);
84-
System.out.print("Learned classification tree model:\n" + model);
85+
System.out.println("Training error: " + trainErr);
86+
System.out.println("Learned classification tree model:\n" + model);
8587

8688
// Train a DecisionTree model for regression.
8789
impurity = "variance";
88-
8990
final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data,
9091
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
9192

@@ -107,8 +108,8 @@ public static void main(String[] args) {
107108
return a + b;
108109
}
109110
}) / data.count();
110-
System.out.print("Training Mean Squared Error: " + trainMSE);
111-
System.out.print("Learned regression tree model:\n" + regressionModel);
111+
System.out.println("Training Mean Squared Error: " + trainMSE);
112+
System.out.println("Learned regression tree model:\n" + regressionModel);
112113

113114
sc.stop();
114115
}

0 commit comments

Comments
 (0)