1717
1818package org .apache .spark .mllib .tree
1919
20+ import org .apache .spark .api .java .JavaRDD
21+
22+ import scala .collection .JavaConverters ._
23+
2024import org .apache .spark .annotation .Experimental
2125import org .apache .spark .Logging
2226import org .apache .spark .mllib .regression .LabeledPoint
23- import org .apache .spark .mllib .tree .configuration .Strategy
27+ import org .apache .spark .mllib .tree .configuration .{ Algo , Strategy }
2428import org .apache .spark .mllib .tree .configuration .Algo ._
2529import org .apache .spark .mllib .tree .configuration .FeatureType ._
2630import org .apache .spark .mllib .tree .configuration .QuantileStrategy ._
27- import org .apache .spark .mllib .tree .impurity .Impurity
31+ import org .apache .spark .mllib .tree .impurity .{ Impurities , Gini , Entropy , Impurity }
2832import org .apache .spark .mllib .tree .model ._
2933import org .apache .spark .rdd .RDD
3034import org .apache .spark .util .random .XORShiftRandom
@@ -200,6 +204,10 @@ object DecisionTree extends Serializable with Logging {
200204 * Method to train a decision tree model.
201205 * The method supports binary and multiclass classification and regression.
202206 *
207+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier ]]
208+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor ]]
209+ * is recommended to clearly separate classification and regression.
210+ *
203211 * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
204212 * For classification, labels should take values {0, 1, ..., numClasses-1}.
205213 * For regression, labels are real numbers.
@@ -213,10 +221,12 @@ object DecisionTree extends Serializable with Logging {
213221 }
214222
215223 /**
216- * Method to train a decision tree model where the instances are represented as an RDD of
217- * (label, features) pairs. The method supports binary classification and regression. For the
218- * binary classification, the label for each instance should either be 0 or 1 to denote the two
219- * classes.
224+ * Method to train a decision tree model.
225+ * The method supports binary and multiclass classification and regression.
226+ *
227+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier ]]
228+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor ]]
229+ * is recommended to clearly separate classification and regression.
220230 *
221231 * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
222232 * For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -237,10 +247,12 @@ object DecisionTree extends Serializable with Logging {
237247 }
238248
239249 /**
240- * Method to train a decision tree model where the instances are represented as an RDD of
241- * (label, features) pairs. The method supports binary classification and regression. For the
242- * binary classification, the label for each instance should either be 0 or 1 to denote the two
243- * classes.
250+ * Method to train a decision tree model.
251+ * The method supports binary and multiclass classification and regression.
252+ *
253+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier ]]
254+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor ]]
255+ * is recommended to clearly separate classification and regression.
244256 *
245257 * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
246258 * For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -263,11 +275,12 @@ object DecisionTree extends Serializable with Logging {
263275 }
264276
265277 /**
266- * Method to train a decision tree model where the instances are represented as an RDD of
267- * (label, features) pairs. The decision tree method supports binary classification and
268- * regression. For the binary classification, the label for each instance should either be 0 or
269- * 1 to denote the two classes. The method also supports categorical features inputs where the
270- * number of categories can specified using the categoricalFeaturesInfo option.
278+ * Method to train a decision tree model.
279+ * The method supports binary and multiclass classification and regression.
280+ *
281+ * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier ]]
282+ * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor ]]
283+ * is recommended to clearly separate classification and regression.
271284 *
272285 * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
273286 * For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -279,11 +292,9 @@ object DecisionTree extends Serializable with Logging {
279292 * @param numClassesForClassification number of classes for classification. Default value of 2.
280293 * @param maxBins maximum number of bins used for splitting features
281294 * @param quantileCalculationStrategy algorithm for calculating quantiles
282- * @param categoricalFeaturesInfo A map storing information about the categorical variables and
283- * the number of discrete values they take. For example,
284- * an entry (n -> k) implies the feature n is categorical with k
285- * categories 0, 1, 2, ... , k-1. It's important to note that
286- * features are zero-indexed.
295+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
296+ * E.g., an entry (n -> k) indicates that feature n is categorical
297+ * with k categories indexed from 0: {0, 1, ..., k-1}.
287298 * @return DecisionTreeModel that can be used for prediction
288299 */
289300 def train (
@@ -300,6 +311,93 @@ object DecisionTree extends Serializable with Logging {
300311 new DecisionTree (strategy).train(input)
301312 }
302313
314+ /**
315+ * Method to train a decision tree model for binary or multiclass classification.
316+ *
317+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
318+ * Labels should take values {0, 1, ..., numClasses-1}.
319+ * @param numClassesForClassification number of classes for classification.
320+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
321+ * E.g., an entry (n -> k) indicates that feature n is categorical
322+ * with k categories indexed from 0: {0, 1, ..., k-1}.
323+ * @param impurity Criterion used for information gain calculation.
324+ * Supported values: "gini" (recommended) or "entropy".
325+ * @param maxDepth Maximum depth of the tree.
326+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
327+ * (suggested value: 4)
328+ * @param maxBins maximum number of bins used for splitting features
329+ * (suggested value: 100)
330+ * @return DecisionTreeModel that can be used for prediction
331+ */
332+ def trainClassifier (
333+ input : RDD [LabeledPoint ],
334+ numClassesForClassification : Int ,
335+ categoricalFeaturesInfo : Map [Int , Int ],
336+ impurity : String ,
337+ maxDepth : Int ,
338+ maxBins : Int ): DecisionTreeModel = {
339+ val impurityType = Impurities .fromString(impurity)
340+ train(input, Classification , impurityType, maxDepth, numClassesForClassification, maxBins, Sort ,
341+ categoricalFeaturesInfo)
342+ }
343+
344+ /**
345+ * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier ]]
346+ */
347+ def trainClassifier (
348+ input : JavaRDD [LabeledPoint ],
349+ numClassesForClassification : Int ,
350+ categoricalFeaturesInfo : java.util.Map [java.lang.Integer , java.lang.Integer ],
351+ impurity : String ,
352+ maxDepth : Int ,
353+ maxBins : Int ): DecisionTreeModel = {
354+ trainClassifier(input.rdd, numClassesForClassification,
355+ categoricalFeaturesInfo.asInstanceOf [java.util.Map [Int , Int ]].asScala.toMap,
356+ impurity, maxDepth, maxBins)
357+ }
358+
359+ /**
360+ * Method to train a decision tree model for regression.
361+ *
362+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
363+ * Labels are real numbers.
364+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
365+ * E.g., an entry (n -> k) indicates that feature n is categorical
366+ * with k categories indexed from 0: {0, 1, ..., k-1}.
367+ * @param impurity Criterion used for information gain calculation.
368+ * Supported values: "variance".
369+ * @param maxDepth Maximum depth of the tree.
370+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
371+ * (suggested value: 4)
372+ * @param maxBins maximum number of bins used for splitting features
373+ * (suggested value: 100)
374+ * @return DecisionTreeModel that can be used for prediction
375+ */
376+ def trainRegressor (
377+ input : RDD [LabeledPoint ],
378+ categoricalFeaturesInfo : Map [Int , Int ],
379+ impurity : String ,
380+ maxDepth : Int ,
381+ maxBins : Int ): DecisionTreeModel = {
382+ val impurityType = Impurities .fromString(impurity)
383+ train(input, Regression , impurityType, maxDepth, 0 , maxBins, Sort , categoricalFeaturesInfo)
384+ }
385+
386+ /**
387+ * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor ]]
388+ */
389+ def trainRegressor (
390+ input : JavaRDD [LabeledPoint ],
391+ categoricalFeaturesInfo : java.util.Map [java.lang.Integer , java.lang.Integer ],
392+ impurity : String ,
393+ maxDepth : Int ,
394+ maxBins : Int ): DecisionTreeModel = {
395+ trainRegressor(input.rdd,
396+ categoricalFeaturesInfo.asInstanceOf [java.util.Map [Int , Int ]].asScala.toMap,
397+ impurity, maxDepth, maxBins)
398+ }
399+
400+
303401 private val InvalidBinIndex = - 1
304402
305403 /**
@@ -1331,16 +1429,15 @@ object DecisionTree extends Serializable with Logging {
13311429 * Categorical features:
13321430 * For each feature, there is 1 bin per split.
13331431 * Splits and bins are handled in 2 ways:
1334- * (a) For multiclass classification with a low-arity feature
1432+ * (a) "unordered features"
1433+ * For multiclass classification with a low-arity feature
13351434 * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
13361435 * the feature is split based on subsets of categories.
1337- * There are 2^(maxFeatureValue - 1) - 1 splits.
1338- * (b) For regression and binary classification,
1436+ * There are math.pow(2, maxFeatureValue - 1) - 1 splits.
1437+ * (b) "ordered features"
1438+ * For regression and binary classification,
13391439 * and for multiclass classification with a high-arity feature,
1340- * there is one split per category.
1341-
1342- * Categorical case (a) features are called unordered features.
1343- * Other cases are called ordered features.
1440+ * there is one bin per category.
13441441 *
13451442 * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]]
13461443 * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy ]] instance containing
0 commit comments