diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 9cbd880897578..9b8bf7063b0fe 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -116,12 +116,10 @@ maximum tree depth of 5. The training error is calculated to measure the algorit
{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.rdd.DatasetInfo +import org.apache.spark.mllib.tree.DecisionTreeClassifier import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.Gini // Load and parse the data file val data = sc.textFile("data/mllib/sample_tree_data.csv") @@ -129,10 +127,17 @@ val parsedData = data.map { line => val parts = line.split(',').map(_.toDouble) LabeledPoint(parts(0), Vectors.dense(parts.tail)) } +val numFeatures = parsedData.take(1)(0).features.size +val datasetInfo = new DatasetInfo(numClasses = 2, numFeatures = numFeatures) // Run training algorithm to build the model -val maxDepth = 5 -val model = DecisionTree.train(parsedData, Classification, Gini, maxDepth) +val dtParams = DecisionTreeClassifier.defaultParams() +dtParams.impurity = "gini" +dtParams.maxDepth = 4 +val model = DecisionTreeClassifier.train(parsedData, datasetInfo, dtParams) + +// Print model in human-readable format. +model.print() // Evaluate model on training examples and compute training error val labelAndPreds = parsedData.map { point => @@ -155,12 +160,10 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.rdd.DatasetInfo +import org.apache.spark.mllib.tree.DecisionTreeRegressor import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.Variance // Load and parse the data file val data = sc.textFile("data/mllib/sample_tree_data.csv") @@ -168,10 +171,17 @@ val parsedData = data.map { line => val parts = line.split(',').map(_.toDouble) LabeledPoint(parts(0), Vectors.dense(parts.tail)) } +val numFeatures = parsedData.take(1)(0).features.size +val datasetInfo = new DatasetInfo(numClasses = 0, numFeatures = numFeatures) // Run training algorithm to build the model -val maxDepth = 5 -val model = DecisionTree.train(parsedData, Regression, Variance, maxDepth) +val dtParams = DecisionTreeRegressor.defaultParams() +dtParams.impurity = "variance" +dtParams.maxDepth = 4 +val model = DecisionTreeRegressor.train(parsedData, datasetInfo, dtParams) + +// Print model in human-readable format. +model.print() // Evaluate model on training examples and compute training error val valuesAndPreds = parsedData.map { point => @@ -179,7 +189,7 @@ val valuesAndPreds = parsedData.map { point => (point.label, prediction) } val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("training Mean Squared Error = " + MSE) +println("Training Mean Squared Error = " + MSE) {% endhighlight %}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 43f13fe24f0d0..62d5e4b032887 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -21,11 +21,10 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.rdd.DatasetInfo import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree, impurity} -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} -import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.{DecisionTreeClassifier, DecisionTreeRegressor} +import org.apache.spark.mllib.tree.configuration.{DTClassifierParams, DTRegressorParams} import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -36,23 +35,24 @@ import org.apache.spark.rdd.RDD * ./bin/spark-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify + * [[org.apache.spark.mllib.rdd.DatasetInfo.categoricalFeaturesInfo]]. */ object DecisionTreeRunner { - object ImpurityType extends Enumeration { - type ImpurityType = Value - val Gini, Entropy, Variance = Value - } - - import ImpurityType._ - case class Params( input: String = null, - algo: Algo = Classification, - numClassesForClassification: Int = 2, - maxDepth: Int = 5, - impurity: ImpurityType = Gini, - maxBins: Int = 100) + dataFormat: String = null, + algo: String = "classification", + impurity: Option[String] = None, + maxDepth: Int = 4, + maxBins: Int = 100, + fracTest: Double = 0.2) + + private val defaultCImpurity = new DTClassifierParams().impurity + private val defaultRImpurity = new DTRegressorParams().impurity def main(args: Array[String]) { val defaultParams = Params() @@ -60,35 +60,41 @@ object DecisionTreeRunner { val parser = new OptionParser[Params]("DecisionTreeRunner") { head("DecisionTreeRunner: an example decision tree app.") opt[String]("algo") - .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") - .action((x, c) => c.copy(algo = Algo.withName(x))) + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) opt[String]("impurity") - .text(s"impurity type (${ImpurityType.values.mkString(",")}), " + - s"default: ${defaultParams.impurity}") - .action((x, c) => c.copy(impurity = ImpurityType.withName(x))) + .text( + s"impurity type\n" + + s"\tFor classification: ${DTClassifierParams.supportedImpurities.mkString(",")}\n" + + s"\t default: $defaultCImpurity" + + s"\tFor regression: ${DTRegressorParams.supportedImpurities.mkString(",")}\n" + + s"\t default: $defaultRImpurity") + .action((x, c) => c.copy(impurity = Some(x))) opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) - opt[Int]("numClassesForClassification") - .text(s"number of classes for classification, " - + s"default: ${defaultParams.numClassesForClassification}") - .action((x, c) => c.copy(numClassesForClassification = x)) opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) arg[String]("") - .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)") + .text("input paths to labeled examples") .required() .action((x, c) => c.copy(input = x)) + arg[String]("") + .text("data format: dense/libsvm") + .required() + .action((x, c) => c.copy(dataFormat = x)) checkConfig { params => - if (params.algo == Classification && - (params.impurity == Gini || params.impurity == Entropy)) { - success - } else if (params.algo == Regression && params.impurity == Variance) { - success - } else { - failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + if (!List("classification", "regression").contains(params.algo)) { + failure(s"Did not recognize Algo: ${params.algo}") + } + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") } + success } } @@ -104,42 +110,92 @@ object DecisionTreeRunner { val sc = new SparkContext(conf) // Load training data and cache it. - val examples = MLUtils.loadLabeledPoints(sc, params.input).cache() + val origExamples = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache() + } + // For classification, re-index classes if needed. + val (examples, numClasses) = params.algo match { + case "classification" => { + // classCounts: class --> # examples in class + val classCounts = origExamples.map(_.label).countByValue + val numClasses = classCounts.size + // classIndex: class --> index in 0,...,numClasses-1 + val classIndex = { + if (classCounts.keySet != Set[Double](0.0, 1.0)) { + classCounts.keys.toList.sorted.zipWithIndex.toMap + } else { + Map[Double, Int]() + } + } + val examples = { + if (classIndex.isEmpty) { + origExamples + } else { + origExamples.map(lp => LabeledPoint(classIndex(lp.label), lp.features)) + } + } + println(s"numClasses = $numClasses.") + println(s"Per-class example fractions, counts:") + println(s"Class\tFrac\tCount") + classCounts.keys.toList.sorted.foreach(c => { + val frac = classCounts(c) / (0.0 + examples.count()) + println(s"$c\t$frac\t${classCounts(c)}") + }) + (examples, numClasses) + } + case "regression" => { + (origExamples, 0) + } + case _ => { + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + } + // Split into training, test. val splits = examples.randomSplit(Array(0.8, 0.2)) val training = splits(0).cache() val test = splits(1).cache() - val numTraining = training.count() val numTest = test.count() - println(s"numTraining = $numTraining, numTest = $numTest.") + println(s"numTraining = $numTraining, numTest = $numTest") examples.unpersist(blocking = false) - val impurityCalculator = params.impurity match { - case Gini => impurity.Gini - case Entropy => impurity.Entropy - case Variance => impurity.Variance - } - - val strategy - = new Strategy( - algo = params.algo, - impurity = impurityCalculator, - maxDepth = params.maxDepth, - maxBins = params.maxBins, - numClassesForClassification = params.numClassesForClassification) - val model = DecisionTree.train(training, strategy) - - if (params.algo == Classification) { - val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy.") - } + val numFeatures = examples.take(1)(0).features.size + val datasetInfo = new DatasetInfo(numClasses, numFeatures) - if (params.algo == Regression) { - val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse.") + params.algo match { + case "classification" => { + val dtParams = DecisionTreeClassifier.defaultParams() + dtParams.maxDepth = params.maxDepth + dtParams.maxBins = params.maxBins + if (params.impurity == None) { + dtParams.impurity = defaultCImpurity + } + val dtLearner = new DecisionTreeClassifier(dtParams) + val model = dtLearner.run(training, datasetInfo) + println(model.toString) + val accuracy = accuracyScore(model, test) + println(s"Test accuracy = $accuracy") + } + case "regression" => { + val dtParams = DecisionTreeRegressor.defaultParams() + dtParams.maxDepth = params.maxDepth + dtParams.maxBins = params.maxBins + if (params.impurity == None) { + dtParams.impurity = defaultRImpurity + } + val dtLearner = new DecisionTreeRegressor(dtParams) + val model = dtLearner.run(training, datasetInfo) + println(model.toString) + val mse = meanSquaredError(model, test) + println(s"Test mean squared error = $mse") + } + case _ => { + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } } sc.stop() @@ -159,9 +215,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + private def meanSquaredError( + model: DecisionTreeModel, + data: RDD[LabeledPoint]): Double = { data.map { y => - val err = tree.predict(y.features) - y.label + val err = model.predict(y.features) - y.label err * err }.mean() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetInfo.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetInfo.scala new file mode 100644 index 0000000000000..b6643b7aae6ea --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetInfo.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +/** + * :: Experimental :: + * A class for holding dataset metadata. + * @param numClasses Number of classes for classification. Values of 0 or 1 indicate regression. + * @param numFeatures Number of features. + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ +class DatasetInfo ( + val numClasses: Int, + val numFeatures: Int, + val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) + extends Serializable { + + /** + * Indicates if this dataset's label is real-valued (numClasses < 2). + */ + def isRegression: Boolean = { + numClasses < 2 + } + + /** + * Indicates if this dataset's label is categorical (numClasses >= 2). + */ + def isClassification: Boolean = { + numClasses >= 2 + } + + /** + * Indicates if this dataset's label is categorical with >2 categories. + */ + def isMulticlass: Boolean = { + numClasses > 2 + } + + /** + * Indicates if this dataset's label is categorical with >2 categories, + * and there is at least one categorical feature. + */ + def isMulticlassWithCategoricalFeatures: Boolean = { + isMulticlass && categoricalFeaturesInfo.nonEmpty + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ad32e3f4560fe..c193339591025 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,48 +19,60 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Experimental import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.DatasetInfo import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.DTParams import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.configuration.QuantileStrategies +import org.apache.spark.mllib.tree.configuration.QuantileStrategy import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom + /** * :: Experimental :: - * A class that implements a decision tree algorithm for classification and regression. It - * supports both continuous and categorical features. - * @param strategy The configuration parameters for the tree algorithm which specify the type - * of algorithm (classification, regression, etc.), feature type (continuous, - * categorical), depth of the tree, quantile calculation strategy, etc. + * An abstract class for decision tree algorithms for classification and regression. + * It supports both continuous and categorical features. + * @param params The configuration parameters for the tree algorithm. */ @Experimental -class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { +private[mllib] abstract class DecisionTree (params: DTParams) + extends Serializable with Logging { - /** + protected final val InvalidBinIndex = -1 + + // depth of the decision tree + protected val maxDepth: Int = params.maxDepth + + protected val maxBins: Int = params.maxBins + + protected val quantileStrategy: QuantileStrategy.QuantileStrategy = + QuantileStrategies.strategy(params.quantileStrategy) + + protected val maxMemoryInMB: Int = params.maxMemoryInMB + + /** * Method to train a decision tree model over an RDD * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @return a DecisionTreeModel that can be used for prediction + * @param datasetInfo Dataset metadata. + * @return top node of a DecisionTreeModel */ - def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + protected def runSub( + input: RDD[LabeledPoint], + datasetInfo: DatasetInfo): Node = { // Cache input RDD for speedup during multiple passes. input.cache() - logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins) = findSplitsBins(input, datasetInfo) val numBins = bins(0).length logDebug("numBins = " + numBins) - // depth of the decision tree - val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -69,25 +81,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) - // num features - val numFeatures = input.take(1)(0).features.size // Calculate level for single group construction // Max memory usage for aggregates - val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 + val maxMemoryUsage = maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins, - strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures, - strategy.algo) - + val numElementsPerNode = getElementsPerNode(datasetInfo, numBins) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) // nodes at a level is 2^level. level is zero indexed. val maxLevelForSingleGroup = math.max( - (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) + (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, + 0) logDebug("max level for single group = " + maxLevelForSingleGroup) /* @@ -100,15 +108,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var level = 0 var break = false - while (level < maxDepth && !break) { + while (level <= maxDepth && !break) { logDebug("#####################################") logDebug("level = " + level) logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, - strategy, level, filters, splits, bins, maxLevelForSingleGroup) + val splitsStatsForLevel = findBestSplits(input, datasetInfo, parentImpurities, + level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { // Extract info for nodes at the current level. @@ -138,13 +146,92 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) - new DecisionTreeModel(topNode, strategy.algo) + topNode } + /** + * For a given categorical feature, use a subsample of the data + * to choose how to arrange possible splits. + * This examines each category and computes a centroid. + * These centroids are later used to sort the possible splits. + * @return Mapping: category (for the given feature) --> centroid + */ + protected def computeCentroidForCategories( + featureIndex: Int, + sampledInput: Array[LabeledPoint], + datasetInfo: DatasetInfo): Map[Double,Double] + + /** + * Extracts left and right split aggregates. + * @param binData Array[Double] of size 2 * numFeatures * numBins + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], + * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, + * (numBins - 1), numClasses) + */ + protected def extractLeftRightNodeAggregates( + binData: Array[Double], + datasetInfo: DatasetInfo, + numBins: Int): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) + + /** + * Get the number of stats elements stored per node in bin aggregates. + */ + protected def getElementsPerNode( + datasetInfo: DatasetInfo, + numBins: Int): Int + + /** + * Performs a sequential aggregation of bins stats over a partition. + * + * @param agg Array[Double] storing aggregate calculation of size + * numClasses * numBins * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @param datasetInfo Dataset metadata. + * @param numNodes Number of nodes in this (level of tree, group), + * where nodes at deeper (larger) levels may be divided into groups. + * @param bins Number of bins = 1 + number of possible splits. + * @return agg + */ + protected def binSeqOpSub( + agg: Array[Double], + arr: Array[Double], + datasetInfo: DatasetInfo, + numNodes: Int, + bins: Array[Array[Bin]]): Array[Double] + + /** + * Calculates the information gain for all splits based upon left/right split aggregates. + * @param leftNodeAgg left node aggregates + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg right node aggregate + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + protected def calculateGainForSplit( + leftNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Array[Double]]], + topImpurity: Double, + datasetInfo: DatasetInfo, + level: Int): InformationGainStats + + /** + * Get bin data for one node. + */ + protected def getBinDataForNode( + node: Int, + binAggregates: Array[Double], + datasetInfo: DatasetInfo, + numNodes: Int, + numBins: Int): Array[Double] + /** * Extract the decision tree node information for the given tree level and node index */ - private def extractNodeInfo( + protected def extractNodeInfo( nodeSplitStats: (Split, InformationGainStats), level: Int, index: Int, @@ -152,7 +239,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = math.pow(2, level).toInt - 1 + index - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val isLeaf = (stats.gain <= 0) || (level == maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node @@ -161,7 +248,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /** * Extract the decision tree node information for the children of the node */ - private def extractInfoForLowerLevels( + protected def extractInfoForLowerLevels( level: Int, index: Int, maxDepth: Int, @@ -173,7 +260,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth - 1) { + if (level < maxDepth) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity } else { @@ -192,110 +279,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo i += 1 } } -} - -object DecisionTree extends Serializable with Logging { - - /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. The parameters for the algorithm are specified using the strategy parameter. - * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree - * @param strategy The configuration parameters for the tree algorithm which specify the type - * of algorithm (classification, regression, etc.), feature type (continuous, - * categorical), depth of the tree, quantile calculation strategy, etc. - * @return a DecisionTreeModel that can be used for prediction - */ - def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - new DecisionTree(strategy).train(input) - } - - /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. - * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree - * @return a DecisionTreeModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int): DecisionTreeModel = { - val strategy = new Strategy(algo, impurity, maxDepth) - new DecisionTree(strategy).train(input) - } - - /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. - * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree - * @param numClassesForClassification number of classes for classification. Default value of 2. - * @return a DecisionTreeModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int, - numClassesForClassification: Int): DecisionTreeModel = { - val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification) - new DecisionTree(strategy).train(input) - } - - /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The decision tree method supports binary classification and - * regression. For the binary classification, the label for each instance should either be 0 or - * 1 to denote the two classes. The method also supports categorical features inputs where the - * number of categories can specified using the categoricalFeaturesInfo option. - * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data for DecisionTree - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree - * @param numClassesForClassification number of classes for classification. Default value of 2. - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @return a DecisionTreeModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int, - numClassesForClassification: Int, - maxBins: Int, - quantileCalculationStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { - val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, - quantileCalculationStrategy, categoricalFeaturesInfo) - new DecisionTree(strategy).train(input) - } - - private val InvalidBinIndex = -1 /** * Returns an array of optimal splits for all nodes at a given level. Splits the task into @@ -303,9 +286,8 @@ object DecisionTree extends Serializable with Logging { * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree + * @param datasetInfo Metadata for input. * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -315,44 +297,44 @@ object DecisionTree extends Serializable with Logging { */ protected[tree] def findBestSplits( input: RDD[LabeledPoint], + datasetInfo: DatasetInfo, parentImpurities: Array[Double], - strategy: Strategy, level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, // the nodes are divided into multiple groups at each level with the number of groups // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt + val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt logDebug("numGroups = " + numGroups) var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { - val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, + val bestSplitsForGroup = findBestSplitsPerGroup(input, datasetInfo, parentImpurities, level, filters, splits, bins, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + findBestSplitsPerGroup(input, datasetInfo, parentImpurities, level, filters, splits, bins) } } - /** + /** * Returns an array of optimal splits for a group of nodes at a given level * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree + * @param datasetInfo Metadata for input. * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -361,10 +343,10 @@ object DecisionTree extends Serializable with Logging { * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. */ - private def findBestSplitsPerGroup( + protected def findBestSplitsPerGroup( input: RDD[LabeledPoint], + datasetInfo: DatasetInfo, parentImpurities: Array[Double], - strategy: Strategy, level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], @@ -396,22 +378,19 @@ object DecisionTree extends Serializable with Logging { * drastically reduce the communication overhead. */ - // common calculations for multiple nested methods + // Common calculations for multiple nested methods: + + // Number of nodes to handle for each group in this level. val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) - // Find the number of features by looking at the first sample. - val numFeatures = input.first().features.size - logDebug("numFeatures = " + numFeatures) + logDebug("numFeatures = " + datasetInfo.numFeatures) val numBins = bins(0).length logDebug("numBins = " + numBins) - val numClasses = strategy.numClassesForClassification - logDebug("numClasses = " + numClasses) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) - val isMulticlassClassificationWithCategoricalFeatures - = strategy.isMulticlassWithCategoricalFeatures - logDebug("isMultiClassWithCategoricalFeatures = " + - isMulticlassClassificationWithCategoricalFeatures) + logDebug("numClasses = " + datasetInfo.numClasses) + val isMulticlass = datasetInfo.isMulticlass + logDebug("isMulticlass = " + isMulticlass) + val isMulticlassWithCategoricalFeatures = datasetInfo.isMulticlassWithCategoricalFeatures + logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex @@ -465,10 +444,13 @@ object DecisionTree extends Serializable with Logging { } /** - * Find bin for one feature. + * Find bin for one (labeledPoint, feature). */ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -510,7 +492,7 @@ object DecisionTree extends Serializable with Logging { * Sequential search helper method to find bin for categorical feature. */ def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { @@ -535,7 +517,8 @@ object DecisionTree extends Serializable with Logging { } else { // Perform sequential search to find bin for categorical features. val binIndex = { - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + val isUnorderedFeature = isMulticlass && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { sequentialBinSearchForOrderedCategoricalFeatureInClassification() @@ -549,16 +532,24 @@ object DecisionTree extends Serializable with Logging { } /** - * Finds bins for all nodes (and all features) at a given level. + * Finds bins for all nodes (and all features) in a given (level, group). * For l nodes, k features the storage is as follows: * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, * where b_ij is an integer between 0 and numBins - 1 for regressions and binary * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. + * + * For unordered features, the "bin index" returned is actually the feature value (category). + * + * @return Array of size 1 + numFeatures * numNodes, where + * arr(0) = label for labeledPoint, and + * arr(1 + numFeatures * nodeIndex + featureIndex) = + * bin index for this labeledPoint + * (or InvalidBinIndex if labeledPoint is not handled by this node). */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. - val arr = new Array[Double](1 + (numFeatures * numNodes)) + val arr = new Array[Double](1 + (datasetInfo.numFeatures * numNodes)) // First element of the array is the label of the instance. arr(0) = labeledPoint.label // Iterate over nodes. @@ -567,25 +558,25 @@ object DecisionTree extends Serializable with Logging { val parentFilters = findParentFilters(nodeIndex) // Find out whether the sample qualifies for the particular node. val sampleValid = isSampleValid(parentFilters, labeledPoint) - val shift = 1 + numFeatures * nodeIndex + val shift = 1 + datasetInfo.numFeatures * nodeIndex if (!sampleValid) { // Mark one bin as -1 is sufficient. arr(shift) = InvalidBinIndex } else { var featureIndex = 0 - while (featureIndex < numFeatures) { - val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex) + while (featureIndex < datasetInfo.numFeatures) { + val featureInfo = datasetInfo.categoricalFeaturesInfo.get(featureIndex) val isFeatureContinuous = featureInfo.isEmpty if (isFeatureContinuous) { arr(shift + featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) } else { val featureCategories = featureInfo.get - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, - isSpaceSufficientForAllCategoricalSplits) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 + arr(shift + featureIndex) = + findBin(featureIndex, labeledPoint, isFeatureContinuous, + isSpaceSufficientForAllCategoricalSplits) } featureIndex += 1 } @@ -595,185 +586,16 @@ object DecisionTree extends Serializable with Logging { arr } - // Find feature bins for all nodes at a level. + // Find feature bins for all nodes in this (level, group). val binMappedRDD = input.map(x => findBinsForLevel(x)) - def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int, - label: Double, featureIndex: Int) = { - - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses - val labelInt = label.toInt - agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 - } - - def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double], - label: Double, agg: Array[Double], rightChildShift: Int) = { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses - // Find all matching bins and increment their values - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 - var binIndex = 0 - while (binIndex < numCategoricalBins) { - val labelInt = label.toInt - if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { - agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + 1 - } else { - agg(rightChildShift + aggIndex + binIndex) - = agg(rightChildShift + aggIndex + binIndex) + 1 - } - binIndex += 1 - } - } - - /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. - * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures * numNodes for classification - */ - def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - featureIndex += 1 - } - } - nodeIndex += 1 - } - } - - /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. - * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numClasses * numSplits * numFeatures * numNodes for classification - */ - def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - val rightChildShift = numClasses * numBins * numFeatures * numNodes - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, - rightChildShift) - } else { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } - } - featureIndex += 1 - } - } - nodeIndex += 1 - } - } - - /** - * Performs a sequential aggregation over a partition for regression. For l nodes, k features, - * the count, sum, sum of squares of one of the p bins is incremented. - * - * @param agg Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for regression - */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update count, sum, and sum^2 for one bin. - val aggShift = 3 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 - agg(aggIndex) = agg(aggIndex) + 1 - agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label*label - featureIndex += 1 - } - } - nodeIndex += 1 - } - } - - /** - * Performs a sequential aggregation over a partition. - */ + // Performs a sequential aggregation over a partition. def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { - strategy.algo match { - case Classification => - if(isMulticlassClassificationWithCategoricalFeatures) { - unorderedClassificationBinSeqOp(arr, agg) - } else { - orderedClassificationBinSeqOp(arr, agg) - } - case Regression => regressionBinSeqOp(arr, agg) - } - agg + binSeqOpSub(agg, arr, datasetInfo, numNodes, bins) } // Calculate bin aggregate length for classification or regression. - val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses, - isMulticlassClassificationWithCategoricalFeatures, strategy.algo) + val binAggregateLength = numNodes * getElementsPerNode(datasetInfo, numBins) logDebug("binAggregateLength = " + binAggregateLength) /** @@ -794,299 +616,10 @@ object DecisionTree extends Serializable with Logging { // Calculate bin aggregates. val binAggregates = { - binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) } logDebug("binAggregates.length = " + binAggregates.length) - /** - * Calculates the information gain for all splits based upon left/right split aggregates. - * @param leftNodeAgg left node aggregates - * @param featureIndex feature index - * @param splitIndex split index - * @param rightNodeAgg right node aggregate - * @param topImpurity impurity of the parent node - * @return information gain and statistics for all splits - */ - def calculateGainForSplit( - leftNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int, - splitIndex: Int, - rightNodeAgg: Array[Array[Array[Double]]], - topImpurity: Double): InformationGainStats = { - strategy.algo match { - case Classification => - var classIndex = 0 - val leftCounts: Array[Double] = new Array[Double](numClasses) - val rightCounts: Array[Double] = new Array[Double](numClasses) - var leftTotalCount = 0.0 - var rightTotalCount = 0.0 - while (classIndex < numClasses) { - val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) - val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) - leftCounts(classIndex) = leftClassCount - leftTotalCount += leftClassCount - rightCounts(classIndex) = rightClassCount - rightTotalCount += rightClassCount - classIndex += 1 - } - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val rootNodeCounts = new Array[Double](numClasses) - var classIndex = 0 - while (classIndex < numClasses) { - rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) - classIndex += 1 - } - strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) - } - } - - if (leftTotalCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) - } - if (rightTotalCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) - } - - val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) - val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount) - - val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount) - val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount) - - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } - - val totalCount = leftTotalCount + rightTotalCount - - // Sum of count for each label - val leftRightCounts: Array[Double] - = leftCounts.zip(rightCounts) - .map{case (leftCount, rightCount) => leftCount + rightCount} - - def indexOfLargestArrayElement(array: Array[Double]): Int = { - val result = array.foldLeft(-1, Double.MinValue, 0) { - case ((maxIndex, maxValue, currentIndex), currentValue) => - if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) - else (maxIndex, maxValue, currentIndex + 1) - } - if (result._1 < 0) 0 else result._1 - } - - val predict = indexOfLargestArrayElement(leftRightCounts) - val prob = leftRightCounts(predict) / totalCount - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) - case Regression => - val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) - val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) - val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) - - val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) - val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) - val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val count = leftCount + rightCount - val sum = leftSum + rightSum - val sumSquares = leftSumSquares + rightSumSquares - strategy.impurity.calculate(count, sum, sumSquares) - } - } - - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum / rightCount) - } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity ,topImpurity, - Double.MinValue, leftSum / leftCount) - } - - val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) - val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) - - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) - - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } - - val predict = (leftSum + rightSum) / (leftCount + rightCount) - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) - } - } - - /** - * Extracts left and right split aggregates. - * @param binData Array[Double] of size 2*numFeatures*numSplits - * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], - * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, - * (numBins - 1), numClasses) - */ - def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { - - - def findAggForOrderedFeatureClassification( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - // shift for this featureIndex - val shift = numClasses * featureIndex * numBins - - var classIndex = 0 - while (classIndex < numClasses) { - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(classIndex) - = binData(shift + (numClasses * (numBins - 1)) + classIndex) - classIndex += 1 - } - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - var innerClassIndex = 0 - while (innerClassIndex < numClasses) { - leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) - = binData(shift + numClasses * splitIndex + innerClassIndex) + - leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = - binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) - innerClassIndex += 1 - } - splitIndex += 1 - } - } - - def findAggForUnorderedFeatureClassification( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - val rightChildShift = numClasses * numBins * numFeatures - var splitIndex = 0 - while (splitIndex < numBins - 1) { - var classIndex = 0 - while (classIndex < numClasses) { - // shift for this featureIndex - val shift = numClasses * featureIndex * numBins + splitIndex * numClasses - val leftBinValue = binData(shift + classIndex) - val rightBinValue = binData(rightChildShift + shift + classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue - rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue - classIndex += 1 - } - splitIndex += 1 - } - } - - def findAggForRegression( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - // shift for this featureIndex - val shift = 3 * featureIndex * numBins - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) - leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(0) = - binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(numBins - 2)(1) = - binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(numBins - 2)(2) = - binData(shift + (3 * (numBins - 1)) + 2) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - var i = 0 // index for regression histograms - while (i < 3) { // count, sum, sum^2 - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) + - leftNodeAgg(featureIndex)(splitIndex - 1)(i) - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) = - binData(shift + (3 * (numBins - 1 - splitIndex) + i)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i) - i += 1 - } - splitIndex += 1 - } - } - - strategy.algo match { - case Classification => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (isMulticlassClassificationWithCategoricalFeatures){ - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - } - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - featureIndex += 1 - } - - (leftNodeAgg, rightNodeAgg) - case Regression => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) - featureIndex += 1 - } - (leftNodeAgg, rightNodeAgg) - } - } - /** * Calculates information gain for all nodes splits. */ @@ -1094,12 +627,12 @@ object DecisionTree extends Serializable with Logging { leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], nodeImpurity: Double): Array[Array[InformationGainStats]] = { - val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) + val gains = Array.ofDim[InformationGainStats](datasetInfo.numFeatures, numBins - 1) - for (featureIndex <- 0 until numFeatures) { + for (featureIndex <- 0 until datasetInfo.numFeatures) { for (splitIndex <- 0 until numBins - 1) { gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, - splitIndex, rightNodeAgg, nodeImpurity) + splitIndex, rightNodeAgg, nodeImpurity, datasetInfo, level) } } gains @@ -1107,7 +640,7 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. - * @param binData Array[Double] of size 2 * numSplits * numFeatures + * @param binData Bin data slice for this node, given by getBinDataForNode. * @param nodeImpurity impurity of the top node * @return tuple of split and information gain */ @@ -1118,30 +651,31 @@ object DecisionTree extends Serializable with Logging { logDebug("node impurity = " + nodeImpurity) // Extract left right node aggregates. - val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) + val (leftNodeAgg, rightNodeAgg) = + extractLeftRightNodeAggregates(binData, datasetInfo, numBins) // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - val (bestFeatureIndex,bestSplitIndex, gainStats) = { + val (bestFeatureIndex, bestSplitIndex, gainStats) = { // Initialize with infeasible values. var bestFeatureIndex = Int.MinValue var bestSplitIndex = Int.MinValue var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) // Iterate over features. var featureIndex = 0 - while (featureIndex < numFeatures) { + while (featureIndex < datasetInfo.numFeatures) { // Iterate over all splits. var splitIndex = 0 val maxSplitIndex : Double = { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = datasetInfo.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { numBins - 1 } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) val isSpaceSufficientForAllCategoricalSplits = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + if (isMulticlass && isSpaceSufficientForAllCategoricalSplits) { math.pow(2.0, featureCategories - 1).toInt - 1 } else { // Binary classification featureCategories @@ -1168,43 +702,14 @@ object DecisionTree extends Serializable with Logging { (splits(bestFeatureIndex)(bestSplitIndex), gainStats) } - /** - * Get bin data for one node. - */ - def getBinDataForNode(node: Int): Array[Double] = { - strategy.algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - val shift = numClasses * node * numBins * numFeatures - val rightChildShift = numClasses * numBins * numFeatures * numNodes - val binsForNode = { - val leftChildData - = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - val rightChildData - = binAggregates.slice(rightChildShift + shift, - rightChildShift + shift + numClasses * numBins * numFeatures) - leftChildData ++ rightChildData - } - binsForNode - } else { - val shift = numClasses * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - binsForNode - } - case Regression => - val shift = 3 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) - binsForNode - } - } - - // Calculate best splits for all nodes at a given level + // Calculate best splits for all nodes in this (level, group). val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level var node = 0 while (node < numNodes) { val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift - val binsForNode: Array[Double] = getBinDataForNode(node) + val binsForNode: Array[Double] + = getBinDataForNode(node, binAggregates, datasetInfo, numNodes, numBins) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) logDebug("parent node impurity = " + parentNodeImpurity) @@ -1214,43 +719,47 @@ object DecisionTree extends Serializable with Logging { bestSplits } - private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int, - isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = { - algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - 2 * numClasses * numBins * numFeatures - } else { - numClasses * numBins * numFeatures - } - case Regression => 3 * numBins * numFeatures - } - } - /** - * Returns split and bins for decision tree calculation. - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree - * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree - * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache - * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) + * Returns splits and bins for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * There are 2^(maxFeatureValue - 1) - 1 splits. + * (b) For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one split per category. + + * Categorical case (a) features are called unordered features. + * Other cases are called ordered features. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return A tuple of (splits,bins). + * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] + * of size (numFeatures, numBins - 1). + * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + datasetInfo: DatasetInfo): (Array[Array[Split]], Array[Array[Bin]]) = { + val count = input.count() - // Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.size + val numFeatures = datasetInfo.numFeatures - val maxBins = strategy.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) - + val isMulticlass = datasetInfo.isMulticlass + logDebug("isMulticlass = " + isMulticlass) /* * Ensure #bins is always greater than the categories. For multiclass classification, @@ -1258,8 +767,8 @@ object DecisionTree extends Serializable with Logging { * It's a limitation of the current implementation but a reasonable trade-off since features * with large number of categories get favored over continuous features. */ - if (strategy.categoricalFeaturesInfo.size > 0) { - val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 + if (datasetInfo.categoricalFeaturesInfo.size > 0) { + val maxCategoriesForFeatures = datasetInfo.categoricalFeaturesInfo.maxBy(_._2)._2 require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + "in categorical features") } @@ -1271,14 +780,15 @@ object DecisionTree extends Serializable with Logging { logDebug("fraction of data used for calculating quantiles = " + fraction) // sampled input for RDD calculation - val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() + val sampledInput = + input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() val numSamples = sampledInput.length val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) - strategy.quantileCalculationStrategy match { - case Sort => + quantileStrategy match { + case QuantileStrategy.Sort => val splits = Array.ofDim[Split](numFeatures, numBins - 1) val bins = Array.ofDim[Bin](numFeatures, numBins) @@ -1288,7 +798,7 @@ object DecisionTree extends Serializable with Logging { var featureIndex = 0 while (featureIndex < numFeatures){ // Check whether the feature is continuous. - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = datasetInfo.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride: Double = numSamples.toDouble / numBins @@ -1299,18 +809,19 @@ object DecisionTree extends Serializable with Logging { splits(featureIndex)(index) = split } } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 // Use different bin/split calculation strategy for categorical features in multiclass - // classification that satisfy the space constraint - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { - // 2^(maxFeatureValue- 1) - 1 combinations + // classification that satisfy the space constraint. + val isUnorderedFeature = isMulticlass && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { + // 2^(maxFeatureValue - 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { - val categories: List[Double] - = extractMultiClassCategories(index + 1, featureCategories) + val categories: List[Double] = + DecisionTree.extractMultiClassCategories(index + 1, featureCategories) splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, categories) bins(featureIndex)(index) = { @@ -1330,29 +841,11 @@ object DecisionTree extends Serializable with Logging { } index += 1 } - } else { - - val centroidForCategories = { - if (isMulticlassClassification) { - // For categorical variables in multiclass classification, - // each bin is a category. The bins are sorted and they - // are ordered by calculating the impurity of their corresponding labels. - sampledInput.map(lp => (lp.features(featureIndex), lp.label)) - .groupBy(_._1) - .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) - .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum))) - } else { // regression or binary classification - // For categorical variables in regression and binary classification, - // each bin is a category. The bins are sorted and they - // are ordered by calculating the centroid of their corresponding labels. - sampledInput.map(lp => (lp.features(featureIndex), lp.label)) - .groupBy(_._1) - .mapValues(x => x.map(_._2).sum / x.map(_._1).length) - } - } + } else { // ordered feature + val centroidForCategories = + computeCentroidForCategories(featureIndex, sampledInput, datasetInfo) - logDebug("centriod for categories = " + centroidForCategories.mkString(",")) + logDebug("centroid for categories = " + centroidForCategories.mkString(",")) // Check for missing categorical variables and putting them last in the sorted list. val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() @@ -1367,7 +860,7 @@ object DecisionTree extends Serializable with Logging { // bins sorted by centroids val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + logDebug("centroid for categorical variable = " + categoriesSortedByCentroid) var categoriesForSplit = List[Double]() categoriesSortedByCentroid.iterator.zipWithIndex.foreach { @@ -1380,7 +873,7 @@ object DecisionTree extends Serializable with Logging { new Bin(new DummyCategoricalSplit(featureIndex, Categorical), splits(featureIndex)(0), Categorical, key) } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + new Bin(splits(featureIndex)(index - 1), splits(featureIndex)(index), Categorical, key) } } @@ -1393,35 +886,41 @@ object DecisionTree extends Serializable with Logging { // Find all bins. featureIndex = 0 while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = datasetInfo.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) for (index <- 1 until numBins - 1){ - val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + val bin = new Bin(splits(featureIndex)(index - 1), splits(featureIndex)(index), Continuous, Double.MinValue) bins(featureIndex)(index) = bin } - bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), + bins(featureIndex)(numBins - 1) = new Bin(splits(featureIndex)(numBins - 2), new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) } featureIndex += 1 } (splits,bins) - case MinMax => + case QuantileStrategy.MinMax => throw new UnsupportedOperationException("minmax not supported yet.") - case ApproxHist => + case QuantileStrategy.ApproxHist => throw new UnsupportedOperationException("approximate histogram not supported yet.") } } +} + + +@Experimental +object DecisionTree extends Serializable with Logging { + /** - * Nested method to extract list of eligible categories given an index. It extracts the - * position of ones in a binary representation of the input. If binary - * representation of an number is 01101 (13), the output list should (3.0, 2.0, + * Extract list of eligible categories given an index. + * It extracts the position of ones in a binary representation of the input. + * If binary representation of an number is 01101 (13), the output list should (3.0, 2.0, * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. */ - private[tree] def extractMultiClassCategories( + protected[tree] def extractMultiClassCategories( input: Int, maxFeatureValue: Int): List[Double] = { var categories = List[Double]() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala new file mode 100644 index 0000000000000..8099c4023f01b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala @@ -0,0 +1,574 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.annotation.Experimental +import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.DatasetInfo +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.DTClassifierParams +import org.apache.spark.mllib.tree.impurity.ClassificationImpurities +import org.apache.spark.mllib.tree.model.{InformationGainStats, Bin, DecisionTreeClassifierModel} +import org.apache.spark.rdd.RDD + + +/** + * :: Experimental :: + * A class that implements a decision tree algorithm for classification. + * It supports both continuous and categorical features. + * @param params The configuration parameters for the tree algorithm. + */ +@Experimental +class DecisionTreeClassifier (params: DTClassifierParams) extends DecisionTree(params) { + + private val impurityFunctor = ClassificationImpurities.impurity(params.impurity) + + /** + * Method to train a decision tree model over an RDD + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * @param datasetInfo Dataset metadata specifying number of classes, features, etc. + * @return a DecisionTreeClassifierModel that can be used for prediction + */ + def run( + input: RDD[LabeledPoint], + datasetInfo: DatasetInfo): DecisionTreeClassifierModel = { + + require(datasetInfo.isClassification) + logDebug("algo = Classification") + + val topNode = super.runSub(input, datasetInfo) + new DecisionTreeClassifierModel(topNode) + } + + protected def computeCentroidForCategories( + featureIndex: Int, + sampledInput: Array[LabeledPoint], + datasetInfo: DatasetInfo): Map[Double,Double] = { + if (datasetInfo.isMulticlass) { + // For categorical variables in multiclass classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the impurity of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) + .map(x => (x._1, x._2.values.toArray)) + .map(x => (x._1, impurityFunctor.calculate(x._2,x._2.sum))) + } else { // binary classification + // For categorical variables in binary classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + } + } + + /** + * Extracts left and right split aggregates. + * @param binData Aggregate array slice from getBinDataForNode. + * For unordered features, this is leftChildData ++ rightChildData, + * each of which is indexed by (feature, split/bin, class), + * with class being the least significant bit. + * For ordered features, this is of size numClasses * numBins * numFeatures. + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], + * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, + * (numBins - 1), numClasses) + */ + protected def extractLeftRightNodeAggregates( + binData: Array[Double], + datasetInfo: DatasetInfo, + numBins: Int): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + + def findAggForOrderedFeatureClassification( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + // shift for this featureIndex + val numClasses = datasetInfo.numClasses + val shift = numClasses * featureIndex * numBins + + var classIndex = 0 + while (classIndex < numClasses) { + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(classIndex) + = binData(shift + (numClasses * (numBins - 1)) + classIndex) + classIndex += 1 + } + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + var innerClassIndex = 0 + while (innerClassIndex < numClasses) { + leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) + = binData(shift + numClasses * splitIndex + innerClassIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = + binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) + innerClassIndex += 1 + } + splitIndex += 1 + } + } + + /** + * Reshape binData for this feature. + * Indexes binData as (feature, split, class) with class as the least significant bit. + * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value + */ + def findAggForUnorderedFeatureClassification( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + val rightChildShift = datasetInfo.numClasses * numBins * datasetInfo.numFeatures + var splitIndex = 0 + while (splitIndex < numBins - 1) { + var classIndex = 0 + while (classIndex < datasetInfo.numClasses) { + // shift for this featureIndex + val shift = + datasetInfo.numClasses * featureIndex * numBins + splitIndex * datasetInfo.numClasses + val leftBinValue = binData(shift + classIndex) + val rightBinValue = binData(rightChildShift + shift + classIndex) + leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue + rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue + classIndex += 1 + } + splitIndex += 1 + } + } + + // Initialize left and right split aggregates. + val leftNodeAgg = + Array.ofDim[Double](datasetInfo.numFeatures, numBins - 1, datasetInfo.numClasses) + val rightNodeAgg = + Array.ofDim[Double](datasetInfo.numFeatures, numBins - 1, datasetInfo.numClasses) + var featureIndex = 0 + while (featureIndex < datasetInfo.numFeatures) { + if (datasetInfo.isMulticlassWithCategoricalFeatures){ + val isFeatureContinuous = datasetInfo.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } + } + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } + featureIndex += 1 + } + + (leftNodeAgg, rightNodeAgg) + } + + /** + * Get number of values to be stored per node in the bin aggregate counts. + * @param datasetInfo Dataset metadata + * @param numBins Number of bins = 1 + number of possible splits. + * @return + */ + protected def getElementsPerNode( + datasetInfo: DatasetInfo, + numBins: Int): Int = { + if (datasetInfo.isMulticlassWithCategoricalFeatures) { + 2 * datasetInfo.numClasses * numBins * datasetInfo.numFeatures + } else { + datasetInfo.numClasses * numBins * datasetInfo.numFeatures + } + } + + /** + * Performs a sequential aggregation over a partition for classification. + * For l nodes, k features, + * either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes for ordered features, or + * 2 * numClasses * numBins * numFeatures * numNodes for unordered features + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return Array storing aggregate calculation, of size: + * + */ + protected def binSeqOpSub( + agg: Array[Double], + arr: Array[Double], + datasetInfo: DatasetInfo, + numNodes: Int, + bins: Array[Array[Bin]]): Array[Double] = { + val numBins = bins(0).length + if(datasetInfo.isMulticlassWithCategoricalFeatures) { + multiclassWithCategoricalBinSeqOp(arr, agg, datasetInfo, numNodes, bins) + } else { + binaryOrNoCategoricalBinSeqOp(arr, agg, datasetInfo, numNodes, numBins) + } + agg + } + + /** + * Calculates the information gain for all splits based upon left/right split aggregates. + * @param leftNodeAgg Left node aggregates: + * leftNodeAgg(feature)(split)(class) = weight of examples + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg Right node aggregates: + * rightNodeAgg(feature)(split)(class) = weight of examples + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + protected def calculateGainForSplit( + leftNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Array[Double]]], + topImpurity: Double, + datasetInfo: DatasetInfo, + level: Int): InformationGainStats = { + + val numClasses = datasetInfo.numClasses + + val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) + val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) + var leftTotalCount = leftCounts.sum + var rightTotalCount = rightCounts.sum + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val rootNodeCounts = new Array[Double](numClasses) + var classIndex = 0 + while (classIndex < numClasses) { + rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) + classIndex += 1 + } + impurityFunctor.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) + } + } + + val totalCount = leftTotalCount + rightTotalCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } + + // Sum of count for each label + val leftRightCounts: Array[Double] = + leftCounts.zip(rightCounts).map{ case (leftCount, rightCount) => leftCount + rightCount } + + def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) + else (maxIndex, maxValue, currentIndex + 1) + } + if (result._1 < 0) 0 else result._1 + } + + val predict = indexOfLargestArrayElement(leftRightCounts) + val prob = leftRightCounts(predict) / totalCount + + val leftImpurity = if (leftTotalCount == 0) { + topImpurity + } else { + impurityFunctor.calculate(leftCounts, leftTotalCount) + } + val rightImpurity = if (rightTotalCount == 0) { + topImpurity + } else { + impurityFunctor.calculate(rightCounts, rightTotalCount) + } + + val leftWeight = leftTotalCount / totalCount + val rightWeight = rightTotalCount / totalCount + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + } + + /** + * Get bin data for one node. + * + * @param node Node index in this (level, group). + * @param binAggregates For unordered features, + * the first half of binAggregates contains leftChildData, + * and the second half contains rightChildData. + * Each half is of size numNodes * numFeatures * numBins * numClasses. + * For ordered features, + * this is of size numNodes * numFeatures * numBins * numClasses. + * Indexing uses node as the most significant bit. + * @return For unordered features, returns leftChildData ++ rightChildData, + * each of which is indexed by (feature, bin/split, class), + * with class being the least significant bit. + * For ordered features, returns data of size numClasses * numBins * numFeatures. + */ + protected def getBinDataForNode( + node: Int, + binAggregates: Array[Double], + datasetInfo: DatasetInfo, + numNodes: Int, + numBins: Int): Array[Double] = { + if (datasetInfo.isMulticlassWithCategoricalFeatures) { + val shift = datasetInfo.numClasses * node * numBins * datasetInfo.numFeatures + val rightChildShift = datasetInfo.numClasses * numBins * datasetInfo.numFeatures * numNodes + val binsForNode = { + val leftChildData = binAggregates.slice( + shift, + shift + datasetInfo.numClasses * numBins * datasetInfo.numFeatures) + val rightChildData = binAggregates.slice( + rightChildShift + shift, + rightChildShift + shift + datasetInfo.numClasses * numBins * datasetInfo.numFeatures) + leftChildData ++ rightChildData + } + binsForNode + } else { + val shift = datasetInfo.numClasses * node * numBins * datasetInfo.numFeatures + val binsForNode = binAggregates.slice( + shift, + shift + datasetInfo.numClasses * numBins * datasetInfo.numFeatures) + binsForNode + } + } + + /** + * Increment aggregate in location for (node, feature, bin, label) + * to indicate that, for this (example, + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. + */ + private def updateBinForOrderedFeature( + arr: Array[Double], + agg: Array[Double], + nodeIndex: Int, + label: Double, + featureIndex: Int, + datasetInfo: DatasetInfo, + numBins: Int) = { + + // Find the bin index for this feature. + val arrIndex = 1 + datasetInfo.numFeatures * nodeIndex + featureIndex + // Update the left or right count for one bin. + val aggShift = datasetInfo.numClasses * numBins * datasetInfo.numFeatures * nodeIndex + val aggIndex = aggShift + datasetInfo.numClasses * featureIndex * numBins + + arr(arrIndex).toInt * datasetInfo.numClasses + agg(aggIndex + label.toInt) += 1 + } + + /** + * + * @param arr arr(0) = label. + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) + * @param agg Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param rightChildShift + * @param bins + */ + private def updateBinForUnorderedFeature( + arr: Array[Double], + agg: Array[Double], + nodeIndex: Int, + featureIndex: Int, + label: Double, + rightChildShift: Int, + datasetInfo: DatasetInfo, + numBins: Int, + bins: Array[Array[Bin]]) = { + + // Find the bin index for this feature. + val arrIndex = 1 + datasetInfo.numFeatures * nodeIndex + featureIndex + val featureValue = arr(arrIndex) + // Update the left or right count for one bin. + val aggShift = + nodeIndex * datasetInfo.numFeatures * numBins * datasetInfo.numClasses + + featureIndex * numBins * datasetInfo.numClasses + + label.toInt + // Find all matching bins and increment their values + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + var binIndex = 0 + while (binIndex < numCategoricalBins) { + val aggIndex = aggShift + binIndex * datasetInfo.numClasses + if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { + agg(aggIndex) += 1 + } else { + agg(rightChildShift + aggIndex) += 1 + } + binIndex += 1 + } + } + + /** + * Helper for binSeqOp + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes + * @param datasetInfo + * @param numNodes + * @param numBins + */ + private def binaryOrNoCategoricalBinSeqOp( + arr: Array[Double], + agg: Array[Double], + datasetInfo: DatasetInfo, + numNodes: Int, + numBins: Int) = { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + datasetInfo.numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < datasetInfo.numFeatures) { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, datasetInfo, numBins) + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Helper for binSeqOp. + * + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation of size + * numClasses * numBins * numFeatures * numNodes + * // Size set by getElementsPerNode(): + * // 2 * numClasses * numBins * numFeatures * numNodes + * SHOULD BE indexed by (node, feature, bin, class), + * with class being the least significant bit. (based on future use) + * @param datasetInfo Dataset metadata. + * @param numNodes Number of nodes in this (level, group). + * @param bins + */ + private def multiclassWithCategoricalBinSeqOp( + arr: Array[Double], + agg: Array[Double], + datasetInfo: DatasetInfo, + numNodes: Int, + bins: Array[Array[Bin]]) = { + val numBins = bins(0).length + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + datasetInfo.numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + val rightChildShift = datasetInfo.numClasses * numBins * datasetInfo.numFeatures * numNodes + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < datasetInfo.numFeatures) { + val isFeatureContinuous = datasetInfo.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, datasetInfo, + numBins) + } else { + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + updateBinForUnorderedFeature(arr, agg, nodeIndex, featureIndex, label, + rightChildShift, datasetInfo, numBins, bins) + } else { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, datasetInfo, + numBins) + } + } + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + +} + + +@Experimental +object DecisionTreeClassifier extends Serializable with Logging { + + /** + * Get a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTreeClassifier]]. + */ + def defaultParams(): DTClassifierParams = { + new DTClassifierParams() + } + + /** + * Train a decision tree model for binary or multiclass classification, + * using the default set of learning parameters. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param datasetInfo Dataset metadata (number of features, number of classes, etc.) + * @return DecisionTreeClassifierModel which can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + datasetInfo: DatasetInfo): DecisionTreeClassifierModel = { + require(datasetInfo.numClasses >= 2) + new DecisionTreeClassifier(new DTClassifierParams()).run(input, datasetInfo) + } + + /** + * Train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param datasetInfo Dataset metadata (number of features, number of classes, etc.) + * @param params The configuration parameters for the tree learning algorithm + * (tree depth, quantile calculation strategy, etc.) + * @return DecisionTreeClassifierModel which can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + datasetInfo: DatasetInfo, + params: DTClassifierParams): DecisionTreeClassifierModel = { + require(datasetInfo.numClasses >= 2) + new DecisionTreeClassifier(params).run(input, datasetInfo) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala new file mode 100644 index 0000000000000..98b6e6dde3894 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.annotation.Experimental +import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.DatasetInfo +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{QuantileStrategies, DTRegressorParams} +import org.apache.spark.mllib.tree.impurity.{RegressionImpurities, RegressionImpurity} +import org.apache.spark.mllib.tree.model.{InformationGainStats, Bin, DecisionTreeRegressorModel} +import org.apache.spark.rdd.RDD + + +/** + * :: Experimental :: + * A class that implements a decision tree algorithm for regression. + * It supports both continuous and categorical features. + * @param params The configuration parameters for the tree algorithm. + */ +@Experimental +class DecisionTreeRegressor (params: DTRegressorParams) extends DecisionTree(params) { + + private val impurityFunctor = RegressionImpurities.impurity(params.impurity) + + /** + * Method to train a decision tree model over an RDD + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param datasetInfo Dataset metadata specifying number of classes, features, etc. + * @return a DecisionTreeRegressorModel that can be used for prediction + */ + def run( + input: RDD[LabeledPoint], + datasetInfo: DatasetInfo): DecisionTreeRegressorModel = { + + require(datasetInfo.isRegression) + logDebug("algo = Regression") + + val topNode = super.runSub(input, datasetInfo) + + new DecisionTreeRegressorModel(topNode) + } + + protected def computeCentroidForCategories( + featureIndex: Int, + sampledInput: Array[LabeledPoint], + datasetInfo: DatasetInfo): Map[Double,Double] = { + // For categorical variables in regression, each bin is a category. + // The bins are sorted and ordered by calculating the centroid of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + } + + /** + * Extracts left and right split aggregates. + * @param binData Array[Double] of size 2 * numFeatures * numBins + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], + * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, + * (numBins - 1), 3) + */ + protected def extractLeftRightNodeAggregates( + binData: Array[Double], + datasetInfo: DatasetInfo, + numBins: Int): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](datasetInfo.numFeatures, numBins - 1, 3) + val rightNodeAgg = Array.ofDim[Double](datasetInfo.numFeatures, numBins - 1, 3) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < datasetInfo.numFeatures) { + // shift for this featureIndex + val shift = 3 * featureIndex * numBins + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(0) = + binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(numBins - 2)(1) = + binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(numBins - 2)(2) = + binData(shift + (3 * (numBins - 1)) + 2) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + var i = 0 // index for regression histograms + while (i < 3) { // count, sum, sum^2 + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) + + leftNodeAgg(featureIndex)(splitIndex - 1)(i) + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) = + binData(shift + (3 * (numBins - 1 - splitIndex) + i)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i) + i += 1 + } + splitIndex += 1 + } + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + } + + protected def getElementsPerNode( + datasetInfo: DatasetInfo, + numBins: Int): Int = { + 3 * numBins * datasetInfo.numFeatures + } + + /** + * Performs a sequential aggregation of bins stats over a partition for regression. + * For l nodes, k features, + * the count, sum, sum of squares of one of the p bins is incremented. + * + * @param agg Array[Double] storing aggregate calculation of size + * 3 * numBins * numFeatures * numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 3 * numBins * numFeatures * numNodes for regression + */ + protected def binSeqOpSub( + agg: Array[Double], + arr: Array[Double], + datasetInfo: DatasetInfo, + numNodes: Int, + bins: Array[Array[Bin]]): Array[Double] = { + val numBins = bins(0).length + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + datasetInfo.numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < datasetInfo.numFeatures) { + // Find the bin index for this feature. + val arrShift = 1 + datasetInfo.numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update count, sum, and sum^2 for one bin. + val aggShift = 3 * numBins * datasetInfo.numFeatures * nodeIndex + val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 + agg(aggIndex) = agg(aggIndex) + 1 + agg(aggIndex + 1) = agg(aggIndex + 1) + label + agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + featureIndex += 1 + } + } + nodeIndex += 1 + } + agg + } + + /** + * Calculates the information gain for all splits based upon left/right split aggregates. + * @param leftNodeAgg left node aggregates + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg right node aggregate + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + protected def calculateGainForSplit( + leftNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Array[Double]]], + topImpurity: Double, + datasetInfo: DatasetInfo, + level: Int): InformationGainStats = { + + val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) + val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) + val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) + + val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) + val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) + val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val count = leftCount + rightCount + val sum = leftSum + rightSum + val sumSquares = leftSumSquares + rightSumSquares + impurityFunctor.calculate(count, sum, sumSquares) + } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, + rightSum / rightCount) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity ,topImpurity, + Double.MinValue, leftSum / leftCount) + } + + val leftImpurity = impurityFunctor.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = impurityFunctor.calculate(rightCount, rightSum, rightSumSquares) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + val predict = (leftSum + rightSum) / (leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + } + + /** + * Get bin data for one node. + */ + protected def getBinDataForNode( + node: Int, + binAggregates: Array[Double], + datasetInfo: DatasetInfo, + numNodes: Int, + numBins: Int): Array[Double] = { + val shift = 3 * node * numBins * datasetInfo.numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * datasetInfo.numFeatures) + binsForNode + } + +} + + +@Experimental +object DecisionTreeRegressor extends Serializable with Logging { + + /** + * Get a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTreeRegressor]]. + */ + def defaultParams(): DTRegressorParams = { + new DTRegressorParams() + } + + /** + * Train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should be real values. + * @param datasetInfo Dataset metadata (number of features, number of classes, etc.) + * @return DecisionTreeRegressorModel which can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + datasetInfo: DatasetInfo): DecisionTreeRegressorModel = { + new DecisionTreeRegressor(new DTRegressorParams()).run(input, datasetInfo) + } + + /** + * Train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should be real values. + * @param datasetInfo Dataset metadata (number of features, number of classes, etc.) + * @param params The configuration parameters for the tree learning algorithm + * (tree depth, quantile calculation strategy, etc.) + * @return DecisionTreeRegressorModel which can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + datasetInfo: DatasetInfo, + params: DTRegressorParams = new DTRegressorParams()): DecisionTreeRegressorModel = { + new DecisionTreeRegressor(params).run(input, datasetInfo) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala new file mode 100644 index 0000000000000..eec79b9f89b8c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.tree.impurity.ClassificationImpurities + +/** + * :: Experimental :: + * Stores all the configuration options for DecisionTreeClassifier construction + * @param impurity Criterion used for information gain calculation. + * Currently supported: "gini", "entropy" + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param maxBins maximum number of bins used for splitting features + * @param quantileStrategy algorithm for calculating quantiles + * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * 128 MB. + */ +@Experimental +class DTClassifierParams ( + var impurity: String = "gini", + maxDepth: Int = 4, + maxBins: Int = 100, + quantileStrategy: String = "sort", + maxMemoryInMB: Int = 128) + extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { + + def getImpurity: String = this.impurity + + def setImpurity(impurity: String) = { + if (!ClassificationImpurities.nameToImpurityMap.contains(impurity)) { + throw new IllegalArgumentException(s"Bad impurity parameter for classification: $impurity" + + s" Supported values: ${DTClassifierParams.supportedImpurities.mkString(", ")}.") + } + this.impurity = impurity + } + +} + +@Experimental +object DTClassifierParams { + + /** + * List of supported impurity options. + */ + def supportedImpurities: List[String] = ClassificationImpurities.names + + /** + * Get list of supported quantileStrategy options. + */ + def supportedQuantileStrategies: List[String] = DTParams.supportedQuantileStrategies + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala new file mode 100644 index 0000000000000..7b3ae5897d2ee --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import scala.beans.BeanProperty + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Stores configuration options for DecisionTree construction. + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param maxBins maximum number of bins used for splitting features + * @param quantileStrategy algorithm for calculating quantiles + * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * 128 MB. + */ +@Experimental +private[mllib] abstract class DTParams ( + @BeanProperty var maxDepth: Int, + @BeanProperty var maxBins: Int, + var quantileStrategy: String, + @BeanProperty var maxMemoryInMB: Int) extends Serializable { + + def getQuantileStrategy: String = this.quantileStrategy + + def setQuantileStrategy(quantileStrategy: String) = { + if (!QuantileStrategies.nameToStrategyMap.contains(quantileStrategy)) { + throw new IllegalArgumentException(s"Bad quantileStrategy parameter: $quantileStrategy." + + s" Supported values: ${DTParams.supportedQuantileStrategies.mkString(", ")}.") + } + this.quantileStrategy = quantileStrategy + } + +} + + +@Experimental +object DTParams { + + /** + * Get list of supported quantileStrategy options. + */ + def supportedQuantileStrategies: List[String] = QuantileStrategies.nameToStrategyMap.keys.toList + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala new file mode 100644 index 0000000000000..640a6af64b8a0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.tree.impurity.RegressionImpurities + +/** + * :: Experimental :: + * Stores all the configuration options for DecisionTreeRegressor construction + * @param impurity Criterion used for information gain calculation. + * Currently supported: "variance" + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param maxBins maximum number of bins used for splitting features + * @param quantileStrategy algorithm for calculating quantiles + * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * 128 MB. + */ +@Experimental +class DTRegressorParams ( + var impurity: String = "variance", + maxDepth: Int = 4, + maxBins: Int = 100, + quantileStrategy: String = "sort", + maxMemoryInMB: Int = 128) + extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { + + def getImpurity: String = this.impurity + + def setImpurity(impurity: String) = { + if (!RegressionImpurities.nameToImpurityMap.contains(impurity)) { + throw new IllegalArgumentException(s"Bad impurity parameter for regression: $impurity" + + s" Supported values: ${DTRegressorParams.supportedImpurities.mkString(", ")}.") + } + this.impurity = impurity + } + +} + +@Experimental +object DTRegressorParams { + + /** + * List of supported impurity options. + */ + def supportedImpurities: List[String] = RegressionImpurities.names + + /** + * Get list of supported quantileStrategy options. + */ + def supportedQuantileStrategies: List[String] = DTParams.supportedQuantileStrategies + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 7da976e55a722..c7f0ae3beea76 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -24,7 +24,36 @@ import org.apache.spark.annotation.Experimental * Enum for selecting the quantile calculation strategy */ @Experimental -object QuantileStrategy extends Enumeration { +private[mllib] object QuantileStrategy extends Enumeration { type QuantileStrategy = Value val Sort, MinMax, ApproxHist = Value } + +/** + * :: Experimental :: + * Factory for creating [[org.apache.spark.mllib.tree.configuration.QuantileStrategy]] instances. + */ +@Experimental +private[mllib] object QuantileStrategies { + + import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ + + /** + * Mapping used for strategy names. + * If you add a new strategy type, add it here. + */ + val nameToStrategyMap: Map[String, QuantileStrategy] = Map( + "sort" -> Sort) + + /** + * Given a string with the name of a quantile strategy, get the QuantileStrategy type. + */ + def strategy(name: String): QuantileStrategy = { + if (nameToStrategyMap.contains(name)) { + nameToStrategyMap(name) + } else { + throw new IllegalArgumentException(s"Bad QuantileStrategy parameter: $name") + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala deleted file mode 100644 index 7c027ac2fda6b..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.configuration - -import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.impurity.Impurity -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ - -/** - * :: Experimental :: - * Stores all the configuration options for tree construction - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. - * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is - * 128 MB. - * - */ -@Experimental -class Strategy ( - val algo: Algo, - val impurity: Impurity, - val maxDepth: Int, - val numClassesForClassification: Int = 2, - val maxBins: Int = 100, - val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128) extends Serializable { - - require(numClassesForClassification >= 2) - val isMulticlassClassification = numClassesForClassification > 2 - val isMulticlassWithCategoricalFeatures - = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurities.scala new file mode 100644 index 0000000000000..d415f8b718d56 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurities.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Factory class for constructing a [[org.apache.spark.mllib.tree.impurity.ClassificationImpurity]] + * type based on its name. + */ +@Experimental +private[mllib] object ClassificationImpurities { + + /** + * Mapping used for impurity names, used for parsing impurity settings. + * If you add a new impurity class, add it here. + */ + val nameToImpurityMap: Map[String, ClassificationImpurity] = Map( + "gini" -> Gini, + "entropy" -> Entropy) + + val names: List[String] = nameToImpurityMap.keys.toList + + /** + * Given impurity name, return type. + */ + def impurity(name: String): ClassificationImpurity = { + if (nameToImpurityMap.contains(name)) { + nameToImpurityMap(name) + } else { + throw new IllegalArgumentException(s"Bad impurity parameter for classification: $name") + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurity.scala new file mode 100644 index 0000000000000..1658cc38b806b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurity.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.{DeveloperApi, Experimental} + +/** + * :: Experimental :: + * Trait for calculating information gain for classification. + */ +@Experimental +private[mllib] trait ClassificationImpurity extends Serializable { + + /** + * :: DeveloperApi :: + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value, or 0 if count = 0 + */ + @DeveloperApi + def calculate(counts: Array[Double], totalCount: Double): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index a0e2d91762782..42d6eb4a08a44 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -25,19 +25,22 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * binary classification. */ @Experimental -object Entropy extends Impurity { +private[mllib] object Entropy extends ClassificationImpurity { - private[tree] def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + private def log2(x: Double) = scala.math.log(x) / scala.math.log(2) /** * :: DeveloperApi :: * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 0.0 var classIndex = 0 @@ -52,14 +55,4 @@ object Entropy extends Impurity { impurity } - /** - * :: DeveloperApi :: - * variance calculation - * @param count number of instances - * @param sum sum of labels - * @param sumSquares summation of squares of the labels - */ - @DeveloperApi - override def calculate(count: Double, sum: Double, sumSquares: Double): Double = - throw new UnsupportedOperationException("Entropy.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 48144b5e6d1e4..f91c4c08f4439 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -26,17 +26,20 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * during binary classification. */ @Experimental -object Gini extends Impurity { +private[mllib] object Gini extends ClassificationImpurity { /** * :: DeveloperApi :: * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 1.0 var classIndex = 0 @@ -48,14 +51,4 @@ object Gini extends Impurity { impurity } - /** - * :: DeveloperApi :: - * variance calculation - * @param count number of instances - * @param sum sum of labels - * @param sumSquares summation of squares of the labels - */ - @DeveloperApi - override def calculate(count: Double, sum: Double, sumSquares: Double): Double = - throw new UnsupportedOperationException("Gini.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurities.scala new file mode 100644 index 0000000000000..7100026c35e8c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurities.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Factory class for constructing a [[org.apache.spark.mllib.tree.impurity.RegressionImpurity]] + * type based on its name. + */ +@Experimental +private[mllib] object RegressionImpurities { + + /** + * Mapping used for impurity names, used for parsing impurity settings. + * If you add a new impurity class, add it here. + */ + val nameToImpurityMap: Map[String, RegressionImpurity] = Map( + "variance" -> Variance) + + val names: List[String] = nameToImpurityMap.keys.toList + + /** + * Given impurity name, return type. + */ + def impurity(name: String): RegressionImpurity = { + if (nameToImpurityMap.contains(name)) { + nameToImpurityMap(name) + } else { + throw new IllegalArgumentException(s"Bad impurity parameter for regression: $name") + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala similarity index 76% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala rename to mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala index 7b2a9320cc21d..71f075b02a22b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala @@ -24,17 +24,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * Trait for calculating information gain. */ @Experimental -trait Impurity extends Serializable { - - /** - * :: DeveloperApi :: - * information calculation for multiclass classification - * @param counts Array[Double] with counts for each label - * @param totalCount sum of counts for all labels - * @return information value - */ - @DeveloperApi - def calculate(counts: Array[Double], totalCount: Double): Double +private[mllib] trait RegressionImpurity extends Serializable { /** * :: DeveloperApi :: @@ -42,8 +32,9 @@ trait Impurity extends Serializable { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels - * @return information value + * @return information value, or 0 if count = 0 */ @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 97149a99ead59..0e561f6a681b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -24,28 +24,21 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * Class for calculating variance during regression */ @Experimental -object Variance extends Impurity { +private[mllib] object Variance extends RegressionImpurity { /** * :: DeveloperApi :: - * information calculation for multiclass classification - * @param counts Array[Double] with counts for each label - * @param totalCount sum of counts for all labels - * @return information value - */ - @DeveloperApi - override def calculate(counts: Array[Double], totalCount: Double): Double = - throw new UnsupportedOperationException("Variance.calculate") - - /** - * :: DeveloperApi :: - * variance calculation + * information calculation for regression * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return variance, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + if (count == 0) { + return 0 + } val squaredLoss = sumSquares - (sum * sum) / count squaredLoss / count } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala new file mode 100644 index 0000000000000..4dc8661b7f144 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.Experimental + + +/** + * :: Experimental :: + * Decision tree model for classification. + * This model stores learned parameters. + * @param topNode root node + */ +@Experimental +class DecisionTreeClassifierModel(topNode: Node) extends DecisionTreeModel(topNode) { + + /** + * Print full model. + */ + override def toString: String = { + s"DecisionTreeClassifierModel\n" + topNode.toStringRecursive(2) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index bf692ca8c4bd7..5083a4df99935 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector @@ -26,10 +25,9 @@ import org.apache.spark.mllib.linalg.Vector * :: Experimental :: * Model to store the decision tree parameters * @param topNode root node - * @param algo algorithm type -- classification or regression */ @Experimental -class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { +class DecisionTreeModel(val topNode: Node) extends Serializable { /** * Predict values for a single data point using the model trained. @@ -50,4 +48,20 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } + + /** + * Get number of nodes in tree, including leaf nodes. + */ + def numNodes: Int = { + topNode.numNodesRecursive + } + + /** + * Get depth of tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + def depth: Int = { + topNode.depthRecursive + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala similarity index 70% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala rename to mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala index 79a01f58319e8..ebe4da5a7a81d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala @@ -15,16 +15,25 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.configuration +package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.Experimental + /** * :: Experimental :: - * Enum to select the algorithm for the decision tree + * Decision tree model for regression. + * This model stores learned parameters. + * @param topNode root node */ @Experimental -object Algo extends Enumeration { - type Algo = Value - val Classification, Regression = Value +class DecisionTreeRegressorModel(topNode: Node) extends DecisionTreeModel(topNode) { + + /** + * Print full model. + */ + override def toString: String = { + s"DecisionTreeRegressorModel\n" + topNode.toStringRecursive(2) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 682f213f411a7..1b4146fcf9459 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -26,7 +26,9 @@ import org.apache.spark.mllib.linalg.Vector * :: DeveloperApi :: * Node in a decision tree * @param id integer node id - * @param predict predicted value at the node + * @param predict Predicted value at the node. + * For classification, this is a class label in 0,1,.... + * For regression, this is a real value. * @param isLeaf whether the leaf is a node * @param split split to calculate left and right nodes * @param leftNode left child @@ -91,4 +93,59 @@ class Node ( } } } + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + def toStringRecursive(indentFactor: Int = 0): String = { + + def splitToString(split: Split, left: Boolean) : String = { + split.featureType match { + case Continuous => if (left) { + s"(feature ${split.feature} <= ${split.threshold})" + } else { + s"(feature ${split.feature} > ${split.threshold})" + } + case Categorical => if (left) { + s"(feature ${split.feature} in ${split.categories})" + } else { + s"(feature ${split.feature} not in ${split.categories})" + } + } + } + val prefix: String = " " * indentFactor + if (isLeaf) { + prefix + s"Predict: $predict\n" + } else { + prefix + s"If ${splitToString(split.get, left=true)}\n" + + leftNode.get.toStringRecursive(indentFactor + 1) + + prefix + s"Else ${splitToString(split.get, left=false)}\n" + + rightNode.get.toStringRecursive(indentFactor + 1) + } + } + + /** + * Get number of nodes in tree from this node, including leaf nodes. + */ + def numNodesRecursive: Int = { + if (isLeaf) { + 1 + } else { + 1 + leftNode.get.numNodesRecursive + rightNode.get.numNodesRecursive + } + } + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. + */ + def depthRecursive: Int = { + if (isLeaf) { + 0 + } else { + 1 + math.max(leftNode.get.depthRecursive, rightNode.get.depthRecursive) + } + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5961a618c59d9..33055cac23f75 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,26 +17,57 @@ package org.apache.spark.mllib.tree +import scala.collection.JavaConversions._ + +import org.apache.spark.mllib.rdd.DatasetInfo import org.scalatest.FunSuite -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} +import org.apache.spark.mllib.tree.configuration.{FeatureType, DTClassifierParams, DTRegressorParams} +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.model.DecisionTreeClassifierModel import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.model.Split -import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { + private def getNumFeatures(data: Array[LabeledPoint]) = data.size match { + case 0 => 0 + case _ => data(0).features.size + } + + def validateModel( + model: DecisionTreeClassifierModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map { x => model.predict(x.features) } + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy) + } + + private def defaultClassifierParams: DTClassifierParams = { + new DTClassifierParams("gini", maxDepth = 2, maxBins = 100) + } + + private def defaultRegressorParams: DTRegressorParams = { + new DTRegressorParams("variance", maxDepth = 2, maxBins = 100) + } + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -47,14 +78,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 3, - numClassesForClassification = 2, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -127,14 +157,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 3, - numClassesForClassification = 2, - maxBins = 100, + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr), categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) // Check splits. @@ -244,14 +273,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 3, - numClassesForClassification = 100, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 100, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -338,14 +366,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() assert(arr.length === 3000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 3, - numClassesForClassification = 100, - maxBins = 100, + val datasetInfo = new DatasetInfo( + numClasses = 100, + numFeatures = getNumFeatures(arr), categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) // 2^10 - 1 > 100, so categorical variables will be ordered @@ -393,15 +420,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - numClassesForClassification = 2, - maxDepth = 3, - maxBins = 100, + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr), categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(7), 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -421,14 +447,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Regression, - Variance, - maxDepth = 3, - maxBins = 100, + val datasetInfo = new DatasetInfo( + numClasses = 0, + numFeatures = getNumFeatures(arr), categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + + val dtLearner = new DecisionTreeRegressor(defaultRegressorParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(7), 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -447,8 +473,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -456,7 +487,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(7), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -470,8 +501,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -479,7 +515,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, Array(0.0), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -494,8 +530,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtParams = defaultClassifierParams + dtParams.impurity = "entropy" + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -503,7 +546,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, Array(0.0), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -518,8 +561,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtParams = defaultClassifierParams + dtParams.impurity = "entropy" + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -527,7 +577,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, Array(0.0), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -542,8 +592,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtParams = defaultClassifierParams + dtParams.impurity = "entropy" + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -557,7 +614,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, parentImpurities, 1, filters, splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -565,7 +622,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, + val bestSplitsWithGroups = dtLearner.findBestSplits(rdd, datasetInfo, parentImpurities, 1, filters, splits, bins, 0) assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) @@ -586,12 +643,52 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val rdd = sc.parallelize(arr) + val datasetInfo = new DatasetInfo( + numClasses = 3, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(datasetInfo.isMulticlass) + + val dtParams = defaultClassifierParams + dtParams.maxDepth = 4 + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(1)) + assert(bestSplit.featureType === Categorical) + } + + test("stump with categorical variables for multiclass classification, with just enough bins") { + val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() + val rdd = sc.parallelize(arr) + val datasetInfo = new DatasetInfo( + numClasses = 3, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(datasetInfo.isMulticlass) + + val dtParams = defaultClassifierParams + dtParams.maxDepth = 4 + dtParams.maxBins = maxBins + val dtLearner = new DecisionTreeClassifier(dtParams) + + val model = dtLearner.run(rdd, datasetInfo) + validateModel(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -600,18 +697,31 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.categories.length === 1) assert(bestSplit.categories.contains(1)) assert(bestSplit.featureType === Categorical) + val gain = bestSplits(0)._2 + assert(gain.leftImpurity == 0) + assert(gain.rightImpurity == 0) } test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClassesForClassification = 3) - assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val rdd = sc.parallelize(arr) + val datasetInfo = new DatasetInfo( + numClasses = 3, + numFeatures = getNumFeatures(arr)) + assert(datasetInfo.isMulticlass) + + val dtParams = defaultClassifierParams + dtParams.maxDepth = 4 + val dtLearner = new DecisionTreeClassifier(dtParams) + + val model = dtLearner.run(rdd, datasetInfo) + validateModel(model, arr, 0.9) + + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, Array[List[Filter]](), splits, bins, 10) + assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -624,12 +734,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) - assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val rdd = sc.parallelize(arr) + val datasetInfo = new DatasetInfo( + numClasses = 3, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 3)) + assert(datasetInfo.isMulticlass) + + val dtParams = defaultClassifierParams + dtParams.maxDepth = 4 + val dtLearner = new DecisionTreeClassifier(dtParams) + + val model = dtLearner.run(rdd, datasetInfo) + validateModel(model, arr, 0.9) + + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -643,12 +763,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) - assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val rdd = sc.parallelize(arr) + val datasetInfo = new DatasetInfo( + numClasses = 3, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + assert(datasetInfo.isMulticlass) + + val dtParams = defaultClassifierParams + dtParams.maxDepth = 4 + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -696,6 +823,14 @@ object DecisionTreeSuite { arr } + def generateCategoricalDataPointsAsList(): (java.util.List[LabeledPoint], DatasetInfo) = { + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = 2, + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) + (seqAsJavaList(generateCategoricalDataPoints()), datasetInfo) + } + def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5ff88f0dd1cac..f138335ebb373 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -91,11 +91,21 @@ object MimaExcludes { MimaBuild.excludeSparkClass("storage.Entry") ++ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ Seq( - ProblemFilters.exclude[IncompatibleMethTypeProblem]( + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.mllib.tree.DecisionTree"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.log2"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.tree.configuration.Algo"), + ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.tree.impurity.Gini.calculate"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Gini.calculate"), + ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), + ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.tree.impurity.Variance.calculate") ) case v if v.startsWith("1.0") =>