Skip to content

Commit e358661

Browse files
committed
DecisionTree API change:
* Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). Added factory classes for Algo and Impurity, but made private[mllib].
1 parent c699850 commit e358661

File tree

4 files changed

+235
-48
lines changed

4 files changed

+235
-48
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
3232
import org.apache.spark.mllib.recommendation._
3333
import org.apache.spark.mllib.regression._
3434
import org.apache.spark.mllib.tree.configuration.Algo._
35-
import org.apache.spark.mllib.tree.configuration.Strategy
35+
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
3636
import org.apache.spark.mllib.tree.DecisionTree
37-
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
37+
import org.apache.spark.mllib.tree.impurity._
3838
import org.apache.spark.mllib.tree.model.DecisionTreeModel
3939
import org.apache.spark.mllib.stat.Statistics
4040
import org.apache.spark.mllib.stat.correlation.CorrelationNames
@@ -498,17 +498,8 @@ class PythonMLLibAPI extends Serializable {
498498

499499
val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
500500

501-
val algo: Algo = algoStr match {
502-
case "classification" => Classification
503-
case "regression" => Regression
504-
case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr")
505-
}
506-
val impurity: Impurity = impurityStr match {
507-
case "gini" => Gini
508-
case "entropy" => Entropy
509-
case "variance" => Variance
510-
case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr")
511-
}
501+
val algo = Algo.stringToAlgo(algoStr)
502+
val impurity = Impurities.stringToImpurity(impurityStr)
512503

513504
val strategy = new Strategy(
514505
algo = algo,

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 193 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20+
import scala.collection.JavaConverters._
21+
2022
import org.apache.spark.annotation.Experimental
2123
import org.apache.spark.Logging
2224
import org.apache.spark.mllib.regression.LabeledPoint
23-
import org.apache.spark.mllib.tree.configuration.Strategy
25+
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
2426
import org.apache.spark.mllib.tree.configuration.Algo._
2527
import org.apache.spark.mllib.tree.configuration.FeatureType._
2628
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
27-
import org.apache.spark.mllib.tree.impurity.Impurity
29+
import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity}
2830
import org.apache.spark.mllib.tree.model._
2931
import org.apache.spark.rdd.RDD
3032
import org.apache.spark.util.random.XORShiftRandom
@@ -213,10 +215,8 @@ object DecisionTree extends Serializable with Logging {
213215
}
214216

215217
/**
216-
* Method to train a decision tree model where the instances are represented as an RDD of
217-
* (label, features) pairs. The method supports binary classification and regression. For the
218-
* binary classification, the label for each instance should either be 0 or 1 to denote the two
219-
* classes.
218+
* Method to train a decision tree model.
219+
* The method supports binary and multiclass classification and regression.
220220
*
221221
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
222222
* For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -237,10 +237,8 @@ object DecisionTree extends Serializable with Logging {
237237
}
238238

239239
/**
240-
* Method to train a decision tree model where the instances are represented as an RDD of
241-
* (label, features) pairs. The method supports binary classification and regression. For the
242-
* binary classification, the label for each instance should either be 0 or 1 to denote the two
243-
* classes.
240+
* Method to train a decision tree model.
241+
* The method supports binary and multiclass classification and regression.
244242
*
245243
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
246244
* For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -263,11 +261,8 @@ object DecisionTree extends Serializable with Logging {
263261
}
264262

265263
/**
266-
* Method to train a decision tree model where the instances are represented as an RDD of
267-
* (label, features) pairs. The decision tree method supports binary classification and
268-
* regression. For the binary classification, the label for each instance should either be 0 or
269-
* 1 to denote the two classes. The method also supports categorical features inputs where the
270-
* number of categories can specified using the categoricalFeaturesInfo option.
264+
* Method to train a decision tree model.
265+
* The method supports binary and multiclass classification and regression.
271266
*
272267
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
273268
* For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -279,11 +274,9 @@ object DecisionTree extends Serializable with Logging {
279274
* @param numClassesForClassification number of classes for classification. Default value of 2.
280275
* @param maxBins maximum number of bins used for splitting features
281276
* @param quantileCalculationStrategy algorithm for calculating quantiles
282-
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
283-
* the number of discrete values they take. For example,
284-
* an entry (n -> k) implies the feature n is categorical with k
285-
* categories 0, 1, 2, ... , k-1. It's important to note that
286-
* features are zero-indexed.
277+
* @param categoricalFeaturesInfo Map storing arity of categorical features.
278+
* E.g., an entry (n -> k) indicates that feature n is categorical
279+
* with k categories indexed from 0: {0, 1, ..., k-1}.
287280
* @return DecisionTreeModel that can be used for prediction
288281
*/
289282
def train(
@@ -300,32 +293,197 @@ object DecisionTree extends Serializable with Logging {
300293
new DecisionTree(strategy).train(input)
301294
}
302295

303-
// Optional arguments in Python: maxBins
296+
/**
297+
* Method to train a decision tree model.
298+
* The method supports binary and multiclass classification and regression.
299+
* This version takes basic types, for consistency with Python API.
300+
*
301+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
302+
* For classification, labels should take values {0, 1, ..., numClasses-1}.
303+
* For regression, labels are real numbers.
304+
* @param algo "classification" or "regression"
305+
* @param numClassesForClassification number of classes for classification. Default value of 2.
306+
* @param categoricalFeaturesInfo Map storing arity of categorical features.
307+
* E.g., an entry (n -> k) indicates that feature n is categorical
308+
* with k categories indexed from 0: {0, 1, ..., k-1}.
309+
* @param impurity criterion used for information gain calculation
310+
* @param maxDepth Maximum depth of the tree.
311+
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
312+
* @param maxBins maximum number of bins used for splitting features
313+
* (default Python value = 100)
314+
* @return DecisionTreeModel that can be used for prediction
315+
*/
304316
def train(
305317
input: RDD[LabeledPoint],
306-
algo: Algo,
318+
algo: String,
307319
numClassesForClassification: Int,
308-
categoricalFeaturesInfo: Map[Int,Int],
309-
impurity: Impurity,
320+
categoricalFeaturesInfo: Map[Int, Int],
321+
impurity: String,
310322
maxDepth: Int,
311-
maxBins: Int): DecisionTreeModel = ???
323+
maxBins: Int): DecisionTreeModel = {
324+
val algoType = Algo.stringToAlgo(algo)
325+
val impurityType = Impurities.stringToImpurity(impurity)
326+
train(input, algoType, impurityType, maxDepth, numClassesForClassification, maxBins, Sort,
327+
categoricalFeaturesInfo)
328+
}
312329

313-
// Optional arguments in Python: all but input, numClassesForClassification
330+
/**
331+
* Method to train a decision tree model.
332+
* The method supports binary and multiclass classification and regression.
333+
* This version takes basic types, for consistency with Python API.
334+
* This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
335+
*
336+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
337+
* For classification, labels should take values {0, 1, ..., numClasses-1}.
338+
* For regression, labels are real numbers.
339+
* @param algo "classification" or "regression"
340+
* @param numClassesForClassification number of classes for classification. Default value of 2.
341+
* @param categoricalFeaturesInfo Map storing arity of categorical features.
342+
* E.g., an entry (n -> k) indicates that feature n is categorical
343+
* with k categories indexed from 0: {0, 1, ..., k-1}.
344+
* @param impurity criterion used for information gain calculation
345+
* @param maxDepth Maximum depth of the tree.
346+
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
347+
* @param maxBins maximum number of bins used for splitting features
348+
* (default Python value = 100)
349+
* @return DecisionTreeModel that can be used for prediction
350+
*/
351+
def train(
352+
input: RDD[LabeledPoint],
353+
algo: String,
354+
numClassesForClassification: Int,
355+
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
356+
impurity: String,
357+
maxDepth: Int,
358+
maxBins: Int): DecisionTreeModel = {
359+
train(input, algo, numClassesForClassification,
360+
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
361+
impurity, maxDepth, maxBins)
362+
}
363+
364+
/**
365+
* Method to train a decision tree model for binary or multiclass classification.
366+
* This version takes basic types, for consistency with Python API.
367+
*
368+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
369+
* Labels should take values {0, 1, ..., numClasses-1}.
370+
* @param numClassesForClassification number of classes for classification.
371+
* @param categoricalFeaturesInfo Map storing arity of categorical features.
372+
* E.g., an entry (n -> k) indicates that feature n is categorical
373+
* with k categories indexed from 0: {0, 1, ..., k-1}.
374+
* (default Python value = {}, i.e., no categorical features)
375+
* @param impurity criterion used for information gain calculation
376+
* (default Python value = "gini")
377+
* @param maxDepth Maximum depth of the tree.
378+
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
379+
* (default Python value = 4)
380+
* @param maxBins maximum number of bins used for splitting features
381+
* (default Python value = 100)
382+
* @return DecisionTreeModel that can be used for prediction
383+
*/
314384
def trainClassifier(
315385
input: RDD[LabeledPoint],
316386
numClassesForClassification: Int,
317-
categoricalFeaturesInfo: Map[Int,Int],
318-
impurity: Impurity,
387+
categoricalFeaturesInfo: Map[Int, Int],
388+
impurity: String,
389+
maxDepth: Int,
390+
maxBins: Int): DecisionTreeModel = {
391+
train(input, "classification", numClassesForClassification, categoricalFeaturesInfo, impurity,
392+
maxDepth, maxBins)
393+
}
394+
395+
/**
396+
* Method to train a decision tree model for binary or multiclass classification.
397+
* This version takes basic types, for consistency with Python API.
398+
* This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
399+
*
400+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
401+
* Labels should take values {0, 1, ..., numClasses-1}.
402+
* @param numClassesForClassification number of classes for classification.
403+
* @param categoricalFeaturesInfo Map storing arity of categorical features.
404+
* E.g., an entry (n -> k) indicates that feature n is categorical
405+
* with k categories indexed from 0: {0, 1, ..., k-1}.
406+
* (default Python value = {}, i.e., no categorical features)
407+
* @param impurity criterion used for information gain calculation
408+
* (default Python value = "gini")
409+
* @param maxDepth Maximum depth of the tree.
410+
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
411+
* (default Python value = 4)
412+
* @param maxBins maximum number of bins used for splitting features
413+
* (default Python value = 100)
414+
* @return DecisionTreeModel that can be used for prediction
415+
*/
416+
def trainClassifier(
417+
input: RDD[LabeledPoint],
418+
numClassesForClassification: Int,
419+
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
420+
impurity: String,
421+
maxDepth: Int,
422+
maxBins: Int): DecisionTreeModel = {
423+
trainClassifier(input, numClassesForClassification,
424+
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
425+
impurity, maxDepth, maxBins)
426+
}
427+
428+
/**
429+
* Method to train a decision tree model for regression.
430+
* This version takes basic types, for consistency with Python API.
431+
*
432+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
433+
* Labels are real numbers.
434+
* @param categoricalFeaturesInfo Map storing arity of categorical features.
435+
* E.g., an entry (n -> k) indicates that feature n is categorical
436+
* with k categories indexed from 0: {0, 1, ..., k-1}.
437+
* (default Python value = {}, i.e., no categorical features)
438+
* @param impurity criterion used for information gain calculation
439+
* (default Python value = "variance")
440+
* @param maxDepth Maximum depth of the tree.
441+
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
442+
* (default Python value = 4)
443+
* @param maxBins maximum number of bins used for splitting features
444+
* (default Python value = 100)
445+
* @return DecisionTreeModel that can be used for prediction
446+
*/
447+
def trainRegressor(
448+
input: RDD[LabeledPoint],
449+
categoricalFeaturesInfo: Map[Int, Int],
450+
impurity: String,
319451
maxDepth: Int,
320-
maxBins: Int): DecisionTreeModel = ???
452+
maxBins: Int): DecisionTreeModel = {
453+
train(input, "regression", 0, categoricalFeaturesInfo, impurity, maxDepth, maxBins)
454+
}
321455

