@@ -162,47 +162,64 @@ val labelAndPreds = data.map { point =>
162162}
163163val trainErr = labelAndPreds.filter(r => r._ 1 != r._ 2).count.toDouble / data.count
164164println("Training Error = " + trainErr)
165+ println("Learned classification tree model:\n" + model)
165166{% endhighlight %}
166167</div >
167168
168169<div data-lang =" java " >
169170{% highlight java %}
171+ import java.util.HashMap;
170172import scala.Tuple2;
173+ import org.apache.spark.api.java.function.Function2;
171174import org.apache.spark.api.java.JavaPairRDD;
172175import org.apache.spark.api.java.JavaRDD;
176+ import org.apache.spark.api.java.JavaSparkContext;
173177import org.apache.spark.api.java.function.Function;
174178import org.apache.spark.api.java.function.PairFunction;
175179import org.apache.spark.mllib.regression.LabeledPoint;
176180import org.apache.spark.mllib.tree.DecisionTree;
177181import org.apache.spark.mllib.tree.model.DecisionTreeModel;
178-
179- JavaRDD<LabeledPoint > data = ... // data set
180-
181- // Train a DecisionTree model.
182+ import org.apache.spark.mllib.util.MLUtils;
183+ import org.apache.spark.SparkConf;
184+
185+ SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
186+ JavaSparkContext sc = new JavaSparkContext(sparkConf);
187+
188+ String datapath = "data/mllib/sample_libsvm_data.txt";
189+ JavaRDD<LabeledPoint > data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
190+ // Compute the number of classes from the data.
191+ Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
192+ @Override public Double call(LabeledPoint p) {
193+ return p.label();
194+ }
195+ }).countByValue().size();
196+
197+ // Set parameters.
182198// Empty categoricalFeaturesInfo indicates all features are continuous.
183- Integer numClasses = ... // number of classes
184199HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
185200String impurity = "gini";
186201Integer maxDepth = 5;
187202Integer maxBins = 100;
188203
204+ // Train a DecisionTree model for classification.
189205final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
190206 categoricalFeaturesInfo, impurity, maxDepth, maxBins);
191207
192208// Evaluate model on training instances and compute training error
193- JavaPairRDD<Double, Double> predictionAndLabel =
209+ JavaPairRDD<Double, Double> predictionAndLabel =
194210 data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
195211 @Override public Tuple2<Double, Double> call(LabeledPoint p) {
196212 return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
197213 }
198214 });
199- Double trainErr = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
215+ Double trainErr =
216+ 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
200217 @Override public Boolean call(Tuple2<Double, Double> pl) {
201218 return !pl._ 1().equals(pl._ 2());
202219 }
203220 }).count() / data.count();
204- System.out.print ("Training error: " + trainErr);
205- System.out.print ("Learned model:\n" + model);
221+ System.out.println ("Training error: " + trainErr);
222+ System.out.println ("Learned classification tree model:\n" + model);
206223{% endhighlight %}
207224</div >
208225
@@ -225,6 +242,8 @@ predictions = model.predict(data.map(lambda x: x.features))
225242labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions)
226243trainErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(data.count())
227244print('Training Error = ' + str(trainErr))
245+ print('Learned classification tree model:')
246+ print(model)
228247{% endhighlight %}
229248
230249Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
@@ -268,47 +287,63 @@ val labelsAndPredictions = data.map { point =>
268287}
269288val trainMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
270289println("Training Mean Squared Error = " + trainMSE)
290+ println("Learned regression tree model:\n" + model)
271291{% endhighlight %}
272292</div >
273293
274294<div data-lang =" java " >
275295{% highlight java %}
296+ import java.util.HashMap;
297+ import scala.Tuple2;
298+ import org.apache.spark.api.java.function.Function2;
276299import org.apache.spark.api.java.JavaPairRDD;
277300import org.apache.spark.api.java.JavaRDD;
301+ import org.apache.spark.api.java.JavaSparkContext;
278302import org.apache.spark.api.java.function.Function;
279303import org.apache.spark.api.java.function.PairFunction;
280304import org.apache.spark.mllib.regression.LabeledPoint;
281305import org.apache.spark.mllib.tree.DecisionTree;
282306import org.apache.spark.mllib.tree.model.DecisionTreeModel;
283- import scala.Tuple2;
307+ import org.apache.spark.mllib.util.MLUtils;
308+ import org.apache.spark.SparkConf;
284309
285- JavaRDD<LabeledPoint > data = ... // data set
310+ String datapath = "data/mllib/sample_libsvm_data.txt";
311+ JavaRDD<LabeledPoint > data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
286312
287- // Train a DecisionTree model.
313+ SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
314+ JavaSparkContext sc = new JavaSparkContext(sparkConf);
315+
316+ // Set parameters.
288317// Empty categoricalFeaturesInfo indicates all features are continuous.
289318HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
290319String impurity = "variance";
291320Integer maxDepth = 5;
292321Integer maxBins = 100;
293322
323+ // Train a DecisionTree model.
294324final DecisionTreeModel model = DecisionTree.trainRegressor(data,
295325 categoricalFeaturesInfo, impurity, maxDepth, maxBins);
296326
297327// Evaluate model on training instances and compute training error
298- JavaPairRDD<Double, Double> predictionAndLabel =
328+ JavaPairRDD<Double, Double> predictionAndLabel =
299329 data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
300330 @Override public Tuple2<Double, Double> call(LabeledPoint p) {
301331 return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
302332 }
303333 });
304- Double trainMSE = predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
334+ Double trainMSE =
335+ predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
305336 @Override public Double call(Tuple2<Double, Double> pl) {
306- Double diff = pl._ 1() - pl._ 2();
337+ Double diff = pl._ 1() - pl._ 2();
307338 return diff * diff;
308339 }
309- }).sum() / data.count();
310- System.out.print("Training Mean Squared Error: " + trainMSE);
311- System.out.print("Learned model:\n" + model);
340+ }).reduce(new Function2<Double, Double, Double>() {
341+ @Override public Double call(Double a, Double b) {
342+ return a + b;
343+ }
344+ }) / data.count();
345+ System.out.println("Training Mean Squared Error: " + trainMSE);
346+ System.out.println("Learned regression tree model:\n" + model);
312347{% endhighlight %}
313348</div >
314349
@@ -331,6 +366,8 @@ predictions = model.predict(data.map(lambda x: x.features))
331366labelsAndPredictions = data.map(lambda lp: lp.label).zip(predictions)
332367trainMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(data.count())
333368print('Training Mean Squared Error = ' + str(trainMSE))
369+ print('Learned regression tree model:')
370+ print(model)
334371{% endhighlight %}
335372
336373Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
0 commit comments