1717
1818package org .apache .spark .mllib .tree
1919
20+ import scala .collection .JavaConverters ._
21+
2022import org .apache .spark .annotation .Experimental
2123import org .apache .spark .Logging
2224import org .apache .spark .mllib .regression .LabeledPoint
23- import org .apache .spark .mllib .tree .configuration .Strategy
25+ import org .apache .spark .mllib .tree .configuration .{ Algo , Strategy }
2426import org .apache .spark .mllib .tree .configuration .Algo ._
2527import org .apache .spark .mllib .tree .configuration .FeatureType ._
2628import org .apache .spark .mllib .tree .configuration .QuantileStrategy ._
27- import org .apache .spark .mllib .tree .impurity .Impurity
29+ import org .apache .spark .mllib .tree .impurity .{ Impurities , Gini , Entropy , Impurity }
2830import org .apache .spark .mllib .tree .model ._
2931import org .apache .spark .rdd .RDD
3032import org .apache .spark .util .random .XORShiftRandom
@@ -213,10 +215,8 @@ object DecisionTree extends Serializable with Logging {
213215 }
214216
215217 /**
216- * Method to train a decision tree model where the instances are represented as an RDD of
217- * (label, features) pairs. The method supports binary classification and regression. For the
218- * binary classification, the label for each instance should either be 0 or 1 to denote the two
219- * classes.
218+ * Method to train a decision tree model.
219+ * The method supports binary and multiclass classification and regression.
220220 *
221221 * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
222222 * For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -237,10 +237,8 @@ object DecisionTree extends Serializable with Logging {
237237 }
238238
239239 /**
240- * Method to train a decision tree model where the instances are represented as an RDD of
241- * (label, features) pairs. The method supports binary classification and regression. For the
242- * binary classification, the label for each instance should either be 0 or 1 to denote the two
243- * classes.
240+ * Method to train a decision tree model.
241+ * The method supports binary and multiclass classification and regression.
244242 *
245243 * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
246244 * For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -263,11 +261,8 @@ object DecisionTree extends Serializable with Logging {
263261 }
264262
265263 /**
266- * Method to train a decision tree model where the instances are represented as an RDD of
267- * (label, features) pairs. The decision tree method supports binary classification and
268- * regression. For the binary classification, the label for each instance should either be 0 or
269- * 1 to denote the two classes. The method also supports categorical features inputs where the
270- * number of categories can specified using the categoricalFeaturesInfo option.
264+ * Method to train a decision tree model.
265+ * The method supports binary and multiclass classification and regression.
271266 *
272267 * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
273268 * For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -279,11 +274,9 @@ object DecisionTree extends Serializable with Logging {
279274 * @param numClassesForClassification number of classes for classification. Default value of 2.
280275 * @param maxBins maximum number of bins used for splitting features
281276 * @param quantileCalculationStrategy algorithm for calculating quantiles
282- * @param categoricalFeaturesInfo A map storing information about the categorical variables and
283- * the number of discrete values they take. For example,
284- * an entry (n -> k) implies the feature n is categorical with k
285- * categories 0, 1, 2, ... , k-1. It's important to note that
286- * features are zero-indexed.
277+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
278+ * E.g., an entry (n -> k) indicates that feature n is categorical
279+ * with k categories indexed from 0: {0, 1, ..., k-1}.
287280 * @return DecisionTreeModel that can be used for prediction
288281 */
289282 def train (
@@ -300,32 +293,197 @@ object DecisionTree extends Serializable with Logging {
300293 new DecisionTree (strategy).train(input)
301294 }
302295
303- // Optional arguments in Python: maxBins
296+ /**
297+ * Method to train a decision tree model.
298+ * The method supports binary and multiclass classification and regression.
299+ * This version takes basic types, for consistency with Python API.
300+ *
301+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
302+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
303+ * For regression, labels are real numbers.
304+ * @param algo "classification" or "regression"
305+ * @param numClassesForClassification number of classes for classification. Default value of 2.
306+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
307+ * E.g., an entry (n -> k) indicates that feature n is categorical
308+ * with k categories indexed from 0: {0, 1, ..., k-1}.
309+ * @param impurity criterion used for information gain calculation
310+ * @param maxDepth Maximum depth of the tree.
311+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
312+ * @param maxBins maximum number of bins used for splitting features
313+ * (default Python value = 100)
314+ * @return DecisionTreeModel that can be used for prediction
315+ */
304316 def train (
305317 input : RDD [LabeledPoint ],
306- algo : Algo ,
318+ algo : String ,
307319 numClassesForClassification : Int ,
308- categoricalFeaturesInfo : Map [Int ,Int ],
309- impurity : Impurity ,
320+ categoricalFeaturesInfo : Map [Int , Int ],
321+ impurity : String ,
310322 maxDepth : Int ,
311- maxBins : Int ): DecisionTreeModel = ???
323+ maxBins : Int ): DecisionTreeModel = {
324+ val algoType = Algo .stringToAlgo(algo)
325+ val impurityType = Impurities .stringToImpurity(impurity)
326+ train(input, algoType, impurityType, maxDepth, numClassesForClassification, maxBins, Sort ,
327+ categoricalFeaturesInfo)
328+ }
312329
313- // Optional arguments in Python: all but input, numClassesForClassification
330+ /**
331+ * Method to train a decision tree model.
332+ * The method supports binary and multiclass classification and regression.
333+ * This version takes basic types, for consistency with Python API.
334+ * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
335+ *
336+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
337+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
338+ * For regression, labels are real numbers.
339+ * @param algo "classification" or "regression"
340+ * @param numClassesForClassification number of classes for classification. Default value of 2.
341+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
342+ * E.g., an entry (n -> k) indicates that feature n is categorical
343+ * with k categories indexed from 0: {0, 1, ..., k-1}.
344+ * @param impurity criterion used for information gain calculation
345+ * @param maxDepth Maximum depth of the tree.
346+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
347+ * @param maxBins maximum number of bins used for splitting features
348+ * (default Python value = 100)
349+ * @return DecisionTreeModel that can be used for prediction
350+ */
351+ def train (
352+ input : RDD [LabeledPoint ],
353+ algo : String ,
354+ numClassesForClassification : Int ,
355+ categoricalFeaturesInfo : java.util.Map [java.lang.Integer , java.lang.Integer ],
356+ impurity : String ,
357+ maxDepth : Int ,
358+ maxBins : Int ): DecisionTreeModel = {
359+ train(input, algo, numClassesForClassification,
360+ categoricalFeaturesInfo.asInstanceOf [java.util.Map [Int , Int ]].asScala.toMap,
361+ impurity, maxDepth, maxBins)
362+ }
363+
364+ /**
365+ * Method to train a decision tree model for binary or multiclass classification.
366+ * This version takes basic types, for consistency with Python API.
367+ *
368+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
369+ * Labels should take values {0, 1, ..., numClasses-1}.
370+ * @param numClassesForClassification number of classes for classification.
371+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
372+ * E.g., an entry (n -> k) indicates that feature n is categorical
373+ * with k categories indexed from 0: {0, 1, ..., k-1}.
374+ * (default Python value = {}, i.e., no categorical features)
375+ * @param impurity criterion used for information gain calculation
376+ * (default Python value = "gini")
377+ * @param maxDepth Maximum depth of the tree.
378+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
379+ * (default Python value = 4)
380+ * @param maxBins maximum number of bins used for splitting features
381+ * (default Python value = 100)
382+ * @return DecisionTreeModel that can be used for prediction
383+ */
314384 def trainClassifier (
315385 input : RDD [LabeledPoint ],
316386 numClassesForClassification : Int ,
317- categoricalFeaturesInfo : Map [Int ,Int ],
318- impurity : Impurity ,
387+ categoricalFeaturesInfo : Map [Int , Int ],
388+ impurity : String ,
389+ maxDepth : Int ,
390+ maxBins : Int ): DecisionTreeModel = {
391+ train(input, " classification" , numClassesForClassification, categoricalFeaturesInfo, impurity,
392+ maxDepth, maxBins)
393+ }
394+
395+ /**
396+ * Method to train a decision tree model for binary or multiclass classification.
397+ * This version takes basic types, for consistency with Python API.
398+ * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
399+ *
400+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
401+ * Labels should take values {0, 1, ..., numClasses-1}.
402+ * @param numClassesForClassification number of classes for classification.
403+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
404+ * E.g., an entry (n -> k) indicates that feature n is categorical
405+ * with k categories indexed from 0: {0, 1, ..., k-1}.
406+ * (default Python value = {}, i.e., no categorical features)
407+ * @param impurity criterion used for information gain calculation
408+ * (default Python value = "gini")
409+ * @param maxDepth Maximum depth of the tree.
410+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
411+ * (default Python value = 4)
412+ * @param maxBins maximum number of bins used for splitting features
413+ * (default Python value = 100)
414+ * @return DecisionTreeModel that can be used for prediction
415+ */
416+ def trainClassifier (
417+ input : RDD [LabeledPoint ],
418+ numClassesForClassification : Int ,
419+ categoricalFeaturesInfo : java.util.Map [java.lang.Integer , java.lang.Integer ],
420+ impurity : String ,
421+ maxDepth : Int ,
422+ maxBins : Int ): DecisionTreeModel = {
423+ trainClassifier(input, numClassesForClassification,
424+ categoricalFeaturesInfo.asInstanceOf [java.util.Map [Int , Int ]].asScala.toMap,
425+ impurity, maxDepth, maxBins)
426+ }
427+
428+ /**
429+ * Method to train a decision tree model for regression.
430+ * This version takes basic types, for consistency with Python API.
431+ *
432+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
433+ * Labels are real numbers.
434+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
435+ * E.g., an entry (n -> k) indicates that feature n is categorical
436+ * with k categories indexed from 0: {0, 1, ..., k-1}.
437+ * (default Python value = {}, i.e., no categorical features)
438+ * @param impurity criterion used for information gain calculation
439+ * (default Python value = "variance")
440+ * @param maxDepth Maximum depth of the tree.
441+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
442+ * (default Python value = 4)
443+ * @param maxBins maximum number of bins used for splitting features
444+ * (default Python value = 100)
445+ * @return DecisionTreeModel that can be used for prediction
446+ */
447+ def trainRegressor (
448+ input : RDD [LabeledPoint ],
449+ categoricalFeaturesInfo : Map [Int , Int ],
450+ impurity : String ,
319451 maxDepth : Int ,
320- maxBins : Int ): DecisionTreeModel = ???
452+ maxBins : Int ): DecisionTreeModel = {
453+ train(input, " regression" , 0 , categoricalFeaturesInfo, impurity, maxDepth, maxBins)
454+ }
321455
322- // Optional arguments in Python: all but input
456+ /**
457+ * Method to train a decision tree model for regression.
458+ * This version takes basic types, for consistency with Python API.
459+ * This version is Java-friendly, taking a Java map for categoricalFeaturesInfo.
460+ *
461+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
462+ * Labels are real numbers.
463+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
464+ * E.g., an entry (n -> k) indicates that feature n is categorical
465+ * with k categories indexed from 0: {0, 1, ..., k-1}.
466+ * (default Python value = {}, i.e., no categorical features)
467+ * @param impurity criterion used for information gain calculation
468+ * (default Python value = "variance")
469+ * @param maxDepth Maximum depth of the tree.
470+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
471+ * (default Python value = 4)
472+ * @param maxBins maximum number of bins used for splitting features
473+ * (default Python value = 100)
474+ * @return DecisionTreeModel that can be used for prediction
475+ */
323476 def trainRegressor (
324477 input : RDD [LabeledPoint ],
325- categoricalFeaturesInfo : Map [Int , Int ],
326- impurity : Impurity ,
478+ categoricalFeaturesInfo : java.util. Map [java.lang. Integer , java.lang. Integer ],
479+ impurity : String ,
327480 maxDepth : Int ,
328- maxBins : Int ): DecisionTreeModel = ???
481+ maxBins : Int ): DecisionTreeModel = {
482+ trainRegressor(input,
483+ categoricalFeaturesInfo.asInstanceOf [java.util.Map [Int , Int ]].asScala.toMap,
484+ impurity, maxDepth, maxBins)
485+ }
486+
329487
330488 private val InvalidBinIndex = - 1
331489
@@ -1361,10 +1519,10 @@ object DecisionTree extends Serializable with Logging {
13611519 * (a) For multiclass classification with a low-arity feature
13621520 * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
13631521 * the feature is split based on subsets of categories.
1364- * There are 2^ (maxFeatureValue - 1) - 1 splits.
1522+ * There are math.pow(2, (maxFeatureValue - 1) - 1) splits.
13651523 * (b) For regression and binary classification,
13661524 * and for multiclass classification with a high-arity feature,
1367- * there is one split per category.
1525+ *
13681526
13691527 * Categorical case (a) features are called unordered features.
13701528 * Other cases are called ordered features.
0 commit comments