322-
// Optional arguments in Python: all but input
456+
/**
457+
* Method to train a decision tree model for regression.
458+
* This version takes basic types, for consistency with Python API.
459+
* This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
460+
*
461+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
462+
* Labels are real numbers.
463+
* @param categoricalFeaturesInfo Map storing arity of categorical features.
464+
* E.g., an entry (n -> k) indicates that feature n is categorical
465+
* with k categories indexed from 0: {0, 1, ..., k-1}.
466+
* (default Python value = {}, i.e., no categorical features)
467+
* @param impurity criterion used for information gain calculation
468+
* (default Python value = "variance")
469+
* @param maxDepth Maximum depth of the tree.
470+
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
471+
* (default Python value = 4)
472+
* @param maxBins maximum number of bins used for splitting features
473+
* (default Python value = 100)
474+
* @return DecisionTreeModel that can be used for prediction
475+
*/
323476
def trainRegressor(
324477
input: RDD[LabeledPoint],
325-
categoricalFeaturesInfo: Map[Int,Int],
326-
impurity: Impurity,
478+
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
479+
impurity: String,
327480
maxDepth: Int,
328-
maxBins: Int): DecisionTreeModel = ???
481+
maxBins: Int): DecisionTreeModel = {
482+
trainRegressor(input,
483+
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
484+
impurity, maxDepth, maxBins)
485+
}
486+
329487

330488
private val InvalidBinIndex = -1
331489

@@ -1361,10 +1519,10 @@ object DecisionTree extends Serializable with Logging {
13611519
* (a) For multiclass classification with a low-arity feature
13621520
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
13631521
* the feature is split based on subsets of categories.
1364-
* There are 2^(maxFeatureValue - 1) - 1 splits.
1522+
* There are math.pow(2, (maxFeatureValue - 1) - 1) splits.
13651523
* (b) For regression and binary classification,
13661524
* and for multiclass classification with a high-arity feature,
1367-
* there is one split per category.
1525+
*
13681526
13691527
* Categorical case (a) features are called unordered features.
13701528
* Other cases are called ordered features.

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,10 @@ import org.apache.spark.annotation.Experimental
2727
object Algo extends Enumeration {
2828
type Algo = Value
2929
val Classification, Regression = Value
30+
31+
private[mllib] def stringToAlgo(name: String): Algo = name match {
32+
case "classification" => Classification
33+
case "regression" => Regression
34+
case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name")
35+
}
3036
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.impurity
19+
20+
/**
21+
* Factory class for Impurity types.
22+
*/
23+
private[mllib] object Impurities {
24+
25+
def stringToImpurity(name: String): Impurity = name match {
26+
case "gini" => Gini
27+
case "entropy" => Entropy
28+
case "variance" => Variance
29+
case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name")
30+
}
31+
32+
}

0 commit comments

Comments
 (0)