From 929f0e648962fd0e0529ac2f40452c7302eed733 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 18 Jul 2014 14:34:59 -0700 Subject: [PATCH 01/20] updating DT APIf --- .../spark/examples/mllib/DTRunnerJKB.scala | 212 +++++ .../examples/mllib/DecisionTreeRunner.scala | 12 +- .../spark/mllib/rdd/DatasetMetadata.scala | 51 ++ .../spark/mllib/tree/DecisionTree.scala | 827 +++++++----------- .../mllib/tree/DecisionTreeClassifier.scala | 602 +++++++++++++ .../mllib/tree/DecisionTreeRegressor.scala | 278 ++++++ .../configuration/DTClassifierParams.scala | 48 + .../mllib/tree/configuration/DTParams.scala | 40 + .../configuration/DTRegressorParams.scala | 48 + .../impurity/ClassificationImpurity.scala | 39 + .../spark/mllib/tree/impurity/Entropy.scala | 32 +- .../spark/mllib/tree/impurity/Gini.scala | 29 +- .../spark/mllib/tree/impurity/Impurity.scala | 10 +- .../tree/impurity/RegressionImpurity.scala | 39 + .../spark/mllib/tree/impurity/Variance.scala | 13 +- .../apache/spark/mllib/tree/model/Bin.scala | 2 +- .../model/DecisionTreeClassifierModel.scala | 42 + .../mllib/tree/model/DecisionTreeModel.scala | 5 +- .../model/DecisionTreeRegressorModel.scala | 39 + .../tree/model/InformationGainStats.scala | 8 +- .../apache/spark/mllib/tree/model/Node.scala | 39 +- 21 files changed, 1857 insertions(+), 558 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetMetadata.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurity.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala new file mode 100644 index 0000000000000..d38e946aeed98 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree, impurity} +import org.apache.spark.mllib.tree.configuration.{Algo, DTParams} +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD + +/** + * An example runner for decision tree. Run with + * {{{ + * ./bin/spark-example org.apache.spark.examples.mllib.DTRunnerJKB [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DTRunnerJKB { + + object ImpurityType extends Enumeration { + type ImpurityType = Value + val Gini, Entropy, Variance = Value + } + + import ImpurityType._ + + case class Params( + input: String = null, + dataFormat: String = null, + algo: Algo = Classification, + maxDepth: Int = 5, + impurity: ImpurityType = Gini, + maxBins: Int = 100, + fracTest: Double = 0.2) + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DTRunnerJKB") { + head("DTRunnerJKB: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = Algo.withName(x))) + opt[String]("impurity") + .text(s"impurity type (${ImpurityType.values.mkString(",")}), " + + s"default: ${defaultParams.impurity}") + .action((x, c) => c.copy(impurity = ImpurityType.withName(x))) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + arg[String]("") + .text("input paths to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + arg[String]("") + .text("data format: dense/libsvm") + .required() + .action((x, c) => c.copy(dataFormat = x)) + checkConfig { params => + if ( + (params.algo == Classification && + !(params.impurity == Gini || params.impurity == Entropy)) || + (params.algo == Regression && !(params.impurity == Variance))) { + failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + } + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } + success + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName("DTRunnerJKB") + val sc = new SparkContext(conf) + + // Load training data and cache it. + val origExamples = params.dataFormat match { + case "dense" => MLUtils.loadLabeledData(sc, params.input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache() + } + val (examples, numClasses) = params.algo match { + case Classification => { + // classCounts: class --> # examples in class + val classCounts = origExamples.map(_.label).countByValue + val numClasses = classCounts.size + // Re-index classes if needed. + // classIndex: class --> index in 0,...,numClasses-1 + val classIndex = { + if (classCounts.keySet != Set[Double](0.0, 1.0)) { + classCounts.keys.toList.sorted.zipWithIndex.toMap + } else { + Map[Double, Int]() + } + } + val examples = { + if (classIndex.isEmpty) { + origExamples + } else { + origExamples.map(lp => LabeledPoint(classIndex(lp.label), lp.features)) + } + } + println(s"numClasses = $numClasses.") + println(s"Per-class example fractions, counts:") + println(s"Class\tFrac\tCount") + classCounts.keys.toList.sorted.foreach(c => { + val frac = classCounts(c) / (0.0 + examples.count()) + println(s"$c\t$frac\t${classCounts(c)}") + }) + (examples, numClasses) + } + case Regression => { + (origExamples, 2) + } + } + + // Split into training, test. + val splits = examples.randomSplit(Array(0.8, 0.2)) + val training = splits(0).cache() + val test = splits(1).cache() + val numTraining = training.count() + val numTest = test.count() + + println(s"numTraining = $numTraining, numTest = $numTest.") + + examples.unpersist(blocking = false) + + val impurityCalculator = params.impurity match { + case Gini => impurity.Gini + case Entropy => impurity.Entropy + case Variance => impurity.Variance + } + + val strategy + = new DTParams( + algo = params.algo, + impurity = impurityCalculator, + maxDepth = params.maxDepth, + maxBins = params.maxBins, + numClassesForClassification = numClasses) + val model = DecisionTree.train(training, strategy) + model.print() + + if (params.algo == Classification) { + val accuracy = accuracyScore(model, test) + println(s"Test accuracy = $accuracy.") + } + + if (params.algo == Regression) { + val mse = meanSquaredError(model, test) + println(s"Test mean squared error = $mse.") + } + + sc.stop() + } + + /** + * Calculates the classifier accuracy. + */ + private def accuracyScore( + model: DecisionTreeModel, + data: RDD[LabeledPoint]): Double = { + val correctCount = data.filter(y => model.predict(y.features) == y.label).count() + val count = data.count() + correctCount.toDouble / count + } + + /** + * Calculates the mean squared error for regression. + */ + private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = tree.predict(y.features) - y.label + err * err + }.mean() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index b3cc361154198..04f5e244052ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -24,7 +24,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree, impurity} -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.configuration.{Algo, DTParams} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils @@ -118,7 +118,17 @@ object DecisionTreeRunner { case Variance => impurity.Variance } +<<<<<<< HEAD val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins) +======= + val strategy + = new DTParams( + algo = params.algo, + impurity = impurityCalculator, + maxDepth = params.maxDepth, + maxBins = params.maxBins, + numClassesForClassification = params.numClassesForClassification) +>>>>>>> 8725f7b... updating DT API, but not done yet val model = DecisionTree.train(training, strategy) if (params.algo == Classification) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetMetadata.scala new file mode 100644 index 0000000000000..ffabeb406d8f7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetMetadata.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +/** + * :: Experimental :: + * A class for holding dataset metadata. + * @param numClasses Number of classes for classification. Values of 0 or 1 indicate regression. + * @param numFeatures Number of features. + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ +class DatasetMetadata (val numClasses: Int, + val numFeatures: Int, + val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) + extends Serializable { + + /** + * Indicates if this dataset's label is categorical with >2 categories. + */ + def isMulticlass(): Boolean = { + numClasses > 2 + } + + /** + * Indicates if this dataset's label is categorical with >2 categories, + * and there is at least one categorical feature. + */ + def isMulticlassWithCategoricalFeatures(): Boolean = { + isMulticlass() && categoricalFeaturesInfo.nonEmpty + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 74d5d7ba10960..ffef7b7a41423 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,46 +19,63 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Experimental import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.DatasetMetadata import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.DTParams import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom + /** * :: Experimental :: * A class that implements a decision tree algorithm for classification and regression. It * supports both continuous and categorical features. - * @param strategy The configuration parameters for the tree algorithm which specify the type + * @param params The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. */ @Experimental -class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { +private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected val params: DTParams) + extends Serializable with Logging { + + protected final val InvalidBinIndex = -1 /** * Method to train a decision tree model over an RDD * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @return a DecisionTreeModel that can be used for prediction + * @param numClasses number of classes for classification. + * Default value is 2, for binary classification. + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + * @return top node of a DecisionTreeModel */ - def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + protected def trainSub( + input: RDD[LabeledPoint], + numClasses: Int = 2, + categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()): Node = { // Cache input RDD for speedup during multiple passes. input.cache() - logDebug("algo = " + strategy.algo) + + // Collect input metadata. + val numFeatures = input.take(1)(0).features.size + val dsMeta = new DatasetMetadata(numClasses, numFeatures, categoricalFeaturesInfo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins) = findSplitsBins(input, dsMeta) val numBins = bins(0).length logDebug("numBins = " + numBins) // depth of the decision tree - val maxDepth = strategy.maxDepth + val maxDepth = params.maxDepth // the max number of nodes possible given the depth of the tree val maxNumNodes = math.pow(2, maxDepth).toInt - 1 // Initialize an array to hold filters applied to points for each node. @@ -69,27 +86,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) - // num features - val numFeatures = input.take(1)(0).features.size // Calculate level for single group construction // Max memory usage for aggregates - val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 + val maxMemoryUsage = params.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = - strategy.algo match { - case Classification => 2 * numBins * numFeatures - case Regression => 3 * numBins * numFeatures - } - + val numElementsPerNode = getElementsPerNode(dsMeta, numBins) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) // nodes at a level is 2^level. level is zero indexed. val maxLevelForSingleGroup = math.max( - (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) + (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, + 0) logDebug("max level for single group = " + maxLevelForSingleGroup) /* @@ -109,10 +120,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, + val splitsStatsForLevel = findBestSplits(input, dsMeta, parentImpurities, level, filters, splits, bins, maxLevelForSingleGroup) - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) // Extract info for nodes at the next lower level. @@ -140,9 +151,86 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) - new DecisionTreeModel(topNode, strategy.algo) + topNode } + //=========================================================================== + // Protected abstract methods + //=========================================================================== + + /** + * For a given categorical feature, use a subsample of the data + * to choose how to arrange possible splits. + * This examines each category and computes a centroid. + * These centroids are later used to sort the possible splits. + * @return Mapping: category (for the given feature) --> centroid + */ + protected def computeCentroidForCategories( + featureIndex: Int, + sampledInput: Array[LabeledPoint], + dsMeta: DatasetMetadata): Map[Double,Double] + + /** + * Extracts left and right split aggregates. + * @param binData Array[Double] of size 2*numFeatures*numSplits + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], + * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, + * (numBins - 1), numClasses) + */ + protected def extractLeftRightNodeAggregates( + binData: Array[Double], + dsMeta: DatasetMetadata, + numBins: Int): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) + + /** + * Get the number of stats elements stored per node in bin aggregates. + */ + protected def getElementsPerNode( + dsMeta: DatasetMetadata, + numBins: Int): Int + + /** + * Performs a sequential aggregation of bins stats over a partition. + */ + protected def binSeqOpSub( + agg: Array[Double], + arr: Array[Double], + dsMeta: DatasetMetadata, + numNodes: Int, + bins: Array[Array[Bin]]): Array[Double] + + /** + * Calculates the information gain for all splits based upon left/right split aggregates. + * @param leftNodeAgg left node aggregates + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg right node aggregate + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + protected def calculateGainForSplit( + leftNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Array[Double]]], + topImpurity: Double, + numClasses: Int, + level: Int): InformationGainStats + + /** + * Get bin data for one node. + */ + protected def getBinDataForNode( + node: Int, + binAggregates: Array[Double], + dsMeta: DatasetMetadata, + numNodes: Int, + numBins: Int): Array[Double] + + //=========================================================================== + // Protected (non-abstract) methods + //=========================================================================== + /** * Extract the decision tree node information for the given tree level and node index */ @@ -154,7 +242,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = math.pow(2, level).toInt - 1 + index - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val isLeaf = (stats.gain <= 0) || (level == params.maxDepth - 1) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node @@ -194,85 +282,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo i += 1 } } -} - -object DecisionTree extends Serializable with Logging { - - /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. The parameters for the algorithm are specified using the strategy parameter. - * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree - * @param strategy The configuration parameters for the tree algorithm which specify the type - * of algorithm (classification, regression, etc.), feature type (continuous, - * categorical), depth of the tree, quantile calculation strategy, etc. - * @return a DecisionTreeModel that can be used for prediction - */ - def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) - } - - /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. - * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree - * @return a DecisionTreeModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int): DecisionTreeModel = { - val strategy = new Strategy(algo,impurity,maxDepth) - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) - } - - - /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The decision tree method supports binary classification and - * regression. For the binary classification, the label for each instance should either be 0 or - * 1 to denote the two classes. The method also supports categorical features inputs where the - * number of categories can specified using the categoricalFeaturesInfo option. - * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data for DecisionTree - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @return a DecisionTreeModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int, - maxBins: Int, - quantileCalculationStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { - val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, - categoricalFeaturesInfo) - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) - } - - private val InvalidBinIndex = -1 /** * Returns an array of optimal splits for all nodes at a given level. Splits the task into @@ -280,9 +289,8 @@ object DecisionTree extends Serializable with Logging { * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree + * @param dsMeta Metadata for input. * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -290,46 +298,46 @@ object DecisionTree extends Serializable with Logging { * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. * @return array of splits with best splits for all nodes at a given level. */ - protected[tree] def findBestSplits( + private def findBestSplits( input: RDD[LabeledPoint], + dsMeta: DatasetMetadata, parentImpurities: Array[Double], - strategy: Strategy, level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, // the nodes are divided into multiple groups at each level with the number of groups // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt + val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt logDebug("numGroups = " + numGroups) var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { - val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, + val bestSplitsForGroup = findBestSplitsPerGroup(input, dsMeta, parentImpurities, level, filters, splits, bins, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + findBestSplitsPerGroup(input, dsMeta, parentImpurities, level, filters, splits, bins) } } - /** + /** * Returns an array of optimal splits for a group of nodes at a given level * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree + * @param dsMeta Metadata for input. * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -340,8 +348,8 @@ object DecisionTree extends Serializable with Logging { */ private def findBestSplitsPerGroup( input: RDD[LabeledPoint], + dsMeta: DatasetMetadata, parentImpurities: Array[Double], - strategy: Strategy, level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], @@ -377,10 +385,16 @@ object DecisionTree extends Serializable with Logging { val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. - val numFeatures = input.first().features.size - logDebug("numFeatures = " + numFeatures) + //val numFeatures = input.first().features.size + //logDebug("numFeatures = " + numFeatures) val numBins = bins(0).length logDebug("numBins = " + numBins) + //val numClasses = dsMeta.numClasses + //logDebug("numClasses = " + numClasses) + val isMulticlass = dsMeta.isMulticlass() + logDebug("isMulticlass = " + isMulticlass) + val isMulticlassWithCategoricalFeatures = dsMeta.isMulticlassWithCategoricalFeatures() + logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex @@ -439,7 +453,8 @@ object DecisionTree extends Serializable with Logging { def findBin( featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean): Int = { + isFeatureContinuous: Boolean, + isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -467,17 +482,28 @@ object DecisionTree extends Serializable with Logging { -1 } + /** + * Sequential search helper method to find bin for categorical feature in multiclass + * classification. The category is returned since each category can belong to multiple + * splits. The actual left/right child allocation per split is performed in the + * sequential phase of the bin aggregate operation. + */ + def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { + labeledPoint.features(featureIndex).toInt + } + /** * Sequential search helper method to find bin for categorical feature. */ - def sequentialBinSearchForCategoricalFeature(): Int = { - val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) + def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { + val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { val bin = bins(featureIndex)(binIndex) - val category = bin.category + val categories = bin.highSplit.categories val features = labeledPoint.features - if (category == features(featureIndex)) { + if (categories.contains(features(featureIndex))) { return binIndex } binIndex += 1 @@ -494,7 +520,13 @@ object DecisionTree extends Serializable with Logging { binIndex } else { // Perform sequential search to find bin for categorical features. - val binIndex = sequentialBinSearchForCategoricalFeature() + val binIndex = { + if (isMulticlass && isSpaceSufficientForAllCategoricalSplits) { + sequentialBinSearchForUnorderedCategoricalFeatureInClassification() + } else { + sequentialBinSearchForOrderedCategoricalFeatureInClassification() + } + } if (binIndex == -1){ throw new UnknownError("no bin was found for categorical variable.") } @@ -506,131 +538,59 @@ object DecisionTree extends Serializable with Logging { * Finds bins for all nodes (and all features) at a given level. * For l nodes, k features the storage is as follows: * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, - * where b_ij is an integer between 0 and numBins - 1. + * where b_ij is an integer between 0 and numBins - 1 for regressions and binary + * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. - val arr = new Array[Double](1 + (numFeatures * numNodes)) + val arr = new Array[Double](1 + (dsMeta.numFeatures * numNodes)) + // First element of the array is the label of the instance. arr(0) = labeledPoint.label + // Iterate over nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) // Find out whether the sample qualifies for the particular node. val sampleValid = isSampleValid(parentFilters, labeledPoint) - val shift = 1 + numFeatures * nodeIndex + val shift = 1 + dsMeta.numFeatures * nodeIndex if (!sampleValid) { // Mark one bin as -1 is sufficient. arr(shift) = InvalidBinIndex } else { var featureIndex = 0 - while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) - featureIndex += 1 - } - } - nodeIndex += 1 - } - arr - } - - /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. - * - * @param agg Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures * numNodes for classification - */ - def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update the left or right count for one bin. - val aggShift = 2 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 - label match { - case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 - case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + while (featureIndex < dsMeta.numFeatures) { + val featureInfo = dsMeta.categoricalFeaturesInfo.get(featureIndex) + val isFeatureContinuous = featureInfo.isEmpty + if (isFeatureContinuous) { + arr(shift + featureIndex) + = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) + } else { + val featureCategories = featureInfo.get + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + arr(shift + featureIndex) + = findBin(featureIndex, labeledPoint, isFeatureContinuous, + isSpaceSufficientForAllCategoricalSplits) } featureIndex += 1 } } nodeIndex += 1 } + arr } - /** - * Performs a sequential aggregation over a partition for regression. For l nodes, k features, - * the count, sum, sum of squares of one of the p bins is incremented. - * - * @param agg Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for regression - */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update count, sum, and sum^2 for one bin. - val aggShift = 3 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 - agg(aggIndex) = agg(aggIndex) + 1 - agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label*label - featureIndex += 1 - } - } - nodeIndex += 1 - } - } + // Find feature bins for all nodes at a level. + val binMappedRDD = input.map(x => findBinsForLevel(x)) - /** - * Performs a sequential aggregation over a partition. - */ + // Performs a sequential aggregation over a partition. def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { - strategy.algo match { - case Classification => classificationBinSeqOp(arr, agg) - case Regression => regressionBinSeqOp(arr, agg) - } - agg + binSeqOpSub(agg, arr, dsMeta, numNodes, bins) } // Calculate bin aggregate length for classification or regression. - val binAggregateLength = strategy.algo match { - case Classification => 2 * numBins * numFeatures * numNodes - case Regression => 3 * numBins * numFeatures * numNodes - } + val binAggregateLength = numNodes * getElementsPerNode(dsMeta, numBins) logDebug("binAggregateLength = " + binAggregateLength) /** @@ -649,242 +609,25 @@ object DecisionTree extends Serializable with Logging { combinedAggregate } - // Find feature bins for all nodes at a level. - val binMappedRDD = input.map(x => findBinsForLevel(x)) - // Calculate bin aggregates. val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) } logDebug("binAggregates.length = " + binAggregates.length) - /** - * Calculates the information gain for all splits based upon left/right split aggregates. - * @param leftNodeAgg left node aggregates - * @param featureIndex feature index - * @param splitIndex split index - * @param rightNodeAgg right node aggregate - * @param topImpurity impurity of the parent node - * @return information gain and statistics for all splits - */ - def calculateGainForSplit( - leftNodeAgg: Array[Array[Double]], - featureIndex: Int, - splitIndex: Int, - rightNodeAgg: Array[Array[Double]], - topImpurity: Double): InformationGainStats = { - strategy.algo match { - case Classification => - val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) - val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) - val leftCount = left0Count + left1Count - - val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex) - val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) - val rightCount = right0Count + right1Count - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) - } - } - - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1) - } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0) - } - - val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) - - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) - - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } - - val predict = (left1Count + right1Count) / (leftCount + rightCount) - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) - case Regression => - val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) - val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) - val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) - - val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex) - val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) - val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val count = leftCount + rightCount - val sum = leftSum + rightSum - val sumSquares = leftSumSquares + rightSumSquares - strategy.impurity.calculate(count, sum, sumSquares) - } - } - - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum / rightCount) - } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity ,topImpurity, - Double.MinValue, leftSum / leftCount) - } - - val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) - val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) - - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) - - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } - - val predict = (leftSum + rightSum) / (leftCount + rightCount) - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) - } - } - - /** - * Extracts left and right split aggregates. - * @param binData Array[Double] of size 2*numFeatures*numSplits - * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], - * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) - */ - def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { - strategy.algo match { - case Classification => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 2 * featureIndex * numBins - - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(2 * (numBins - 2)) - = binData(shift + (2 * (numBins - 1))) - rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) - = binData(shift + (2 * (numBins - 1)) + 1) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) + - leftNodeAgg(featureIndex)(2 * splitIndex - 2) - leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) + - leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = - binData(shift + (2 *(numBins - 1 - splitIndex))) + - rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = - binData(shift + (2* (numBins - 1 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) - - splitIndex += 1 - } - featureIndex += 1 - } - (leftNodeAgg, rightNodeAgg) - case Regression => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 3 * featureIndex * numBins - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - leftNodeAgg(featureIndex)(2) = binData(shift + 2) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(3 * (numBins - 2)) = - binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = - binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = - binData(shift + (3 * (numBins - 1)) + 2) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3) - leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) - leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = - binData(shift + (3 * (numBins - 1 - splitIndex))) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = - binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = - binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) - - splitIndex += 1 - } - featureIndex += 1 - } - (leftNodeAgg, rightNodeAgg) - } - } - /** * Calculates information gain for all nodes splits. */ def calculateGainsForAllNodeSplits( - leftNodeAgg: Array[Array[Double]], - rightNodeAgg: Array[Array[Double]], + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], nodeImpurity: Double): Array[Array[InformationGainStats]] = { - val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) + val gains = Array.ofDim[InformationGainStats](dsMeta.numFeatures, numBins - 1) - for (featureIndex <- 0 until numFeatures) { + for (featureIndex <- 0 until dsMeta.numFeatures) { for (splitIndex <- 0 until numBins - 1) { gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, - splitIndex, rightNodeAgg, nodeImpurity) + splitIndex, rightNodeAgg, nodeImpurity, dsMeta.numClasses, level) } } gains @@ -903,7 +646,7 @@ object DecisionTree extends Serializable with Logging { logDebug("node impurity = " + nodeImpurity) // Extract left right node aggregates. - val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) + val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData, dsMeta, numBins) // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) @@ -915,10 +658,25 @@ object DecisionTree extends Serializable with Logging { var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) // Iterate over features. var featureIndex = 0 - while (featureIndex < numFeatures) { + while (featureIndex < dsMeta.numFeatures) { // Iterate over all splits. var splitIndex = 0 - while (splitIndex < numBins - 1) { + val maxSplitIndex : Double = { + val isFeatureContinuous = dsMeta.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + numBins - 1 + } else { // Categorical feature + val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isMulticlass && isSpaceSufficientForAllCategoricalSplits) { + math.pow(2.0, featureCategories - 1).toInt - 1 + } else { // Binary classification + featureCategories + } + } + } + while (splitIndex < maxSplitIndex) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -938,36 +696,20 @@ object DecisionTree extends Serializable with Logging { (splits(bestFeatureIndex)(bestSplitIndex), gainStats) } - /** - * Get bin data for one node. - */ - def getBinDataForNode(node: Int): Array[Double] = { - strategy.algo match { - case Classification => - val shift = 2 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) - binsForNode - case Regression => - val shift = 3 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) - binsForNode - } - } - // Calculate best splits for all nodes at a given level val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level var node = 0 while (node < numNodes) { val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift - val binsForNode: Array[Double] = getBinDataForNode(node) + val binsForNode: Array[Double] + = getBinDataForNode(node, binAggregates, dsMeta, numNodes, numBins) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) - logDebug("node impurity = " + parentNodeImpurity) + logDebug("parent node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) node += 1 } - bestSplits } @@ -975,34 +717,38 @@ object DecisionTree extends Serializable with Logging { * Returns split and bins for decision tree calculation. * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + dsMeta: DatasetMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { + val count = input.count() - // Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.size + val numFeatures = dsMeta.numFeatures - val maxBins = strategy.maxBins + val maxBins = params.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) + val isMulticlass = dsMeta.isMulticlass() + logDebug("isMulticlass = " + isMulticlass) + /* - * TODO: Add a require statement ensuring #bins is always greater than the categories. + * Ensure #bins is always greater than the categories. For multiclass classification, + * #bins should be greater than 2^(maxCategories - 1) - 1. * It's a limitation of the current implementation but a reasonable trade-off since features * with large number of categories get favored over continuous features. */ - if (strategy.categoricalFeaturesInfo.size > 0) { - val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 - require(numBins >= maxCategoriesForFeatures) + if (dsMeta.categoricalFeaturesInfo.size > 0) { + val maxCategoriesForFeatures = dsMeta.categoricalFeaturesInfo.maxBy(_._2)._2 + require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + + "in categorical features") } + // Calculate the number of sample for approximate quantile calculation. val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 @@ -1015,7 +761,7 @@ object DecisionTree extends Serializable with Logging { val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) - strategy.quantileCalculationStrategy match { + params.quantileStrategy match { case Sort => val splits = Array.ofDim[Split](numFeatures, numBins - 1) val bins = Array.ofDim[Bin](numFeatures, numBins) @@ -1026,7 +772,7 @@ object DecisionTree extends Serializable with Logging { var featureIndex = 0 while (featureIndex < numFeatures){ // Check whether the feature is continuous. - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = dsMeta.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride: Double = numSamples.toDouble / numBins @@ -1036,48 +782,76 @@ object DecisionTree extends Serializable with Logging { val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) splits(featureIndex)(index) = split } - } else { - val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - require(maxFeatureValue < numBins, "number of categories should be less than number " + - "of bins") - - // For categorical variables, each bin is a category. The bins are sorted and they - // are ordered by calculating the centroid of their corresponding labels. - val centroidForCategories = - sampledInput.map(lp => (lp.features(featureIndex),lp.label)) - .groupBy(_._1) - .mapValues(x => x.map(_._2).sum / x.map(_._1).length) - - // Check for missing categorical variables and putting them last in the sorted list. - val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until maxFeatureValue) { - if (centroidForCategories.contains(i)) { - fullCentroidForCategories(i) = centroidForCategories(i) - } else { - fullCentroidForCategories(i) = Double.MaxValue - } - } - - // bins sorted by centroids - val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - - logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) - - var categoriesForSplit = List[Double]() - categoriesSortedByCentroid.iterator.zipWithIndex.foreach { - case ((key, value), index) => - categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, - categoriesForSplit) + } else { // Categorical feature + val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 + + // Use different bin/split calculation strategy for categorical features in multiclass + // classification that satisfy the space constraint + if (isMulticlass && isSpaceSufficientForAllCategoricalSplits) { + // 2^(maxFeatureValue- 1) - 1 combinations + var index = 0 + while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { + val categories: List[Double] = + DecisionTree.extractMultiClassCategories(index + 1, featureCategories) + splits(featureIndex)(index) + = new Split(featureIndex, Double.MinValue, Categorical, categories) bins(featureIndex)(index) = { if (index == 0) { - new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, key) + new Bin( + new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), + Categorical, + Double.MinValue) } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Categorical, key) + new Bin( + splits(featureIndex)(index - 1), + splits(featureIndex)(index), + Categorical, + Double.MinValue) } } + index += 1 + } + } else { + + val centroidForCategories = + computeCentroidForCategories(featureIndex, sampledInput, dsMeta) + + logDebug("centroid for categories = " + centroidForCategories.mkString(",")) + + // Check for missing categorical variables and putting them last in the sorted list. + val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() + for (i <- 0 until featureCategories) { + if (centroidForCategories.contains(i)) { + fullCentroidForCategories(i) = centroidForCategories(i) + } else { + fullCentroidForCategories(i) = Double.MaxValue + } + } + + // bins sorted by centroids + val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) + + logDebug("centroid for categorical variable = " + categoriesSortedByCentroid) + + var categoriesForSplit = List[Double]() + categoriesSortedByCentroid.iterator.zipWithIndex.foreach { + case ((key, value), index) => + categoriesForSplit = key :: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, + Categorical, categoriesForSplit) + bins(featureIndex)(index) = { + if (index == 0) { + new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), Categorical, key) + } else { + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Categorical, key) + } + } + } } } featureIndex += 1 @@ -1086,7 +860,7 @@ object DecisionTree extends Serializable with Logging { // Find all bins. featureIndex = 0 while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = dsMeta.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) @@ -1107,4 +881,33 @@ object DecisionTree extends Serializable with Logging { throw new UnsupportedOperationException("approximate histogram not supported yet.") } } + +} + +object DecisionTree extends Serializable with Logging { + + /** + * Extract list of eligible categories given an index. + * It extracts the position of ones in a binary representation of the input. + * If binary representation of an number is 01101 (13), the output list should (3.0, 2.0, + * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. + */ + protected[tree] def extractMultiClassCategories( + input: Int, + maxFeatureValue: Int): List[Double] = { + var categories = List[Double]() + var j = 0 + var bitShiftedInput = input + while (j < maxFeatureValue) { + if (bitShiftedInput % 2 != 0) { + // updating the list of categories. + categories = j.toDouble :: categories + } + // Right shift by one + bitShiftedInput = bitShiftedInput >> 1 + j += 1 + } + categories + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala new file mode 100644 index 0000000000000..7503371eede96 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala @@ -0,0 +1,602 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.annotation.Experimental +import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.DatasetMetadata +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.DTClassifierParams +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.impurity.{ClassificationImpurity, Entropy, Gini} +import org.apache.spark.mllib.tree.model.{InformationGainStats, Bin, DecisionTreeClassifierModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + + +/** + * :: Experimental :: + * A class that implements a decision tree algorithm for classification. + * It supports both continuous and categorical features. + * @param params The configuration parameters for the tree algorithm. + */ +@Experimental +class DecisionTreeClassifier (params: DTClassifierParams) + extends DecisionTree[DecisionTreeClassifierModel](params) { + + private val impurityFunctor : ClassificationImpurity = params.impurity match { + case "gini" => Gini + case "entropy" => Entropy + case _ => throw new IllegalArgumentException(s"Bad impurity parameter for classification: ${params.impurity}") + } + + /** + * Method to train a decision tree model over an RDD + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * @param numClasses number of classes for classification. + * Default value is 2, for binary classification. + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + * @return a DecisionTreeClassifierModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + numClasses: Int = 2, + categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()): DecisionTreeClassifierModel = { + + logDebug("algo = Classification") + + val numClasses: Int = 2 + require(numClasses >= 2) + val isMulticlassClassification = numClasses > 2 + val isMulticlassWithCategoricalFeatures = + isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + + val topNode = super.trainSub(input, numClasses, categoricalFeaturesInfo) + + new DecisionTreeClassifierModel(topNode) + } + + //=========================================================================== + // Protected methods (abstract from DecisionTree) + //=========================================================================== + + protected def computeCentroidForCategories( + featureIndex: Int, + sampledInput: Array[LabeledPoint], + dsMeta: DatasetMetadata): Map[Double,Double] = { + if (dsMeta.isMulticlass()) { + // For categorical variables in multiclass classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the impurity of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) + .map(x => (x._1, x._2.values.toArray)) + .map(x => (x._1, impurityFunctor.calculate(x._2,x._2.sum))) + } else { // binary classification + // For categorical variables in binary classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + } + } + + /** + * Extracts left and right split aggregates. + * @param binData Array[Double] of size 2*numFeatures*numSplits + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], + * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, + * (numBins - 1), numClasses) + */ + protected def extractLeftRightNodeAggregates( + binData: Array[Double], + dsMeta: DatasetMetadata, + numBins: Int): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + + def findAggForOrderedFeatureClassification( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + // shift for this featureIndex + val shift = dsMeta.numClasses * featureIndex * numBins + + var classIndex = 0 + while (classIndex < dsMeta.numClasses) { + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(classIndex) + = binData(shift + (dsMeta.numClasses * (numBins - 1)) + classIndex) + classIndex += 1 + } + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + var innerClassIndex = 0 + while (innerClassIndex < dsMeta.numClasses) { + leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) + = binData(shift + dsMeta.numClasses * splitIndex + innerClassIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = + binData(shift + (dsMeta.numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) + innerClassIndex += 1 + } + splitIndex += 1 + } + } + + def findAggForUnorderedFeatureClassification( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + val rightChildShift = dsMeta.numClasses * numBins * dsMeta.numFeatures + var splitIndex = 0 + while (splitIndex < numBins - 1) { + var classIndex = 0 + while (classIndex < dsMeta.numClasses) { + // shift for this featureIndex + val shift = + dsMeta.numClasses * featureIndex * numBins + splitIndex * dsMeta.numClasses + val leftBinValue = binData(shift + classIndex) + val rightBinValue = binData(rightChildShift + shift + classIndex) + leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue + rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue + classIndex += 1 + } + splitIndex += 1 + } + } + + // Initialize left and right split aggregates. + val leftNodeAgg = + Array.ofDim[Double](dsMeta.numFeatures, numBins - 1, dsMeta.numClasses) + val rightNodeAgg = + Array.ofDim[Double](dsMeta.numFeatures, numBins - 1, dsMeta.numClasses) + var featureIndex = 0 + while (featureIndex < dsMeta.numFeatures) { + if (dsMeta.isMulticlassWithCategoricalFeatures()){ + val isFeatureContinuous = dsMeta.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } + } + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } + featureIndex += 1 + } + + (leftNodeAgg, rightNodeAgg) + } + + protected def getElementsPerNode( + dsMeta: DatasetMetadata, + numBins: Int): Int = { + if (dsMeta.isMulticlassWithCategoricalFeatures()) { + 2 * dsMeta.numClasses * numBins * dsMeta.numFeatures + } else { + dsMeta.numClasses * numBins * dsMeta.numFeatures + } + } + + /** + * Performs a sequential aggregation over a partition for classification. + * For l nodes, k features, + * either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. + * @param agg Array[Double] storing aggregate calculation of size + * numClasses * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation, of size: + * 2 * numSplits * numFeatures * numNodes for ordered features, or + * 2 * numClasses * numSplits * numFeatures * numNodes for unordered features + */ + protected def binSeqOpSub( + agg: Array[Double], + arr: Array[Double], + dsMeta: DatasetMetadata, + numNodes: Int, + bins: Array[Array[Bin]]): Array[Double] = { + val numBins = bins(0).length + if(dsMeta.isMulticlassWithCategoricalFeatures()) { + unorderedClassificationBinSeqOp(arr, agg, dsMeta, numNodes, bins) + } else { + orderedClassificationBinSeqOp(arr, agg, dsMeta, numNodes, numBins) + } + agg + } + + /** + * Calculates the information gain for all splits based upon left/right split aggregates. + * @param leftNodeAgg left node aggregates + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg right node aggregate + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + protected def calculateGainForSplit( + leftNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Array[Double]]], + topImpurity: Double, + numClasses: Int, + level: Int): InformationGainStats = { + + var classIndex = 0 + val leftCounts: Array[Double] = new Array[Double](numClasses) + val rightCounts: Array[Double] = new Array[Double](numClasses) + var leftTotalCount = 0.0 + var rightTotalCount = 0.0 + while (classIndex < numClasses) { + val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) + val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) + leftCounts(classIndex) = leftClassCount + leftTotalCount += leftClassCount + rightCounts(classIndex) = rightClassCount + rightTotalCount += rightClassCount + classIndex += 1 + } + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val rootNodeCounts = new Array[Double](numClasses) + var classIndex = 0 + while (classIndex < numClasses) { + rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) + classIndex += 1 + } + impurityFunctor.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) + } + } + + if (leftTotalCount == 0) { + return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) + } + if (rightTotalCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) + } + + val leftImpurity = impurityFunctor.calculate(leftCounts, leftTotalCount) + val rightImpurity = impurityFunctor.calculate(rightCounts, rightTotalCount) + + val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount) + val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount) + + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + val totalCount = leftTotalCount + rightTotalCount + + // Sum of count for each label + val leftRightCounts: Array[Double] + = leftCounts.zip(rightCounts) + .map{case (leftCount, rightCount) => leftCount + rightCount} + + def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) + else (maxIndex, maxValue, currentIndex + 1) + } + if (result._1 < 0) 0 else result._1 + } + + val predict = indexOfLargestArrayElement(leftRightCounts) + val prob = leftRightCounts(predict) / totalCount + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + } + + /** + * Get bin data for one node. + */ + protected def getBinDataForNode( + node: Int, + binAggregates: Array[Double], + dsMeta: DatasetMetadata, + numNodes: Int, + numBins: Int): Array[Double] = { + if (dsMeta.isMulticlassWithCategoricalFeatures()) { + val shift = dsMeta.numClasses * node * numBins * dsMeta.numFeatures + val rightChildShift = dsMeta.numClasses * numBins * dsMeta.numFeatures * numNodes + val binsForNode = { + val leftChildData + = binAggregates.slice(shift, shift + dsMeta.numClasses * numBins * dsMeta.numFeatures) + val rightChildData + = binAggregates.slice(rightChildShift + shift, + rightChildShift + shift + dsMeta.numClasses * numBins * dsMeta.numFeatures) + leftChildData ++ rightChildData + } + binsForNode + } else { + val shift = dsMeta.numClasses * node * numBins * dsMeta.numFeatures + val binsForNode = binAggregates.slice(shift, shift + dsMeta.numClasses * numBins * dsMeta.numFeatures) + binsForNode + } + } + + //=========================================================================== + // Private methods + //=========================================================================== + + private def updateBinForOrderedFeature( + arr: Array[Double], + agg: Array[Double], + nodeIndex: Int, + label: Double, + featureIndex: Int, + dsMeta: DatasetMetadata, + numBins: Int) = { + + // Find the bin index for this feature. + val arrShift = 1 + dsMeta.numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = dsMeta.numClasses * numBins * dsMeta.numFeatures * nodeIndex + val aggIndex + = aggShift + dsMeta.numClasses * featureIndex * numBins + + arr(arrIndex).toInt * dsMeta.numClasses + val labelInt = label.toInt + agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 + } + + private def updateBinForUnorderedFeature( + nodeIndex: Int, + featureIndex: Int, + arr: Array[Double], + label: Double, + agg: Array[Double], + rightChildShift: Int, + dsMeta: DatasetMetadata, + numBins: Int, + bins: Array[Array[Bin]]) = { + + // Find the bin index for this feature. + val arrShift = 1 + dsMeta.numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = dsMeta.numClasses * numBins * dsMeta.numFeatures * nodeIndex + val aggIndex + = aggShift + dsMeta.numClasses * featureIndex * numBins + arr(arrIndex).toInt * dsMeta.numClasses + // Find all matching bins and increment their values + val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + var binIndex = 0 + while (binIndex < numCategoricalBins) { + val labelInt = label.toInt + if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { + agg(aggIndex + binIndex) + = agg(aggIndex + binIndex) + 1 + } else { + agg(rightChildShift + aggIndex + binIndex) + = agg(rightChildShift + aggIndex + binIndex) + 1 + } + binIndex += 1 + } + } + + /** + * Helper for binSeqOp + */ + private def orderedClassificationBinSeqOp( + arr: Array[Double], + agg: Array[Double], + dsMeta: DatasetMetadata, + numNodes: Int, + numBins: Int) = { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + dsMeta.numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < dsMeta.numFeatures) { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, dsMeta, numBins) + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Helper for binSeqOp + */ + private def unorderedClassificationBinSeqOp( + arr: Array[Double], + agg: Array[Double], + dsMeta: DatasetMetadata, + numNodes: Int, + bins: Array[Array[Bin]]) = { + val numBins = bins(0).length + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + dsMeta.numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + val rightChildShift = dsMeta.numClasses * numBins * dsMeta.numFeatures * numNodes + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < dsMeta.numFeatures) { + val isFeatureContinuous = dsMeta.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, dsMeta, numBins) + } else { + val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, + rightChildShift, dsMeta, numBins, bins) + } else { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, dsMeta, numBins) + } + } + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + +} + +object DecisionTreeClassifier extends Serializable with Logging { + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. The parameters for the algorithm are specified using the params parameter. + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param params The configuration parameters for the tree algorithm which specify the type + * of algorithm (classification, regression, etc.), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + * @return a DecisionTreeClassifierModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint], params: DTParams): DecisionTreeClassifierModel = { + new DecisionTree(params).train(input) + } + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algorithm, classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @return a DecisionTreeClassifierModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int): DecisionTreeClassifierModel = { + val params = new DTParams(algo, impurity, maxDepth) + new DecisionTree(params).train(input) + } + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algorithm, classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @param numClasses number of classes for classification. Default value of 2. + * @return a DecisionTreeClassifierModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClasses: Int): DecisionTreeClassifierModel = { + val params = new DTParams(algo, impurity, maxDepth, numClasses) + new DecisionTree(params).train(input) + } + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The decision tree method supports binary classification and + * regression. For the binary classification, the label for each instance should either be 0 or + * 1 to denote the two classes. The method also supports categorical features inputs where the + * number of categories can specified using the categoricalFeaturesInfo option. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data for DecisionTree + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param numClasses number of classes for classification. Default value of 2. + * @param maxBins maximum number of bins used for splitting features + * @param quantileStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, + * an entry (n -> k) implies the feature n is categorical with k + * categories 0, 1, 2, ... , k-1. It's important to note that + * features are zero-indexed. + * @return a DecisionTreeClassifierModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClasses: Int, + maxBins: Int, + quantileStrategy: QuantileStrategy, + categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeClassifierModel = { + val params = new DTParams(algo, impurity, maxDepth, numClasses, maxBins, + quantileStrategy, categoricalFeaturesInfo) + new DecisionTree(params).train(input) + } + + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala new file mode 100644 index 0000000000000..c04646908154b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.rdd.DatasetMetadata +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.DTRegressorParams +import org.apache.spark.mllib.tree.impurity.{RegressionImpurity, Variance} +import org.apache.spark.mllib.tree.model.{InformationGainStats, Bin, DecisionTreeRegressorModel} +import org.apache.spark.rdd.RDD + + +/** + * :: Experimental :: + * A class that implements a decision tree algorithm for regression. + * It supports both continuous and categorical features. + * @param params The configuration parameters for the tree algorithm. + */ +@Experimental +class DecisionTreeRegressor (params: DTRegressorParams) + extends DecisionTree[DecisionTreeRegressorModel](params) { + + private val impurityFunctor = params.impurity match { + case "variance" => Variance + case _ => throw new IllegalArgumentException(s"Bad impurity parameter for regression: ${params.impurity}") + } + + /** + * Method to train a decision tree model over an RDD + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + * @return a DecisionTreeRegressorModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()): DecisionTreeRegressorModel = { + + logDebug("algo = Regression") + + val topNode = super.trainSub(input, 0, categoricalFeaturesInfo) + + new DecisionTreeRegressorModel(topNode) + } + + //=========================================================================== + // Protected methods (abstract from DecisionTree) + //=========================================================================== + + protected def computeCentroidForCategories( + featureIndex: Int, + sampledInput: Array[LabeledPoint], + dsMeta: DatasetMetadata): Map[Double,Double] = { + // For categorical variables in regression, each bin is a category. + // The bins are sorted and are ordered by calculating the centroid of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + } + + /** + * Extracts left and right split aggregates. + * @param binData Array[Double] of size 2*numFeatures*numSplits + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], + * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, + * (numBins - 1), numClasses) + */ + protected def extractLeftRightNodeAggregates( + binData: Array[Double], + dsMeta: DatasetMetadata, + numBins: Int): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](dsMeta.numFeatures, numBins - 1, 3) + val rightNodeAgg = Array.ofDim[Double](dsMeta.numFeatures, numBins - 1, 3) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < dsMeta.numFeatures) { + // shift for this featureIndex + val shift = 3 * featureIndex * numBins + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(0) = + binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(numBins - 2)(1) = + binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(numBins - 2)(2) = + binData(shift + (3 * (numBins - 1)) + 2) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + var i = 0 // index for regression histograms + while (i < 3) { // count, sum, sum^2 + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) + + leftNodeAgg(featureIndex)(splitIndex - 1)(i) + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) = + binData(shift + (3 * (numBins - 1 - splitIndex) + i)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i) + i += 1 + } + splitIndex += 1 + } + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + } + + protected def getElementsPerNode( + dsMeta: DatasetMetadata, + numBins: Int): Int = { + 3 * numBins * dsMeta.numFeatures + } + + /** + * Performs a sequential aggregation of bins stats over a partition for regression. + * For l nodes, k features, + * the count, sum, sum of squares of one of the p bins is incremented. + * + * @param agg Array[Double] storing aggregate calculation of size + * 3 * numSplits * numFeatures * numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 3 * numSplits * numFeatures * numNodes for regression + */ + protected def binSeqOpSub( + agg: Array[Double], + arr: Array[Double], + dsMeta: DatasetMetadata, + numNodes: Int, + bins: Array[Array[Bin]]): Array[Double] = { + val numBins = bins(0).length + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + dsMeta.numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < dsMeta.numFeatures) { + // Find the bin index for this feature. + val arrShift = 1 + dsMeta.numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update count, sum, and sum^2 for one bin. + val aggShift = 3 * numBins * dsMeta.numFeatures * nodeIndex + val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 + agg(aggIndex) = agg(aggIndex) + 1 + agg(aggIndex + 1) = agg(aggIndex + 1) + label + agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + featureIndex += 1 + } + } + nodeIndex += 1 + } + agg + } + + /** + * Calculates the information gain for all splits based upon left/right split aggregates. + * @param leftNodeAgg left node aggregates + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg right node aggregate + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + protected def calculateGainForSplit( + leftNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Array[Double]]], + topImpurity: Double, + numClasses: Int, + level: Int): InformationGainStats = { + + val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) + val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) + val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) + + val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) + val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) + val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val count = leftCount + rightCount + val sum = leftSum + rightSum + val sumSquares = leftSumSquares + rightSumSquares + impurityFunctor.calculate(count, sum, sumSquares) + } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, + rightSum / rightCount) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity ,topImpurity, + Double.MinValue, leftSum / leftCount) + } + + val leftImpurity = impurityFunctor.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = impurityFunctor.calculate(rightCount, rightSum, rightSumSquares) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + val predict = (leftSum + rightSum) / (leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + } + + /** + * Get bin data for one node. + */ + protected def getBinDataForNode( + node: Int, + binAggregates: Array[Double], + dsMeta: DatasetMetadata, + numNodes: Int, + numBins: Int): Array[Double] = { + val shift = 3 * node * numBins * dsMeta.numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * dsMeta.numFeatures) + binsForNode + } + + //=========================================================================== + // Protected methods + //=========================================================================== + + /** + * Performs a sequential aggregation over a partition for regression. + */ + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala new file mode 100644 index 0000000000000..ce98a395e1f82 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.tree.configuration.DTParams + +/** + * :: Experimental :: + * Stores all the configuration options for DecisionTreeClassifier construction + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileStrategy algorithm for calculating quantiles + * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * 128 MB. + * @param impurity criterion used for information gain calculation (e.g., "gini" or "entropy") + */ +@Experimental +class DTClassifierParams ( + maxDepth: Int = 5, + maxBins: Int = 100, + quantileStrategy: String = "sort", + maxMemoryInMB: Int = 128, + val impurity: String = "gini") + extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { + + /* + if (!List("gini", "entropy").contains(impurity)) { + throw new IllegalArgumentException(s"Bad impurity parameter for classification: $impurity") + } + */ + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala new file mode 100644 index 0000000000000..8721e431f7b3d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ + +/** + * :: Experimental :: + * Stores configuration options for DecisionTree construction. + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileStrategy algorithm for calculating quantiles + * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * 128 MB. + */ +@Experimental +class DTParams ( + val maxDepth: Int, + val maxBins: Int, + val quantileStrategy: String, + val maxMemoryInMB: Int) extends Serializable { + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala new file mode 100644 index 0000000000000..c2e0d537308e0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.tree.configuration.DTParams + +/** + * :: Experimental :: + * Stores all the configuration options for DecisionTreeRegressor construction + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileStrategy algorithm for calculating quantiles + * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * 128 MB. + * @param impurity criterion used for information gain calculation (e.g., "variance") + */ +@Experimental +class DTRegressorParams ( + maxDepth: Int = 5, + maxBins: Int = 100, + quantileStrategy: String = "sort", + maxMemoryInMB: Int = 128, + val impurity: String = "variance") + extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { + + /* + if (!List("variance").contains(impurity)) { + throw new IllegalArgumentException(s"Bad impurity parameter for regression: $impurity") + } + */ + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurity.scala new file mode 100644 index 0000000000000..3aade2eeaac72 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurity.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.{DeveloperApi, Experimental} + +/** + * :: Experimental :: + * Trait for calculating information gain for classification. + */ +@Experimental +private[mllib] trait ClassificationImpurity extends Serializable { + + /** + * :: DeveloperApi :: + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value + */ + @DeveloperApi + def calculate(counts: Array[Double], totalCount: Double): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 60f43e9278d2a..5674848dbaf05 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -25,29 +25,31 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * binary classification. */ @Experimental -object Entropy extends Impurity { +private[mllib] object Entropy extends ClassificationImpurity { private[tree] def log2(x: Double) = scala.math.log(x) / scala.math.log(2) /** * :: DeveloperApi :: - * entropy calculation - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 - * @return entropy value + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value */ @DeveloperApi - override def calculate(c0: Double, c1: Double): Double = { - if (c0 == 0 || c1 == 0) { - 0 - } else { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - -(f0 * log2(f0)) - (f1 * log2(f1)) + override def calculate(counts: Array[Double], totalCount: Double): Double = { + val numClasses = counts.length + var impurity = 0.0 + var classIndex = 0 + while (classIndex < numClasses) { + val classCount = counts(classIndex) + if (classCount != 0) { + val freq = classCount / totalCount + impurity -= freq * log2(freq) + } + classIndex += 1 } + impurity } - override def calculate(count: Double, sum: Double, sumSquares: Double): Double = - throw new UnsupportedOperationException("Entropy.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index c51d76d9b4c5b..20ca09f4a0395 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -26,27 +26,26 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * during binary classification. */ @Experimental -object Gini extends Impurity { +private[mllib] object Gini extends ClassificationImpurity { /** * :: DeveloperApi :: - * Gini coefficient calculation - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 - * @return Gini coefficient value + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value */ @DeveloperApi - override def calculate(c0: Double, c1: Double): Double = { - if (c0 == 0 || c1 == 0) { - 0 - } else { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - 1 - f0 * f0 - f1 * f1 + override def calculate(counts: Array[Double], totalCount: Double): Double = { + val numClasses = counts.length + var impurity = 1.0 + var classIndex = 0 + while (classIndex < numClasses) { + val freq = counts(classIndex) / totalCount + impurity -= freq * freq + classIndex += 1 } + impurity } - override def calculate(count: Double, sum: Double, sumSquares: Double): Double = - throw new UnsupportedOperationException("Gini.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 8eab247cf0932..16b28c3471113 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -24,17 +24,17 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * Trait for calculating information gain. */ @Experimental -trait Impurity extends Serializable { +private[mllib] trait Impurity extends Serializable { /** * :: DeveloperApi :: - * information calculation for binary classification - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels * @return information value */ @DeveloperApi - def calculate(c0 : Double, c1 : Double): Double + def calculate(counts: Array[Double], totalCount: Double): Double /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala new file mode 100644 index 0000000000000..6e01b0334c9fe --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.{DeveloperApi, Experimental} + +/** + * :: Experimental :: + * Trait for calculating information gain. + */ +@Experimental +private[mllib] trait RegressionImpurity extends Serializable { + + /** + * :: DeveloperApi :: + * information calculation for regression + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return information value + */ + @DeveloperApi + def calculate(count: Double, sum: Double, sumSquares: Double): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 47d07122af30f..0feac65a574af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -24,8 +24,17 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * Class for calculating variance during regression */ @Experimental -object Variance extends Impurity { - override def calculate(c0: Double, c1: Double): Double = +private[mllib] object Variance extends Impurity { + + /** + * :: DeveloperApi :: + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value + */ + @DeveloperApi + override def calculate(counts: Array[Double], totalCount: Double): Double = throw new UnsupportedOperationException("Variance.calculate") /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 2d71e1e366069..c89c1e371a40e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * @param highSplit signifying the upper threshold for the continuous feature to be * accepted in the bin * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin + * @param category categorical label value accepted in the bin for binary classification */ private[tree] case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala new file mode 100644 index 0000000000000..3cd616820fbe7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.Experimental +//import org.apache.spark.mllib.tree.model.DecisionTreeModel +//import org.apache.spark.mllib.tree.model.Node + + +/** + * :: Experimental :: + * Decision tree model for classification. + * This model stores learned parameters. + * @param topNode root node + */ +@Experimental +class DecisionTreeClassifierModel(topNode: Node) extends DecisionTreeModel(topNode) { + + /** + * Print tree. + */ + def print() { + println(s"DecisionTreeClassifierModel") + topNode.print(" ") + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index bf692ca8c4bd7..9c9b22763eedc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector @@ -26,10 +25,9 @@ import org.apache.spark.mllib.linalg.Vector * :: Experimental :: * Model to store the decision tree parameters * @param topNode root node - * @param algo algorithm type -- classification or regression */ @Experimental -class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { +class DecisionTreeModel(val topNode: Node) extends Serializable { /** * Predict values for a single data point using the model trained. @@ -50,4 +48,5 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala new file mode 100644 index 0000000000000..5a566ba1be1ce --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Decision tree model for regression. + * This model stores learned parameters. + * @param topNode root node + */ +@Experimental +class DecisionTreeRegressorModel(topNode: Node) extends DecisionTreeModel(topNode) { + + /** + * Print tree. + */ + def print() { + println(s"DecisionTreeRegressorModel") + topNode.print(" ") + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index cc8a24cce9614..fb12298e0f5d3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -27,6 +27,7 @@ import org.apache.spark.annotation.DeveloperApi * @param leftImpurity left node impurity * @param rightImpurity right node impurity * @param predict predicted value + * @param prob probability of the label (classification only) */ @DeveloperApi class InformationGainStats( @@ -34,10 +35,11 @@ class InformationGainStats( val impurity: Double, val leftImpurity: Double, val rightImpurity: Double, - val predict: Double) extends Serializable { + val predict: Double, + val prob: Double = 0.0) extends Serializable { override def toString = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" - .format(gain, impurity, leftImpurity, rightImpurity, predict) + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f" + .format(gain, impurity, leftImpurity, rightImpurity, predict, prob) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 682f213f411a7..5705916eb6531 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -26,7 +26,9 @@ import org.apache.spark.mllib.linalg.Vector * :: DeveloperApi :: * Node in a decision tree * @param id integer node id - * @param predict predicted value at the node + * @param predict Predicted value at the node. + * For classification, this is a class label in 0,1,.... + * For regression, this is a real value. * @param isLeaf whether the leaf is a node * @param split split to calculate left and right nodes * @param leftNode left child @@ -91,4 +93,39 @@ class Node ( } } } + + /** + * Recursive print functions. + * @param prefix Prefix for each printed line (for spacing). + */ + def print(prefix: String = "") { + + def splitToString(split: Split, left: Boolean) : String = { + split.featureType match { + case Continuous => { + if (left) { + s"(feature ${split.feature} <= ${split.threshold})" + } else { + s"(feature ${split.feature} > ${split.threshold})" + } + } + case Categorical => { + if (left) { + s"(feature ${split.feature} in ${split.categories})" + } else { + s"(feature ${split.feature} not in ${split.categories})" + } + } + } + } + + if (isLeaf) { + println(prefix + s"Predict: $predict") + } else { + println(prefix + s"If ${splitToString(split.get, true)}") + leftNode.get.print(prefix + " ") + println(prefix + s"Else ${splitToString(split.get, false)}") + tNode.get.print(prefix + " ") + } + } } From 20fc8057e912c8cc1266cbb39ce0285907e7356b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 19 Jul 2014 16:37:50 -0700 Subject: [PATCH 02/20] Mostly done with DecisionTree API re-config. Still need to update DecisionTreeRegressor class,object, update docs, tests and examples. --- .../spark/examples/mllib/DTRunnerJKB.scala | 2 +- .../spark/mllib/rdd/DatasetMetadata.scala | 27 ++- .../spark/mllib/tree/DecisionTree.scala | 34 +-- .../mllib/tree/DecisionTreeClassifier.scala | 193 ++++++++---------- .../mllib/tree/DecisionTreeRegressor.scala | 101 +++++++++ .../configuration/DTClassifierParams.scala | 10 +- .../mllib/tree/configuration/DTParams.scala | 5 +- .../configuration/DTRegressorParams.scala | 10 +- .../tree/configuration/QuantileStrategy.scala | 12 ++ .../mllib/tree/configuration/Strategy.scala | 60 ------ 10 files changed, 245 insertions(+), 209 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala index d38e946aeed98..a90f3da10e7af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala @@ -172,7 +172,7 @@ object DTRunnerJKB { impurity = impurityCalculator, maxDepth = params.maxDepth, maxBins = params.maxBins, - numClassesForClassification = numClasses) + numClasses = numClasses) val model = DecisionTree.train(training, strategy) model.print() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetMetadata.scala index ffabeb406d8f7..c6d996b0659ba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetMetadata.scala @@ -28,15 +28,30 @@ package org.apache.spark.mllib.rdd * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. */ -class DatasetMetadata (val numClasses: Int, - val numFeatures: Int, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) +class DatasetMetadata ( + val numClasses: Int, + val numFeatures: Int, + val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable { + /** + * Indicates if this dataset's label is real-valued (numClasses < 2). + */ + def isRegression: Boolean = { + numClasses < 2 + } + + /** + * Indicates if this dataset's label is categorical (numClasses >= 2). + */ + def isClassification: Boolean = { + numClasses >= 2 + } + /** * Indicates if this dataset's label is categorical with >2 categories. */ - def isMulticlass(): Boolean = { + def isMulticlass: Boolean = { numClasses > 2 } @@ -44,8 +59,8 @@ class DatasetMetadata (val numClasses: Int, * Indicates if this dataset's label is categorical with >2 categories, * and there is at least one categorical feature. */ - def isMulticlassWithCategoricalFeatures(): Boolean = { - isMulticlass() && categoricalFeaturesInfo.nonEmpty + def isMulticlassWithCategoricalFeatures: Boolean = { + isMulticlass && categoricalFeaturesInfo.nonEmpty } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ffef7b7a41423..36c4f24e1d47f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.DTParams import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom @@ -47,27 +47,16 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va /** * Method to train a decision tree model over an RDD * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @param numClasses number of classes for classification. - * Default value is 2, for binary classification. - * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * @param dsMeta Dataset metadata. * @return top node of a DecisionTreeModel */ protected def trainSub( input: RDD[LabeledPoint], - numClasses: Int = 2, - categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()): Node = { + dsMeta: DatasetMetadata): Node = { // Cache input RDD for speedup during multiple passes. input.cache() - // Collect input metadata. - val numFeatures = input.take(1)(0).features.size - val dsMeta = new DatasetMetadata(numClasses, numFeatures, categoricalFeaturesInfo) - // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. val (splits, bins) = findSplitsBins(input, dsMeta) @@ -384,16 +373,14 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va // common calculations for multiple nested methods val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) - // Find the number of features by looking at the first sample. - //val numFeatures = input.first().features.size //logDebug("numFeatures = " + numFeatures) val numBins = bins(0).length logDebug("numBins = " + numBins) //val numClasses = dsMeta.numClasses //logDebug("numClasses = " + numClasses) - val isMulticlass = dsMeta.isMulticlass() + val isMulticlass = dsMeta.isMulticlass logDebug("isMulticlass = " + isMulticlass) - val isMulticlassWithCategoricalFeatures = dsMeta.isMulticlassWithCategoricalFeatures() + val isMulticlassWithCategoricalFeatures = dsMeta.isMulticlassWithCategoricalFeatures logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level @@ -732,7 +719,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va val maxBins = params.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) - val isMulticlass = dsMeta.isMulticlass() + val isMulticlass = dsMeta.isMulticlass logDebug("isMulticlass = " + isMulticlass) @@ -755,14 +742,15 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va logDebug("fraction of data used for calculating quantiles = " + fraction) // sampled input for RDD calculation - val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() + val sampledInput = + input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() val numSamples = sampledInput.length val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) params.quantileStrategy match { - case Sort => + case QuantileStrategy.Sort => val splits = Array.ofDim[Split](numFeatures, numBins - 1) val bins = Array.ofDim[Bin](numFeatures, numBins) @@ -875,9 +863,9 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va featureIndex += 1 } (splits,bins) - case MinMax => + case QuantileStrategy.MinMax => throw new UnsupportedOperationException("minmax not supported yet.") - case ApproxHist => + case QuantileStrategy.ApproxHist => throw new UnsupportedOperationException("approximate histogram not supported yet.") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala index 7503371eede96..fca8725947888 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala @@ -41,39 +41,22 @@ import org.apache.spark.util.random.XORShiftRandom class DecisionTreeClassifier (params: DTClassifierParams) extends DecisionTree[DecisionTreeClassifierModel](params) { - private val impurityFunctor : ClassificationImpurity = params.impurity match { - case "gini" => Gini - case "entropy" => Entropy - case _ => throw new IllegalArgumentException(s"Bad impurity parameter for classification: ${params.impurity}") - } + private val impurityFunctor = params.impurity /** * Method to train a decision tree model over an RDD * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @param numClasses number of classes for classification. - * Default value is 2, for binary classification. - * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * @param dsMeta Dataset metadata specifying number of classes, features, etc. * @return a DecisionTreeClassifierModel that can be used for prediction */ def train( input: RDD[LabeledPoint], - numClasses: Int = 2, - categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()): DecisionTreeClassifierModel = { + dsMeta: DatasetMetadata): DecisionTreeClassifierModel = { + require(dsMeta.isClassification) logDebug("algo = Classification") - val numClasses: Int = 2 - require(numClasses >= 2) - val isMulticlassClassification = numClasses > 2 - val isMulticlassWithCategoricalFeatures = - isMulticlassClassification && (categoricalFeaturesInfo.size > 0) - - val topNode = super.trainSub(input, numClasses, categoricalFeaturesInfo) - + val topNode = super.trainSub(input, dsMeta) new DecisionTreeClassifierModel(topNode) } @@ -85,7 +68,7 @@ class DecisionTreeClassifier (params: DTClassifierParams) featureIndex: Int, sampledInput: Array[LabeledPoint], dsMeta: DatasetMetadata): Map[Double,Double] = { - if (dsMeta.isMulticlass()) { + if (dsMeta.isMulticlass) { // For categorical variables in multiclass classification, // each bin is a category. The bins are sorted and they // are ordered by calculating the impurity of their corresponding labels. @@ -183,7 +166,7 @@ class DecisionTreeClassifier (params: DTClassifierParams) Array.ofDim[Double](dsMeta.numFeatures, numBins - 1, dsMeta.numClasses) var featureIndex = 0 while (featureIndex < dsMeta.numFeatures) { - if (dsMeta.isMulticlassWithCategoricalFeatures()){ + if (dsMeta.isMulticlassWithCategoricalFeatures){ val isFeatureContinuous = dsMeta.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) @@ -209,7 +192,7 @@ class DecisionTreeClassifier (params: DTClassifierParams) protected def getElementsPerNode( dsMeta: DatasetMetadata, numBins: Int): Int = { - if (dsMeta.isMulticlassWithCategoricalFeatures()) { + if (dsMeta.isMulticlassWithCategoricalFeatures) { 2 * dsMeta.numClasses * numBins * dsMeta.numFeatures } else { dsMeta.numClasses * numBins * dsMeta.numFeatures @@ -235,7 +218,7 @@ class DecisionTreeClassifier (params: DTClassifierParams) numNodes: Int, bins: Array[Array[Bin]]): Array[Double] = { val numBins = bins(0).length - if(dsMeta.isMulticlassWithCategoricalFeatures()) { + if(dsMeta.isMulticlassWithCategoricalFeatures) { unorderedClassificationBinSeqOp(arr, agg, dsMeta, numNodes, bins) } else { orderedClassificationBinSeqOp(arr, agg, dsMeta, numNodes, numBins) @@ -343,7 +326,7 @@ class DecisionTreeClassifier (params: DTClassifierParams) dsMeta: DatasetMetadata, numNodes: Int, numBins: Int): Array[Double] = { - if (dsMeta.isMulticlassWithCategoricalFeatures()) { + if (dsMeta.isMulticlassWithCategoricalFeatures) { val shift = dsMeta.numClasses * node * numBins * dsMeta.numFeatures val rightChildShift = dsMeta.numClasses * numBins * dsMeta.numFeatures * numNodes val binsForNode = { @@ -500,103 +483,97 @@ class DecisionTreeClassifier (params: DTClassifierParams) object DecisionTreeClassifier extends Serializable with Logging { /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. The parameters for the algorithm are specified using the params parameter. + * Train a decision tree model for binary or multiclass classification. * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree - * @param params The configuration parameters for the tree algorithm which specify the type - * of algorithm (classification, regression, etc.), feature type (continuous, - * categorical), depth of the tree, quantile calculation strategy, etc. - * @return a DecisionTreeClassifierModel that can be used for prediction + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param dsMeta Dataset metadata (number of features, number of classes, etc.) + * @param params The configuration parameters for the tree learning algorithm + * (tree depth, quantile calculation strategy, etc.) + * @return DecisionTreeClassifierModel which can be used for prediction */ - def train(input: RDD[LabeledPoint], params: DTParams): DecisionTreeClassifierModel = { - new DecisionTree(params).train(input) + def train( + input: RDD[LabeledPoint], + dsMeta: DatasetMetadata, + params: DTClassifierParams = new DTClassifierParams()): DecisionTreeClassifierModel = { + require(dsMeta.numClasses >= 2) + new DecisionTreeClassifier(params).train(input, dsMeta) } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. + * Train a decision tree model for binary or multiclass classification. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree - * @return a DecisionTreeClassifierModel that can be used for prediction + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClasses Number of classes (label types) for classification. + * Default = 2 (binary classification). + * @param categoricalFeaturesInfo A map from each categorical variable to the + * number of discrete values it takes. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It is important to note that features are + * zero-indexed. + * Default = treat all features as continuous. + * @param params The configuration parameters for the tree learning algorithm + * (tree depth, quantile calculation strategy, etc.) + * @return DecisionTreeClassifierModel which can be used for prediction */ def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int): DecisionTreeClassifierModel = { - val params = new DTParams(algo, impurity, maxDepth) - new DecisionTree(params).train(input) + input: RDD[LabeledPoint], + numClasses: Int = 2, + categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + params: DTClassifierParams = new DTClassifierParams()): DecisionTreeClassifierModel = { + + // Find the number of features by looking at the first sample. + val numFeatures = input.first().features.size + val dsMeta = new DatasetMetadata(numClasses, numFeatures, categoricalFeaturesInfo) + + train(input, dsMeta, params) } - /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. - * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree - * @param numClasses number of classes for classification. Default value of 2. - * @return a DecisionTreeClassifierModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int, - numClasses: Int): DecisionTreeClassifierModel = { - val params = new DTParams(algo, impurity, maxDepth, numClasses) - new DecisionTree(params).train(input) + // TODO: Move elsewhere! + protected def getImpurity(impurityName: String): ClassificationImpurity = { + impurityName match { + case "gini" => Gini + case "entropy" => Entropy + case _ => throw new IllegalArgumentException( + s"Bad impurity parameter for classification: $impurityName") + } } + // TODO: Add various versions of train() function below. + /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The decision tree method supports binary classification and - * regression. For the binary classification, the label for each instance should either be 0 or - * 1 to denote the two classes. The method also supports categorical features inputs where the - * number of categories can specified using the categoricalFeaturesInfo option. + * Train a decision tree model for binary or multiclass classification. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data for DecisionTree - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree - * @param numClasses number of classes for classification. Default value of 2. - * @param maxBins maximum number of bins used for splitting features - * @param quantileStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. - * @return a DecisionTreeClassifierModel that can be used for prediction + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClasses Number of classes (label types) for classification. + * @param categoricalFeaturesInfo A map from each categorical variable to the + * number of discrete values it takes. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It is important to note that features are + * zero-indexed. + * @param impurityName Criterion used for information gain calculation + * @param maxDepth Maximum depth of the tree + * @param maxBins Maximum number of bins used for splitting features + * @param quantileStrategyName Algorithm for calculating quantiles + * @return DecisionTreeClassifierModel which can be used for prediction */ def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int, - numClasses: Int, - maxBins: Int, - quantileStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeClassifierModel = { - val params = new DTParams(algo, impurity, maxDepth, numClasses, maxBins, - quantileStrategy, categoricalFeaturesInfo) - new DecisionTree(params).train(input) + input: RDD[LabeledPoint], + numClasses: Int, + categoricalFeaturesInfo: Map[Int, Int], + impurityName: String, + maxDepth: Int, + maxBins: Int, + quantileStrategyName: String, + maxMemoryInMB: Int): DecisionTreeClassifierModel = { + + val impurity = getImpurity(impurityName) + val quantileStrategy = getQuantileStrategy(quantileStrategyName) + val params = + new DTClassifierParams(impurity, maxDepth, maxBins, quantileStrategy, maxMemoryInMB) + train(input, numClasses, categoricalFeaturesInfo, params) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala index c04646908154b..11d5aada5a549 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Experimental +import org.apache.spark.Logging import org.apache.spark.mllib.rdd.DatasetMetadata import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.DTRegressorParams @@ -36,10 +37,13 @@ import org.apache.spark.rdd.RDD class DecisionTreeRegressor (params: DTRegressorParams) extends DecisionTree[DecisionTreeRegressorModel](params) { + private val impurityFunctor = params.impurity + /* private val impurityFunctor = params.impurity match { case "variance" => Variance case _ => throw new IllegalArgumentException(s"Bad impurity parameter for regression: ${params.impurity}") } + */ /** * Method to train a decision tree model over an RDD @@ -276,3 +280,100 @@ class DecisionTreeRegressor (params: DTRegressorParams) } } + + +object DecisionTreeRegressor extends Serializable with Logging { + + /** + * Train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param dsMeta Dataset metadata (number of features, number of classes, etc.) + * @param params The configuration parameters for the tree learning algorithm + * (tree depth, quantile calculation strategy, etc.) + * @return DecisionTreeRegressorModel which can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + dsMeta: DatasetMetadata, + params: DTRegressorParams = new DTRegressorParams()): DecisionTreeRegressorModel = { + require(dsMeta.numClasses >= 2) + new DecisionTreeRegressor(params).train(input, dsMeta) + } + + /** + * Train a decision tree model for binary or multiclass regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClasses Number of classes (label types) for regression. + * Default = 2 (binary regression). + * @param categoricalFeaturesInfo A map from each categorical variable to the + * number of discrete values it takes. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It is important to note that features are + * zero-indexed. + * Default = treat all features as continuous. + * @param params The configuration parameters for the tree learning algorithm + * (tree depth, quantile calculation strategy, etc.) + * @return DecisionTreeRegressorModel which can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + numClasses: Int = 2, + categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + params: DTRegressorParams = new DTRegressorParams()): DecisionTreeRegressorModel = { + + // Find the number of features by looking at the first sample. + val numFeatures = input.first().features.size + val dsMeta = new DatasetMetadata(numClasses, numFeatures, categoricalFeaturesInfo) + + train(input, dsMeta, params) + } + + // TODO: Move elsewhere! + protected def getImpurity(impurityName: String): RegressionImpurity = { + impurityName match { + case "gini" => Gini + case "entropy" => Entropy + case _ => throw new IllegalArgumentException( + s"Bad impurity parameter for regression: $impurityName") + } + } + + /** + * Train a decision tree model for binary or multiclass regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClasses Number of classes (label types) for regression. + * @param categoricalFeaturesInfo A map from each categorical variable to the + * number of discrete values it takes. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It is important to note that features are + * zero-indexed. + * @param impurityName Criterion used for information gain calculation + * @param maxDepth Maximum depth of the tree + * @param maxBins Maximum number of bins used for splitting features + * @param quantileStrategyName Algorithm for calculating quantiles + * @return DecisionTreeRegressorModel which can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + numClasses: Int, + categoricalFeaturesInfo: Map[Int, Int], + impurityName: String, + maxDepth: Int, + maxBins: Int, + quantileStrategyName: String, + maxMemoryInMB: Int): DecisionTreeRegressorModel = { + + val impurity = getImpurity(impurityName) + val quantileStrategy = getQuantileStrategy(quantileStrategyName) + val params = + new DTRegressorParams(impurity, maxDepth, maxBins, quantileStrategy, maxMemoryInMB) + train(input, numClasses, categoricalFeaturesInfo, params) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala index ce98a395e1f82..52013622ce1c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -19,24 +19,26 @@ package org.apache.spark.mllib.tree.configuration import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.configuration.DTParams +import org.apache.spark.mllib.tree.impurity.{ClassificationImpurity, Gini} +import org.apache.spark.mllib.tree.configuration.QuantileStrategy /** * :: Experimental :: * Stores all the configuration options for DecisionTreeClassifier construction + * @param impurity criterion used for information gain calculation (e.g., Gini or Entropy) * @param maxDepth maximum depth of the tree * @param maxBins maximum number of bins used for splitting features * @param quantileStrategy algorithm for calculating quantiles * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * @param impurity criterion used for information gain calculation (e.g., "gini" or "entropy") */ @Experimental class DTClassifierParams ( + val impurity: ClassificationImpurity = Gini, maxDepth: Int = 5, maxBins: Int = 100, - quantileStrategy: String = "sort", - maxMemoryInMB: Int = 128, - val impurity: String = "gini") + quantileStrategy: QuantileStrategy.QuantileStrategy = QuantileStrategy.Sort, + maxMemoryInMB: Int = 128) extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { /* diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala index 8721e431f7b3d..6287eeb7ebd0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala @@ -18,8 +18,7 @@ package org.apache.spark.mllib.tree.configuration import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.impurity.Impurity -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy /** * :: Experimental :: @@ -34,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ class DTParams ( val maxDepth: Int, val maxBins: Int, - val quantileStrategy: String, + val quantileStrategy: QuantileStrategy.QuantileStrategy, val maxMemoryInMB: Int) extends Serializable { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala index c2e0d537308e0..bd76d2d794ee7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.configuration import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.configuration.DTParams +import org.apache.spark.mllib.tree.impurity.{RegressionImpurity, Variance} +import org.apache.spark.mllib.tree.configuration.QuantileStrategy /** * :: Experimental :: @@ -28,15 +30,15 @@ import org.apache.spark.mllib.tree.configuration.DTParams * @param quantileStrategy algorithm for calculating quantiles * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * @param impurity criterion used for information gain calculation (e.g., "variance") + * @param impurity criterion used for information gain calculation (e.g., Variance) */ @Experimental class DTRegressorParams ( + val impurity: RegressionImpurity = Variance, maxDepth: Int = 5, maxBins: Int = 100, - quantileStrategy: String = "sort", - maxMemoryInMB: Int = 128, - val impurity: String = "variance") + quantileStrategy: QuantileStrategy.QuantileStrategy = QuantileStrategy.Sort, + maxMemoryInMB: Int = 128) extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { /* diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 7da976e55a722..052ef6c148d33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -27,4 +27,16 @@ import org.apache.spark.annotation.Experimental object QuantileStrategy extends Enumeration { type QuantileStrategy = Value val Sort, MinMax, ApproxHist = Value + + /** + * Given a string with the name of a quantile strategy, get the QuantileStrategy type. + */ + def getQuantileStrategy(strategyName: String): QuantileStrategy = { + strategyName match { + case "sort" => Sort + case _ => throw new IllegalArgumentException( + s"Bad QuantileStrategy parameter: $strategyName") + } + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala deleted file mode 100644 index 7c027ac2fda6b..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.configuration - -import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.impurity.Impurity -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ - -/** - * :: Experimental :: - * Stores all the configuration options for tree construction - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. - * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is - * 128 MB. - * - */ -@Experimental -class Strategy ( - val algo: Algo, - val impurity: Impurity, - val maxDepth: Int, - val numClassesForClassification: Int = 2, - val maxBins: Int = 100, - val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128) extends Serializable { - - require(numClassesForClassification >= 2) - val isMulticlassClassification = numClassesForClassification > 2 - val isMulticlassWithCategoricalFeatures - = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) - -} From 0ced13a5773e2973042a580a03ab4a9457fe3fe8 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 23 Jul 2014 13:07:55 -0700 Subject: [PATCH 03/20] Major changes to DecisionTree API and internals. Unit tests work. Still need to update documentation. Split classes: * DecisionTree --> DecisionTreeClassifier and DecisionTreeRegressor * DecisionTreeModel --> DecisionTreeClassifierModel, DecisionTreeRegressorModel * Super-classes DecisionTree, DecisionTreeModel are private to mllib. Included print() function for human-readable model descriptions * For: DecisionTreeClassifierModel, DecisionTreeRegressorModel, Node parameters (used to be named Strategy) * Split into: DTParams, DTClassifierParams, DTRegressorParams. * Added defaultParams() method to DecisionTreeClassifier/Regressor. * impurity ** Made private to mllib package. ** Split Impurity into ClassifierImpurity, RegressorImpurity ** Added factories: ClassifierImpurities, RegressorImpurities * QuantileStrategy: Added factory QuantileStrategies * maxDepth: Changed meaning by 1. Previously, depth = 1 meant 1 leaf node; now it means 1 internal and 2 leaf nodes. This matches scikit-learn and rpart. train() functions: * Changed to use DatasetInfo class for metadata. * Eliminated many of the static train() functions to prevent users from needing to remember the order of long lists of parameters. DecisionTree internals: * renamed numSplits to numBins (since it was a duplicate name) --- .../spark/examples/mllib/DTRunnerJKB.scala | 212 ------------ .../examples/mllib/DecisionTreeRunner.scala | 189 +++++++---- ...atasetMetadata.scala => DatasetInfo.scala} | 2 +- .../spark/mllib/tree/DecisionTree.scala | 176 ++++++---- .../mllib/tree/DecisionTreeClassifier.scala | 309 +++++++++--------- .../mllib/tree/DecisionTreeRegressor.scala | 163 +++------ .../spark/mllib/tree/configuration/Algo.scala | 30 -- .../configuration/DTClassifierParams.scala | 13 +- .../mllib/tree/configuration/DTParams.scala | 9 +- .../configuration/DTRegressorParams.scala | 10 +- .../tree/configuration/QuantileStrategy.scala | 19 +- .../impurity/ClassificationImpurities.scala | 66 ++++ .../tree/impurity/RegressionImpurities.scala | 65 ++++ .../spark/mllib/tree/impurity/Variance.scala | 15 +- .../apache/spark/mllib/tree/model/Node.scala | 2 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 246 +++++++++----- 16 files changed, 735 insertions(+), 791 deletions(-) delete mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala rename mllib/src/main/scala/org/apache/spark/mllib/rdd/{DatasetMetadata.scala => DatasetInfo.scala} (98%) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurities.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurities.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala deleted file mode 100644 index a90f3da10e7af..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DTRunnerJKB.scala +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib - -import scopt.OptionParser - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree, impurity} -import org.apache.spark.mllib.tree.configuration.{Algo, DTParams} -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD - -/** - * An example runner for decision tree. Run with - * {{{ - * ./bin/spark-example org.apache.spark.examples.mllib.DTRunnerJKB [options] - * }}} - * If you use it as a template to create your own app, please use `spark-submit` to submit your app. - */ -object DTRunnerJKB { - - object ImpurityType extends Enumeration { - type ImpurityType = Value - val Gini, Entropy, Variance = Value - } - - import ImpurityType._ - - case class Params( - input: String = null, - dataFormat: String = null, - algo: Algo = Classification, - maxDepth: Int = 5, - impurity: ImpurityType = Gini, - maxBins: Int = 100, - fracTest: Double = 0.2) - - def main(args: Array[String]) { - val defaultParams = Params() - - val parser = new OptionParser[Params]("DTRunnerJKB") { - head("DTRunnerJKB: an example decision tree app.") - opt[String]("algo") - .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") - .action((x, c) => c.copy(algo = Algo.withName(x))) - opt[String]("impurity") - .text(s"impurity type (${ImpurityType.values.mkString(",")}), " + - s"default: ${defaultParams.impurity}") - .action((x, c) => c.copy(impurity = ImpurityType.withName(x))) - opt[Int]("maxDepth") - .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") - .action((x, c) => c.copy(maxDepth = x)) - opt[Int]("maxBins") - .text(s"max number of bins, default: ${defaultParams.maxBins}") - .action((x, c) => c.copy(maxBins = x)) - opt[Double]("fracTest") - .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") - .action((x, c) => c.copy(fracTest = x)) - arg[String]("") - .text("input paths to labeled examples") - .required() - .action((x, c) => c.copy(input = x)) - arg[String]("") - .text("data format: dense/libsvm") - .required() - .action((x, c) => c.copy(dataFormat = x)) - checkConfig { params => - if ( - (params.algo == Classification && - !(params.impurity == Gini || params.impurity == Entropy)) || - (params.algo == Regression && !(params.impurity == Variance))) { - failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") - } - if (params.fracTest < 0 || params.fracTest > 1) { - failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") - } - success - } - } - - parser.parse(args, defaultParams).map { params => - run(params) - }.getOrElse { - sys.exit(1) - } - } - - def run(params: Params) { - val conf = new SparkConf().setAppName("DTRunnerJKB") - val sc = new SparkContext(conf) - - // Load training data and cache it. - val origExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledData(sc, params.input).cache() - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache() - } - val (examples, numClasses) = params.algo match { - case Classification => { - // classCounts: class --> # examples in class - val classCounts = origExamples.map(_.label).countByValue - val numClasses = classCounts.size - // Re-index classes if needed. - // classIndex: class --> index in 0,...,numClasses-1 - val classIndex = { - if (classCounts.keySet != Set[Double](0.0, 1.0)) { - classCounts.keys.toList.sorted.zipWithIndex.toMap - } else { - Map[Double, Int]() - } - } - val examples = { - if (classIndex.isEmpty) { - origExamples - } else { - origExamples.map(lp => LabeledPoint(classIndex(lp.label), lp.features)) - } - } - println(s"numClasses = $numClasses.") - println(s"Per-class example fractions, counts:") - println(s"Class\tFrac\tCount") - classCounts.keys.toList.sorted.foreach(c => { - val frac = classCounts(c) / (0.0 + examples.count()) - println(s"$c\t$frac\t${classCounts(c)}") - }) - (examples, numClasses) - } - case Regression => { - (origExamples, 2) - } - } - - // Split into training, test. - val splits = examples.randomSplit(Array(0.8, 0.2)) - val training = splits(0).cache() - val test = splits(1).cache() - val numTraining = training.count() - val numTest = test.count() - - println(s"numTraining = $numTraining, numTest = $numTest.") - - examples.unpersist(blocking = false) - - val impurityCalculator = params.impurity match { - case Gini => impurity.Gini - case Entropy => impurity.Entropy - case Variance => impurity.Variance - } - - val strategy - = new DTParams( - algo = params.algo, - impurity = impurityCalculator, - maxDepth = params.maxDepth, - maxBins = params.maxBins, - numClasses = numClasses) - val model = DecisionTree.train(training, strategy) - model.print() - - if (params.algo == Classification) { - val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy.") - } - - if (params.algo == Regression) { - val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse.") - } - - sc.stop() - } - - /** - * Calculates the classifier accuracy. - */ - private def accuracyScore( - model: DecisionTreeModel, - data: RDD[LabeledPoint]): Double = { - val correctCount = data.filter(y => model.predict(y.features) == y.label).count() - val count = data.count() - correctCount.toDouble / count - } - - /** - * Calculates the mean squared error for regression. - */ - private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = tree.predict(y.features) - y.label - err * err - }.mean() - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 04f5e244052ed..cc1b2e94bb6bc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -21,11 +21,11 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.rdd.DatasetInfo import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree, impurity} -import org.apache.spark.mllib.tree.configuration.{Algo, DTParams} -import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.{DecisionTreeClassifier, DecisionTreeRegressor} +import org.apache.spark.mllib.tree.configuration.{DTClassifierParams, DTRegressorParams} +import org.apache.spark.mllib.tree.impurity.{ClassificationImpurities, RegressionImpurities} import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -36,54 +36,73 @@ import org.apache.spark.rdd.RDD * ./bin/spark-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify + * [[org.apache.spark.mllib.rdd.DatasetInfo.categoricalFeaturesInfo]]. */ object DecisionTreeRunner { - object ImpurityType extends Enumeration { - type ImpurityType = Value - val Gini, Entropy, Variance = Value - } - - import ImpurityType._ - case class Params( input: String = null, - algo: Algo = Classification, + dataFormat: String = null, + algo: String = "classification", + impurity: Option[String] = None, maxDepth: Int = 5, - impurity: ImpurityType = Gini, - maxBins: Int = 100) + maxBins: Int = 100, + fracTest: Double = 0.2) def main(args: Array[String]) { val defaultParams = Params() + val defaultCImpurity = ClassificationImpurities.impurityName(new DTClassifierParams().impurity) + val defaultRImpurity = RegressionImpurities.impurityName(new DTRegressorParams().impurity) val parser = new OptionParser[Params]("DecisionTreeRunner") { head("DecisionTreeRunner: an example decision tree app.") opt[String]("algo") - .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") - .action((x, c) => c.copy(algo = Algo.withName(x))) + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) opt[String]("impurity") - .text(s"impurity type (${ImpurityType.values.mkString(",")}), " + - s"default: ${defaultParams.impurity}") - .action((x, c) => c.copy(impurity = ImpurityType.withName(x))) + .text( + s"impurity type\n" + + s"\tFor classification: ${ClassificationImpurities.names.mkString(",")}\n" + + s"\t default: $defaultCImpurity" + + s"\tFor regression: ${RegressionImpurities.names.mkString(",")}\n" + + s"\t default: $defaultRImpurity") + .action((x, c) => c.copy(impurity = Some(x))) opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) arg[String]("") - .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)") + .text("input paths to labeled examples") .required() .action((x, c) => c.copy(input = x)) + arg[String]("") + .text("data format: dense/libsvm") + .required() + .action((x, c) => c.copy(dataFormat = x)) checkConfig { params => - if (params.algo == Classification && - (params.impurity == Gini || params.impurity == Entropy)) { - success - } else if (params.algo == Regression && params.impurity == Variance) { - success - } else { - failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + if (!List("classification", "regression").contains(params.algo)) { + failure(s"Did not recognize Algo: ${params.algo}") + } + if (params.impurity != None) { + if ((params.algo == "classification" && + !ClassificationImpurities.names.contains(params.impurity)) || + (params.algo == "regression" && + !RegressionImpurities.names.contains(params.impurity))) { + failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + } + } + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") } + success } } @@ -99,12 +118,52 @@ object DecisionTreeRunner { val sc = new SparkContext(conf) // Load training data and cache it. - val examples = MLUtils.loadLabeledPoints(sc, params.input).cache() + val origExamples = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache() + } + // For classification, re-index classes if needed. + val (examples, numClasses) = params.algo match { + case "classification" => { + // classCounts: class --> # examples in class + val classCounts = origExamples.map(_.label).countByValue + val numClasses = classCounts.size + // classIndex: class --> index in 0,...,numClasses-1 + val classIndex = { + if (classCounts.keySet != Set[Double](0.0, 1.0)) { + classCounts.keys.toList.sorted.zipWithIndex.toMap + } else { + Map[Double, Int]() + } + } + val examples = { + if (classIndex.isEmpty) { + origExamples + } else { + origExamples.map(lp => LabeledPoint(classIndex(lp.label), lp.features)) + } + } + println(s"numClasses = $numClasses.") + println(s"Per-class example fractions, counts:") + println(s"Class\tFrac\tCount") + classCounts.keys.toList.sorted.foreach(c => { + val frac = classCounts(c) / (0.0 + examples.count()) + println(s"$c\t$frac\t${classCounts(c)}") + }) + (examples, numClasses) + } + case "regression" => { + (origExamples, 0) + } + case _ => { + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + } + // Split into training, test. val splits = examples.randomSplit(Array(0.8, 0.2)) val training = splits(0).cache() val test = splits(1).cache() - val numTraining = training.count() val numTest = test.count() @@ -112,33 +171,39 @@ object DecisionTreeRunner { examples.unpersist(blocking = false) - val impurityCalculator = params.impurity match { - case Gini => impurity.Gini - case Entropy => impurity.Entropy - case Variance => impurity.Variance - } - -<<<<<<< HEAD - val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins) -======= - val strategy - = new DTParams( - algo = params.algo, - impurity = impurityCalculator, - maxDepth = params.maxDepth, - maxBins = params.maxBins, - numClassesForClassification = params.numClassesForClassification) ->>>>>>> 8725f7b... updating DT API, but not done yet - val model = DecisionTree.train(training, strategy) - - if (params.algo == Classification) { - val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy.") - } + val numFeatures = examples.take(1)(0).features.size + val datasetInfo = new DatasetInfo(numClasses, numFeatures) - if (params.algo == Regression) { - val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse.") + params.algo match { + case "classification" => { + val dtParams = DecisionTreeClassifier.defaultParams() + dtParams.maxDepth = params.maxDepth + dtParams.maxBins = params.maxBins + if (params.impurity != None) { + dtParams.impurity = ClassificationImpurities.impurity(params.impurity.get) + } + val dtLearner = new DecisionTreeClassifier(dtParams) + val model = dtLearner.train(training, datasetInfo) + model.print() + val accuracy = accuracyScore(model, test) + println(s"Test accuracy = $accuracy.") + } + case "regression" => { + val dtParams = DecisionTreeRegressor.defaultParams() + dtParams.maxDepth = params.maxDepth + dtParams.maxBins = params.maxBins + if (params.impurity != None) { + dtParams.impurity = RegressionImpurities.impurity(params.impurity.get) + } + val dtLearner = new DecisionTreeRegressor(dtParams) + val model = dtLearner.train(training, datasetInfo) + model.print() + val mse = meanSquaredError(model, test) + println(s"Test mean squared error = $mse.") + } + case _ => { + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } } sc.stop() @@ -149,12 +214,8 @@ object DecisionTreeRunner { */ private def accuracyScore( model: DecisionTreeModel, - data: RDD[LabeledPoint], - threshold: Double = 0.5): Double = { - def predictedValue(features: Vector): Double = { - if (model.predict(features) < threshold) 0.0 else 1.0 - } - val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() + data: RDD[LabeledPoint]): Double = { + val correctCount = data.filter(y => model.predict(y.features) == y.label).count() val count = data.count() correctCount.toDouble / count } @@ -162,9 +223,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + private def meanSquaredError( + model: DecisionTreeModel, + data: RDD[LabeledPoint]): Double = { data.map { y => - val err = tree.predict(y.features) - y.label + val err = model.predict(y.features) - y.label err * err }.mean() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetInfo.scala similarity index 98% rename from mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetMetadata.scala rename to mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetInfo.scala index c6d996b0659ba..b6643b7aae6ea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetInfo.scala @@ -28,7 +28,7 @@ package org.apache.spark.mllib.rdd * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. */ -class DatasetMetadata ( +class DatasetInfo ( val numClasses: Int, val numFeatures: Int, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 36c4f24e1d47f..e467d10d36581 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,10 +19,9 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.mllib.rdd.DatasetMetadata +import org.apache.spark.mllib.rdd.DatasetInfo import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.DTParams -import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy import org.apache.spark.mllib.tree.model._ @@ -34,9 +33,7 @@ import org.apache.spark.util.random.XORShiftRandom * :: Experimental :: * A class that implements a decision tree algorithm for classification and regression. It * supports both continuous and categorical features. - * @param params The configuration parameters for the tree algorithm which specify the type - * of algorithm (classification, regression, etc.), feature type (continuous, - * categorical), depth of the tree, quantile calculation strategy, etc. + * @param params The configuration parameters for the tree algorithm. */ @Experimental private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected val params: DTParams) @@ -47,19 +44,19 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va /** * Method to train a decision tree model over an RDD * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @param dsMeta Dataset metadata. + * @param datasetInfo Dataset metadata. * @return top node of a DecisionTreeModel */ protected def trainSub( input: RDD[LabeledPoint], - dsMeta: DatasetMetadata): Node = { + datasetInfo: DatasetInfo): Node = { // Cache input RDD for speedup during multiple passes. input.cache() // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = findSplitsBins(input, dsMeta) + val (splits, bins) = findSplitsBins(input, datasetInfo) val numBins = bins(0).length logDebug("numBins = " + numBins) @@ -81,7 +78,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va // Max memory usage for aggregates val maxMemoryUsage = params.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = getElementsPerNode(dsMeta, numBins) + val numElementsPerNode = getElementsPerNode(datasetInfo, numBins) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) @@ -109,7 +106,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = findBestSplits(input, dsMeta, parentImpurities, + val splitsStatsForLevel = findBestSplits(input, datasetInfo, parentImpurities, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { @@ -157,34 +154,44 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va protected def computeCentroidForCategories( featureIndex: Int, sampledInput: Array[LabeledPoint], - dsMeta: DatasetMetadata): Map[Double,Double] + datasetInfo: DatasetInfo): Map[Double,Double] /** * Extracts left and right split aggregates. - * @param binData Array[Double] of size 2*numFeatures*numSplits + * @param binData Array[Double] of size 2 * numFeatures * numBins * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, * (numBins - 1), numClasses) */ protected def extractLeftRightNodeAggregates( binData: Array[Double], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numBins: Int): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) /** * Get the number of stats elements stored per node in bin aggregates. */ protected def getElementsPerNode( - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numBins: Int): Int /** * Performs a sequential aggregation of bins stats over a partition. + * + * @param agg Array[Double] storing aggregate calculation of size + * numClasses * numBins * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @param datasetInfo Dataset metadata. + * @param numNodes Number of nodes in this (level of tree, group), + * where nodes at deeper (larger) levels may be divided into groups. + * @param bins Number of bins = 1 + number of possible splits. + * @return agg */ protected def binSeqOpSub( agg: Array[Double], arr: Array[Double], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numNodes: Int, bins: Array[Array[Bin]]): Array[Double] @@ -203,7 +210,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va splitIndex: Int, rightNodeAgg: Array[Array[Array[Double]]], topImpurity: Double, - numClasses: Int, + datasetInfo: DatasetInfo, level: Int): InformationGainStats /** @@ -212,7 +219,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va protected def getBinDataForNode( node: Int, binAggregates: Array[Double], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numNodes: Int, numBins: Int): Array[Double] @@ -223,7 +230,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va /** * Extract the decision tree node information for the given tree level and node index */ - private def extractNodeInfo( + protected def extractNodeInfo( nodeSplitStats: (Split, InformationGainStats), level: Int, index: Int, @@ -240,7 +247,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va /** * Extract the decision tree node information for the children of the node */ - private def extractInfoForLowerLevels( + protected def extractInfoForLowerLevels( level: Int, index: Int, maxDepth: Int, @@ -278,7 +285,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree - * @param dsMeta Metadata for input. + * @param datasetInfo Metadata for input. * @param parentImpurities Impurities for all parent nodes for the current level * @param level Level of the tree * @param filters Filters for all nodes at a given level @@ -287,9 +294,9 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. * @return array of splits with best splits for all nodes at a given level. */ - private def findBestSplits( + protected[tree] def findBestSplits( input: RDD[LabeledPoint], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, parentImpurities: Array[Double], level: Int, filters: Array[List[Filter]], @@ -309,14 +316,14 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { - val bestSplitsForGroup = findBestSplitsPerGroup(input, dsMeta, parentImpurities, level, + val bestSplitsForGroup = findBestSplitsPerGroup(input, datasetInfo, parentImpurities, level, filters, splits, bins, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, dsMeta, parentImpurities, level, filters, splits, bins) + findBestSplitsPerGroup(input, datasetInfo, parentImpurities, level, filters, splits, bins) } } @@ -325,7 +332,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree - * @param dsMeta Metadata for input. + * @param datasetInfo Metadata for input. * @param parentImpurities Impurities for all parent nodes for the current level * @param level Level of the tree * @param filters Filters for all nodes at a given level @@ -335,9 +342,9 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. */ - private def findBestSplitsPerGroup( + protected def findBestSplitsPerGroup( input: RDD[LabeledPoint], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, parentImpurities: Array[Double], level: Int, filters: Array[List[Filter]], @@ -370,17 +377,18 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va * drastically reduce the communication overhead. */ - // common calculations for multiple nested methods + // Common calculations for multiple nested methods: + + // Number of nodes to handle for each group in this level. val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) - //logDebug("numFeatures = " + numFeatures) + logDebug("numFeatures = " + datasetInfo.numFeatures) val numBins = bins(0).length logDebug("numBins = " + numBins) - //val numClasses = dsMeta.numClasses - //logDebug("numClasses = " + numClasses) - val isMulticlass = dsMeta.isMulticlass + logDebug("numClasses = " + datasetInfo.numClasses) + val isMulticlass = datasetInfo.isMulticlass logDebug("isMulticlass = " + isMulticlass) - val isMulticlassWithCategoricalFeatures = dsMeta.isMulticlassWithCategoricalFeatures + val isMulticlassWithCategoricalFeatures = datasetInfo.isMulticlassWithCategoricalFeatures logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level @@ -435,7 +443,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va } /** - * Find bin for one feature. + * Find bin for one (labeledPoint, feature). */ def findBin( featureIndex: Int, @@ -483,7 +491,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va * Sequential search helper method to find bin for categorical feature. */ def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { - val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { @@ -528,10 +536,16 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va * where b_ij is an integer between 0 and numBins - 1 for regressions and binary * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. + * + * @return Array of size 1 + numFeatures * numNodes, where + * arr(0) = label for labeledPoint, and + * arr(1 + numFeatures * nodeIndex + featureIndex) = + * bin index for this labeledPoint + * (or InvalidBinIndex if labeledPoint is not handled by this node). */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. - val arr = new Array[Double](1 + (dsMeta.numFeatures * numNodes)) + val arr = new Array[Double](1 + (datasetInfo.numFeatures * numNodes)) // First element of the array is the label of the instance. arr(0) = labeledPoint.label // Iterate over nodes. @@ -540,14 +554,14 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va val parentFilters = findParentFilters(nodeIndex) // Find out whether the sample qualifies for the particular node. val sampleValid = isSampleValid(parentFilters, labeledPoint) - val shift = 1 + dsMeta.numFeatures * nodeIndex + val shift = 1 + datasetInfo.numFeatures * nodeIndex if (!sampleValid) { // Mark one bin as -1 is sufficient. arr(shift) = InvalidBinIndex } else { var featureIndex = 0 - while (featureIndex < dsMeta.numFeatures) { - val featureInfo = dsMeta.categoricalFeaturesInfo.get(featureIndex) + while (featureIndex < datasetInfo.numFeatures) { + val featureInfo = datasetInfo.categoricalFeaturesInfo.get(featureIndex) val isFeatureContinuous = featureInfo.isEmpty if (isFeatureContinuous) { arr(shift + featureIndex) @@ -558,7 +572,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va = numBins > math.pow(2, featureCategories.toInt - 1) - 1 arr(shift + featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, - isSpaceSufficientForAllCategoricalSplits) + isSpaceSufficientForAllCategoricalSplits) } featureIndex += 1 } @@ -573,11 +587,11 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va // Performs a sequential aggregation over a partition. def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { - binSeqOpSub(agg, arr, dsMeta, numNodes, bins) + binSeqOpSub(agg, arr, datasetInfo, numNodes, bins) } // Calculate bin aggregate length for classification or regression. - val binAggregateLength = numNodes * getElementsPerNode(dsMeta, numBins) + val binAggregateLength = numNodes * getElementsPerNode(datasetInfo, numBins) logDebug("binAggregateLength = " + binAggregateLength) /** @@ -609,12 +623,12 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], nodeImpurity: Double): Array[Array[InformationGainStats]] = { - val gains = Array.ofDim[InformationGainStats](dsMeta.numFeatures, numBins - 1) + val gains = Array.ofDim[InformationGainStats](datasetInfo.numFeatures, numBins - 1) - for (featureIndex <- 0 until dsMeta.numFeatures) { + for (featureIndex <- 0 until datasetInfo.numFeatures) { for (splitIndex <- 0 until numBins - 1) { gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, - splitIndex, rightNodeAgg, nodeImpurity, dsMeta.numClasses, level) + splitIndex, rightNodeAgg, nodeImpurity, datasetInfo, level) } } gains @@ -622,7 +636,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va /** * Find the best split for a node. - * @param binData Array[Double] of size 2 * numSplits * numFeatures + * @param binData Array[Double] of size 2 * numBins * numFeatures * @param nodeImpurity impurity of the top node * @return tuple of split and information gain */ @@ -633,7 +647,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va logDebug("node impurity = " + nodeImpurity) // Extract left right node aggregates. - val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData, dsMeta, numBins) + val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData, datasetInfo, numBins) // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) @@ -645,15 +659,15 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) // Iterate over features. var featureIndex = 0 - while (featureIndex < dsMeta.numFeatures) { + while (featureIndex < datasetInfo.numFeatures) { // Iterate over all splits. var splitIndex = 0 val maxSplitIndex : Double = { - val isFeatureContinuous = dsMeta.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = datasetInfo.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { numBins - 1 } else { // Categorical feature - val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) val isSpaceSufficientForAllCategoricalSplits = numBins > math.pow(2, featureCategories.toInt - 1) - 1 if (isMulticlass && isSpaceSufficientForAllCategoricalSplits) { @@ -690,7 +704,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va while (node < numNodes) { val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift val binsForNode: Array[Double] - = getBinDataForNode(node, binAggregates, dsMeta, numNodes, numBins) + = getBinDataForNode(node, binAggregates, datasetInfo, numNodes, numBins) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) logDebug("parent node impurity = " + parentNodeImpurity) @@ -701,25 +715,46 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va } /** - * Returns split and bins for decision tree calculation. - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree - * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree - * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache - * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) + * Returns splits and bins for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * There are 2^(maxFeatureValue - 1) - 1 splits. + * (b) For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one split per category. + * + * Categorical case (a) features are called unordered features. + * Other cases are called ordered features. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return A tuple of (splits,bins). + * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] + * of size (numFeatures, numBins - 1). + * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], - dsMeta: DatasetMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { + datasetInfo: DatasetInfo): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() - val numFeatures = dsMeta.numFeatures + val numFeatures = datasetInfo.numFeatures val maxBins = params.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) - val isMulticlass = dsMeta.isMulticlass + val isMulticlass = datasetInfo.isMulticlass logDebug("isMulticlass = " + isMulticlass) @@ -729,8 +764,8 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va * It's a limitation of the current implementation but a reasonable trade-off since features * with large number of categories get favored over continuous features. */ - if (dsMeta.categoricalFeaturesInfo.size > 0) { - val maxCategoriesForFeatures = dsMeta.categoricalFeaturesInfo.maxBy(_._2)._2 + if (datasetInfo.categoricalFeaturesInfo.size > 0) { + val maxCategoriesForFeatures = datasetInfo.categoricalFeaturesInfo.maxBy(_._2)._2 require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + "in categorical features") } @@ -760,7 +795,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va var featureIndex = 0 while (featureIndex < numFeatures){ // Check whether the feature is continuous. - val isFeatureContinuous = dsMeta.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = datasetInfo.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride: Double = numSamples.toDouble / numBins @@ -771,14 +806,14 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va splits(featureIndex)(index) = split } } else { // Categorical feature - val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) val isSpaceSufficientForAllCategoricalSplits = numBins > math.pow(2, featureCategories.toInt - 1) - 1 // Use different bin/split calculation strategy for categorical features in multiclass // classification that satisfy the space constraint if (isMulticlass && isSpaceSufficientForAllCategoricalSplits) { - // 2^(maxFeatureValue- 1) - 1 combinations + // 2^(maxFeatureValue - 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { val categories: List[Double] = @@ -803,9 +838,8 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va index += 1 } } else { - val centroidForCategories = - computeCentroidForCategories(featureIndex, sampledInput, dsMeta) + computeCentroidForCategories(featureIndex, sampledInput, datasetInfo) logDebug("centroid for categories = " + centroidForCategories.mkString(",")) @@ -835,7 +869,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va new Bin(new DummyCategoricalSplit(featureIndex, Categorical), splits(featureIndex)(0), Categorical, key) } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + new Bin(splits(featureIndex)(index - 1), splits(featureIndex)(index), Categorical, key) } } @@ -848,16 +882,16 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va // Find all bins. featureIndex = 0 while (featureIndex < numFeatures) { - val isFeatureContinuous = dsMeta.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = datasetInfo.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) for (index <- 1 until numBins - 1){ - val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + val bin = new Bin(splits(featureIndex)(index - 1), splits(featureIndex)(index), Continuous, Double.MinValue) bins(featureIndex)(index) = bin } - bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), + bins(featureIndex)(numBins - 1) = new Bin(splits(featureIndex)(numBins - 2), new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) } featureIndex += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala index fca8725947888..963d2ce5ea5ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala @@ -19,16 +19,12 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.mllib.rdd.DatasetMetadata +import org.apache.spark.mllib.rdd.DatasetInfo import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.DTClassifierParams -import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.impurity.{ClassificationImpurity, Entropy, Gini} +//import org.apache.spark.mllib.tree.impurity.{ClassificationImpurity, ClassificationImpurities} import org.apache.spark.mllib.tree.model.{InformationGainStats, Bin, DecisionTreeClassifierModel} import org.apache.spark.rdd.RDD -import org.apache.spark.util.random.XORShiftRandom /** @@ -46,17 +42,17 @@ class DecisionTreeClassifier (params: DTClassifierParams) /** * Method to train a decision tree model over an RDD * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @param dsMeta Dataset metadata specifying number of classes, features, etc. + * @param datasetInfo Dataset metadata specifying number of classes, features, etc. * @return a DecisionTreeClassifierModel that can be used for prediction */ def train( input: RDD[LabeledPoint], - dsMeta: DatasetMetadata): DecisionTreeClassifierModel = { + datasetInfo: DatasetInfo): DecisionTreeClassifierModel = { - require(dsMeta.isClassification) + require(datasetInfo.isClassification) logDebug("algo = Classification") - val topNode = super.trainSub(input, dsMeta) + val topNode = super.trainSub(input, datasetInfo) new DecisionTreeClassifierModel(topNode) } @@ -67,8 +63,8 @@ class DecisionTreeClassifier (params: DTClassifierParams) protected def computeCentroidForCategories( featureIndex: Int, sampledInput: Array[LabeledPoint], - dsMeta: DatasetMetadata): Map[Double,Double] = { - if (dsMeta.isMulticlass) { + datasetInfo: DatasetInfo): Map[Double,Double] = { + if (datasetInfo.isMulticlass) { // For categorical variables in multiclass classification, // each bin is a category. The bins are sorted and they // are ordered by calculating the impurity of their corresponding labels. @@ -89,14 +85,14 @@ class DecisionTreeClassifier (params: DTClassifierParams) /** * Extracts left and right split aggregates. - * @param binData Array[Double] of size 2*numFeatures*numSplits + * @param binData Array[Double] of size 2 * numFeatures * numBins * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, * (numBins - 1), numClasses) */ protected def extractLeftRightNodeAggregates( binData: Array[Double], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numBins: Int): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { def findAggForOrderedFeatureClassification( @@ -105,15 +101,15 @@ class DecisionTreeClassifier (params: DTClassifierParams) featureIndex: Int) { // shift for this featureIndex - val shift = dsMeta.numClasses * featureIndex * numBins + val shift = datasetInfo.numClasses * featureIndex * numBins var classIndex = 0 - while (classIndex < dsMeta.numClasses) { + while (classIndex < datasetInfo.numClasses) { // left node aggregate for the lowest split leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) // right node aggregate for the highest split rightNodeAgg(featureIndex)(numBins - 2)(classIndex) - = binData(shift + (dsMeta.numClasses * (numBins - 1)) + classIndex) + = binData(shift + (datasetInfo.numClasses * (numBins - 1)) + classIndex) classIndex += 1 } @@ -123,12 +119,12 @@ class DecisionTreeClassifier (params: DTClassifierParams) // calculating left node aggregate for a split as a sum of left node aggregate of a // lower split and the left bin aggregate of a bin where the split is a high split var innerClassIndex = 0 - while (innerClassIndex < dsMeta.numClasses) { + while (innerClassIndex < datasetInfo.numClasses) { leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) - = binData(shift + dsMeta.numClasses * splitIndex + innerClassIndex) + + = binData(shift + datasetInfo.numClasses * splitIndex + innerClassIndex) + leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = - binData(shift + (dsMeta.numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + + binData(shift + (datasetInfo.numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) innerClassIndex += 1 } @@ -141,14 +137,14 @@ class DecisionTreeClassifier (params: DTClassifierParams) rightNodeAgg: Array[Array[Array[Double]]], featureIndex: Int) { - val rightChildShift = dsMeta.numClasses * numBins * dsMeta.numFeatures + val rightChildShift = datasetInfo.numClasses * numBins * datasetInfo.numFeatures var splitIndex = 0 while (splitIndex < numBins - 1) { var classIndex = 0 - while (classIndex < dsMeta.numClasses) { + while (classIndex < datasetInfo.numClasses) { // shift for this featureIndex val shift = - dsMeta.numClasses * featureIndex * numBins + splitIndex * dsMeta.numClasses + datasetInfo.numClasses * featureIndex * numBins + splitIndex * datasetInfo.numClasses val leftBinValue = binData(shift + classIndex) val rightBinValue = binData(rightChildShift + shift + classIndex) leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue @@ -161,19 +157,19 @@ class DecisionTreeClassifier (params: DTClassifierParams) // Initialize left and right split aggregates. val leftNodeAgg = - Array.ofDim[Double](dsMeta.numFeatures, numBins - 1, dsMeta.numClasses) + Array.ofDim[Double](datasetInfo.numFeatures, numBins - 1, datasetInfo.numClasses) val rightNodeAgg = - Array.ofDim[Double](dsMeta.numFeatures, numBins - 1, dsMeta.numClasses) + Array.ofDim[Double](datasetInfo.numFeatures, numBins - 1, datasetInfo.numClasses) var featureIndex = 0 - while (featureIndex < dsMeta.numFeatures) { - if (dsMeta.isMulticlassWithCategoricalFeatures){ - val isFeatureContinuous = dsMeta.categoricalFeaturesInfo.get(featureIndex).isEmpty + while (featureIndex < datasetInfo.numFeatures) { + if (datasetInfo.isMulticlassWithCategoricalFeatures){ + val isFeatureContinuous = datasetInfo.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } else { - val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 if (isSpaceSufficientForAllCategoricalSplits) { findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } else { @@ -189,13 +185,19 @@ class DecisionTreeClassifier (params: DTClassifierParams) (leftNodeAgg, rightNodeAgg) } + /** + * Get number of values to be stored per node in the bin aggregate counts. + * @param datasetInfo Dataset metadata + * @param numBins Number of bins = 1 + number of possible splits. + * @return + */ protected def getElementsPerNode( - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numBins: Int): Int = { - if (dsMeta.isMulticlassWithCategoricalFeatures) { - 2 * dsMeta.numClasses * numBins * dsMeta.numFeatures + if (datasetInfo.isMulticlassWithCategoricalFeatures) { + 2 * datasetInfo.numClasses * numBins * datasetInfo.numFeatures } else { - dsMeta.numClasses * numBins * dsMeta.numFeatures + datasetInfo.numClasses * numBins * datasetInfo.numFeatures } } @@ -204,24 +206,26 @@ class DecisionTreeClassifier (params: DTClassifierParams) * For l nodes, k features, * either the left count or the right count of one of the p bins is * incremented based upon whether the feature is classified as 0 or 1. - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation, of size: - * 2 * numSplits * numFeatures * numNodes for ordered features, or - * 2 * numClasses * numSplits * numFeatures * numNodes for unordered features + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes + * TODO: FIX DOC + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return Array storing aggregate calculation, of size: + * 2 * numBins * numFeatures * numNodes for ordered features, or + * 2 * numClasses * numBins * numFeatures * numNodes for unordered features */ protected def binSeqOpSub( agg: Array[Double], arr: Array[Double], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numNodes: Int, bins: Array[Array[Bin]]): Array[Double] = { val numBins = bins(0).length - if(dsMeta.isMulticlassWithCategoricalFeatures) { - unorderedClassificationBinSeqOp(arr, agg, dsMeta, numNodes, bins) + if(datasetInfo.isMulticlassWithCategoricalFeatures) { + unorderedClassificationBinSeqOp(arr, agg, datasetInfo, numNodes, bins) } else { - orderedClassificationBinSeqOp(arr, agg, dsMeta, numNodes, numBins) + orderedClassificationBinSeqOp(arr, agg, datasetInfo, numNodes, numBins) } agg } @@ -241,14 +245,16 @@ class DecisionTreeClassifier (params: DTClassifierParams) splitIndex: Int, rightNodeAgg: Array[Array[Array[Double]]], topImpurity: Double, - numClasses: Int, + datasetInfo: DatasetInfo, level: Int): InformationGainStats = { - var classIndex = 0 + val numClasses = datasetInfo.numClasses + val leftCounts: Array[Double] = new Array[Double](numClasses) val rightCounts: Array[Double] = new Array[Double](numClasses) var leftTotalCount = 0.0 var rightTotalCount = 0.0 + var classIndex = 0 while (classIndex < numClasses) { val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) @@ -323,24 +329,26 @@ class DecisionTreeClassifier (params: DTClassifierParams) protected def getBinDataForNode( node: Int, binAggregates: Array[Double], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numNodes: Int, numBins: Int): Array[Double] = { - if (dsMeta.isMulticlassWithCategoricalFeatures) { - val shift = dsMeta.numClasses * node * numBins * dsMeta.numFeatures - val rightChildShift = dsMeta.numClasses * numBins * dsMeta.numFeatures * numNodes + if (datasetInfo.isMulticlassWithCategoricalFeatures) { + val shift = datasetInfo.numClasses * node * numBins * datasetInfo.numFeatures + val rightChildShift = datasetInfo.numClasses * numBins * datasetInfo.numFeatures * numNodes val binsForNode = { val leftChildData - = binAggregates.slice(shift, shift + dsMeta.numClasses * numBins * dsMeta.numFeatures) + = binAggregates.slice(shift, shift + datasetInfo.numClasses * numBins * datasetInfo.numFeatures) val rightChildData = binAggregates.slice(rightChildShift + shift, - rightChildShift + shift + dsMeta.numClasses * numBins * dsMeta.numFeatures) + rightChildShift + shift + datasetInfo.numClasses * numBins * datasetInfo.numFeatures) leftChildData ++ rightChildData } binsForNode } else { - val shift = dsMeta.numClasses * node * numBins * dsMeta.numFeatures - val binsForNode = binAggregates.slice(shift, shift + dsMeta.numClasses * numBins * dsMeta.numFeatures) + val shift = datasetInfo.numClasses * node * numBins * datasetInfo.numFeatures + val binsForNode = binAggregates.slice( + shift, + shift + datasetInfo.numClasses * numBins * datasetInfo.numFeatures) binsForNode } } @@ -349,57 +357,68 @@ class DecisionTreeClassifier (params: DTClassifierParams) // Private methods //=========================================================================== + /** + * Increment aggregate in location for (node, feature, bin, label) + * to indicate that, for this (example, + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. + */ private def updateBinForOrderedFeature( arr: Array[Double], agg: Array[Double], nodeIndex: Int, label: Double, featureIndex: Int, - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numBins: Int) = { // Find the bin index for this feature. - val arrShift = 1 + dsMeta.numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex + val arrIndex = 1 + datasetInfo.numFeatures * nodeIndex + featureIndex // Update the left or right count for one bin. - val aggShift = dsMeta.numClasses * numBins * dsMeta.numFeatures * nodeIndex - val aggIndex - = aggShift + dsMeta.numClasses * featureIndex * numBins - + arr(arrIndex).toInt * dsMeta.numClasses - val labelInt = label.toInt - agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 + val aggShift = datasetInfo.numClasses * numBins * datasetInfo.numFeatures * nodeIndex + val aggIndex = aggShift + datasetInfo.numClasses * featureIndex * numBins + + arr(arrIndex).toInt * datasetInfo.numClasses + agg(aggIndex + label.toInt) += 1 } + /** + * + * @param arr Size numNodes * numFeatures + 1. + * Indexed by (node, feature) where feature is the least significant bit, + * shifted by 1. + * @param agg Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param rightChildShift + * @param bins + */ private def updateBinForUnorderedFeature( + arr: Array[Double], + agg: Array[Double], nodeIndex: Int, featureIndex: Int, - arr: Array[Double], label: Double, - agg: Array[Double], rightChildShift: Int, - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numBins: Int, bins: Array[Array[Bin]]) = { // Find the bin index for this feature. - val arrShift = 1 + dsMeta.numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex + val arrIndex = 1 + datasetInfo.numFeatures * nodeIndex + featureIndex // Update the left or right count for one bin. - val aggShift = dsMeta.numClasses * numBins * dsMeta.numFeatures * nodeIndex - val aggIndex - = aggShift + dsMeta.numClasses * featureIndex * numBins + arr(arrIndex).toInt * dsMeta.numClasses + val aggShift = datasetInfo.numClasses * numBins * datasetInfo.numFeatures * nodeIndex + val aggIndex = aggShift + datasetInfo.numClasses * featureIndex * numBins + + arr(arrIndex).toInt * datasetInfo.numClasses // Find all matching bins and increment their values - val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { - val labelInt = label.toInt - if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { - agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + 1 + if (bins(featureIndex)(binIndex).highSplit.categories.contains(label.toInt)) { + agg(aggIndex + binIndex) += 1 } else { - agg(rightChildShift + aggIndex + binIndex) - = agg(rightChildShift + aggIndex + binIndex) + 1 + agg(rightChildShift + aggIndex + binIndex) += 1 } binIndex += 1 } @@ -407,26 +426,33 @@ class DecisionTreeClassifier (params: DTClassifierParams) /** * Helper for binSeqOp + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes + * @param datasetInfo + * @param numNodes + * @param numBins */ private def orderedClassificationBinSeqOp( arr: Array[Double], agg: Array[Double], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numNodes: Int, numBins: Int) = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + dsMeta.numFeatures * nodeIndex + val validSignalIndex = 1 + datasetInfo.numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { // actual class label val label = arr(0) // Iterate over all features. var featureIndex = 0 - while (featureIndex < dsMeta.numFeatures) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, dsMeta, numBins) + while (featureIndex < datasetInfo.numFeatures) { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, datasetInfo, numBins) featureIndex += 1 } } @@ -435,12 +461,22 @@ class DecisionTreeClassifier (params: DTClassifierParams) } /** - * Helper for binSeqOp + * Helper for binSeqOp. + * + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation of size + * numClasses * numBins * numFeatures * numNodes + * // Size set by getElementsPerNode(): + * // 2 * numClasses * numBins * numFeatures * numNodes + * @param datasetInfo Dataset metadata. + * @param numNodes Number of nodes in this (level, group). + * @param bins */ private def unorderedClassificationBinSeqOp( arr: Array[Double], agg: Array[Double], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numNodes: Int, bins: Array[Array[Bin]]) = { val numBins = bins(0).length @@ -448,27 +484,27 @@ class DecisionTreeClassifier (params: DTClassifierParams) var nodeIndex = 0 while (nodeIndex < numNodes) { // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + dsMeta.numFeatures * nodeIndex + val validSignalIndex = 1 + datasetInfo.numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { - val rightChildShift = dsMeta.numClasses * numBins * dsMeta.numFeatures * numNodes + val rightChildShift = datasetInfo.numClasses * numBins * datasetInfo.numFeatures * numNodes // actual class label val label = arr(0) // Iterate over all features. var featureIndex = 0 - while (featureIndex < dsMeta.numFeatures) { - val isFeatureContinuous = dsMeta.categoricalFeaturesInfo.get(featureIndex).isEmpty + while (featureIndex < datasetInfo.numFeatures) { + val isFeatureContinuous = datasetInfo.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, dsMeta, numBins) + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, datasetInfo, numBins) } else { - val featureCategories = dsMeta.categoricalFeaturesInfo(featureIndex) + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) val isSpaceSufficientForAllCategoricalSplits = numBins > math.pow(2, featureCategories.toInt - 1) - 1 if (isSpaceSufficientForAllCategoricalSplits) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, - rightChildShift, dsMeta, numBins, bins) + updateBinForUnorderedFeature(arr, agg, nodeIndex, featureIndex, label, + rightChildShift, datasetInfo, numBins, bins) } else { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, dsMeta, numBins) + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, datasetInfo, numBins) } } featureIndex += 1 @@ -483,97 +519,44 @@ class DecisionTreeClassifier (params: DTClassifierParams) object DecisionTreeClassifier extends Serializable with Logging { /** - * Train a decision tree model for binary or multiclass classification. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * Labels should take values {0, 1, ..., numClasses-1}. - * @param dsMeta Dataset metadata (number of features, number of classes, etc.) - * @param params The configuration parameters for the tree learning algorithm - * (tree depth, quantile calculation strategy, etc.) - * @return DecisionTreeClassifierModel which can be used for prediction + * Get a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTreeClassifier]]. */ - def train( - input: RDD[LabeledPoint], - dsMeta: DatasetMetadata, - params: DTClassifierParams = new DTClassifierParams()): DecisionTreeClassifierModel = { - require(dsMeta.numClasses >= 2) - new DecisionTreeClassifier(params).train(input, dsMeta) + def defaultParams(): DTClassifierParams = { + new DTClassifierParams() } /** - * Train a decision tree model for binary or multiclass classification. + * Train a decision tree model for binary or multiclass classification, + * using the default set of learning parameters. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClasses Number of classes (label types) for classification. - * Default = 2 (binary classification). - * @param categoricalFeaturesInfo A map from each categorical variable to the - * number of discrete values it takes. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It is important to note that features are - * zero-indexed. - * Default = treat all features as continuous. - * @param params The configuration parameters for the tree learning algorithm - * (tree depth, quantile calculation strategy, etc.) + * @param datasetInfo Dataset metadata (number of features, number of classes, etc.) * @return DecisionTreeClassifierModel which can be used for prediction */ def train( input: RDD[LabeledPoint], - numClasses: Int = 2, - categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - params: DTClassifierParams = new DTClassifierParams()): DecisionTreeClassifierModel = { - - // Find the number of features by looking at the first sample. - val numFeatures = input.first().features.size - val dsMeta = new DatasetMetadata(numClasses, numFeatures, categoricalFeaturesInfo) - - train(input, dsMeta, params) + datasetInfo: DatasetInfo): DecisionTreeClassifierModel = { + require(datasetInfo.numClasses >= 2) + new DecisionTreeClassifier(new DTClassifierParams()).train(input, datasetInfo) } - // TODO: Move elsewhere! - protected def getImpurity(impurityName: String): ClassificationImpurity = { - impurityName match { - case "gini" => Gini - case "entropy" => Entropy - case _ => throw new IllegalArgumentException( - s"Bad impurity parameter for classification: $impurityName") - } - } - - // TODO: Add various versions of train() function below. - /** * Train a decision tree model for binary or multiclass classification. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClasses Number of classes (label types) for classification. - * @param categoricalFeaturesInfo A map from each categorical variable to the - * number of discrete values it takes. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It is important to note that features are - * zero-indexed. - * @param impurityName Criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree - * @param maxBins Maximum number of bins used for splitting features - * @param quantileStrategyName Algorithm for calculating quantiles + * @param datasetInfo Dataset metadata (number of features, number of classes, etc.) + * @param params The configuration parameters for the tree learning algorithm + * (tree depth, quantile calculation strategy, etc.) * @return DecisionTreeClassifierModel which can be used for prediction */ def train( input: RDD[LabeledPoint], - numClasses: Int, - categoricalFeaturesInfo: Map[Int, Int], - impurityName: String, - maxDepth: Int, - maxBins: Int, - quantileStrategyName: String, - maxMemoryInMB: Int): DecisionTreeClassifierModel = { - - val impurity = getImpurity(impurityName) - val quantileStrategy = getQuantileStrategy(quantileStrategyName) - val params = - new DTClassifierParams(impurity, maxDepth, maxBins, quantileStrategy, maxMemoryInMB) - train(input, numClasses, categoricalFeaturesInfo, params) + datasetInfo: DatasetInfo, + params: DTClassifierParams): DecisionTreeClassifierModel = { + require(datasetInfo.numClasses >= 2) + new DecisionTreeClassifier(params).train(input, datasetInfo) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala index 11d5aada5a549..2957afc271826 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.mllib.rdd.DatasetMetadata +import org.apache.spark.mllib.rdd.DatasetInfo import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.DTRegressorParams -import org.apache.spark.mllib.tree.impurity.{RegressionImpurity, Variance} +import org.apache.spark.mllib.tree.configuration.{QuantileStrategies, DTRegressorParams} +import org.apache.spark.mllib.tree.impurity.{RegressionImpurities, RegressionImpurity} import org.apache.spark.mllib.tree.model.{InformationGainStats, Bin, DecisionTreeRegressorModel} import org.apache.spark.rdd.RDD @@ -38,30 +38,21 @@ class DecisionTreeRegressor (params: DTRegressorParams) extends DecisionTree[DecisionTreeRegressorModel](params) { private val impurityFunctor = params.impurity - /* - private val impurityFunctor = params.impurity match { - case "variance" => Variance - case _ => throw new IllegalArgumentException(s"Bad impurity parameter for regression: ${params.impurity}") - } - */ /** * Method to train a decision tree model over an RDD * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * @param datasetInfo Dataset metadata specifying number of classes, features, etc. * @return a DecisionTreeRegressorModel that can be used for prediction */ def train( input: RDD[LabeledPoint], - categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()): DecisionTreeRegressorModel = { + datasetInfo: DatasetInfo): DecisionTreeRegressorModel = { + require(datasetInfo.isRegression) logDebug("algo = Regression") - val topNode = super.trainSub(input, 0, categoricalFeaturesInfo) + val topNode = super.trainSub(input, datasetInfo) new DecisionTreeRegressorModel(topNode) } @@ -73,7 +64,7 @@ class DecisionTreeRegressor (params: DTRegressorParams) protected def computeCentroidForCategories( featureIndex: Int, sampledInput: Array[LabeledPoint], - dsMeta: DatasetMetadata): Map[Double,Double] = { + datasetInfo: DatasetInfo): Map[Double,Double] = { // For categorical variables in regression, each bin is a category. // The bins are sorted and are ordered by calculating the centroid of their corresponding labels. sampledInput.map(lp => (lp.features(featureIndex), lp.label)) @@ -83,22 +74,22 @@ class DecisionTreeRegressor (params: DTRegressorParams) /** * Extracts left and right split aggregates. - * @param binData Array[Double] of size 2*numFeatures*numSplits + * @param binData Array[Double] of size 2 * numFeatures * numBins * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, - * (numBins - 1), numClasses) + * (numBins - 1), 3) */ protected def extractLeftRightNodeAggregates( binData: Array[Double], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numBins: Int): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](dsMeta.numFeatures, numBins - 1, 3) - val rightNodeAgg = Array.ofDim[Double](dsMeta.numFeatures, numBins - 1, 3) + val leftNodeAgg = Array.ofDim[Double](datasetInfo.numFeatures, numBins - 1, 3) + val rightNodeAgg = Array.ofDim[Double](datasetInfo.numFeatures, numBins - 1, 3) // Iterate over all features. var featureIndex = 0 - while (featureIndex < dsMeta.numFeatures) { + while (featureIndex < datasetInfo.numFeatures) { // shift for this featureIndex val shift = 3 * featureIndex * numBins // left node aggregate for the lowest split @@ -138,9 +129,9 @@ class DecisionTreeRegressor (params: DTRegressorParams) } protected def getElementsPerNode( - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numBins: Int): Int = { - 3 * numBins * dsMeta.numFeatures + 3 * numBins * datasetInfo.numFeatures } /** @@ -149,15 +140,15 @@ class DecisionTreeRegressor (params: DTRegressorParams) * the count, sum, sum of squares of one of the p bins is incremented. * * @param agg Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for classification + * 3 * numBins * numFeatures * numNodes for classification * @param arr Array[Double] of size 1 + (numFeatures * numNodes) * @return Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for regression + * 3 * numBins * numFeatures * numNodes for regression */ protected def binSeqOpSub( agg: Array[Double], arr: Array[Double], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numNodes: Int, bins: Array[Array[Bin]]): Array[Double] = { val numBins = bins(0).length @@ -165,19 +156,19 @@ class DecisionTreeRegressor (params: DTRegressorParams) var nodeIndex = 0 while (nodeIndex < numNodes) { // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + dsMeta.numFeatures * nodeIndex + val validSignalIndex = 1 + datasetInfo.numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { // actual class label val label = arr(0) // Iterate over all features. var featureIndex = 0 - while (featureIndex < dsMeta.numFeatures) { + while (featureIndex < datasetInfo.numFeatures) { // Find the bin index for this feature. - val arrShift = 1 + dsMeta.numFeatures * nodeIndex + val arrShift = 1 + datasetInfo.numFeatures * nodeIndex val arrIndex = arrShift + featureIndex // Update count, sum, and sum^2 for one bin. - val aggShift = 3 * numBins * dsMeta.numFeatures * nodeIndex + val aggShift = 3 * numBins * datasetInfo.numFeatures * nodeIndex val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label @@ -205,7 +196,7 @@ class DecisionTreeRegressor (params: DTRegressorParams) splitIndex: Int, rightNodeAgg: Array[Array[Array[Double]]], topImpurity: Double, - numClasses: Int, + datasetInfo: DatasetInfo, level: Int): InformationGainStats = { val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) @@ -261,119 +252,55 @@ class DecisionTreeRegressor (params: DTRegressorParams) protected def getBinDataForNode( node: Int, binAggregates: Array[Double], - dsMeta: DatasetMetadata, + datasetInfo: DatasetInfo, numNodes: Int, numBins: Int): Array[Double] = { - val shift = 3 * node * numBins * dsMeta.numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * dsMeta.numFeatures) + val shift = 3 * node * numBins * datasetInfo.numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * datasetInfo.numFeatures) binsForNode } - //=========================================================================== - // Protected methods - //=========================================================================== - - /** - * Performs a sequential aggregation over a partition for regression. - */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = { - } - } object DecisionTreeRegressor extends Serializable with Logging { /** - * Train a decision tree model for regression. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * Labels should take values {0, 1, ..., numClasses-1}. - * @param dsMeta Dataset metadata (number of features, number of classes, etc.) - * @param params The configuration parameters for the tree learning algorithm - * (tree depth, quantile calculation strategy, etc.) - * @return DecisionTreeRegressorModel which can be used for prediction + * Get a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTreeRegressor]]. */ - def train( - input: RDD[LabeledPoint], - dsMeta: DatasetMetadata, - params: DTRegressorParams = new DTRegressorParams()): DecisionTreeRegressorModel = { - require(dsMeta.numClasses >= 2) - new DecisionTreeRegressor(params).train(input, dsMeta) + def defaultParams(): DTRegressorParams = { + new DTRegressorParams() } /** - * Train a decision tree model for binary or multiclass regression. + * Train a decision tree model for regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClasses Number of classes (label types) for regression. - * Default = 2 (binary regression). - * @param categoricalFeaturesInfo A map from each categorical variable to the - * number of discrete values it takes. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It is important to note that features are - * zero-indexed. - * Default = treat all features as continuous. - * @param params The configuration parameters for the tree learning algorithm - * (tree depth, quantile calculation strategy, etc.) + * Labels should be real values. + * @param datasetInfo Dataset metadata (number of features, number of classes, etc.) * @return DecisionTreeRegressorModel which can be used for prediction */ def train( - input: RDD[LabeledPoint], - numClasses: Int = 2, - categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - params: DTRegressorParams = new DTRegressorParams()): DecisionTreeRegressorModel = { - - // Find the number of features by looking at the first sample. - val numFeatures = input.first().features.size - val dsMeta = new DatasetMetadata(numClasses, numFeatures, categoricalFeaturesInfo) - - train(input, dsMeta, params) - } - - // TODO: Move elsewhere! - protected def getImpurity(impurityName: String): RegressionImpurity = { - impurityName match { - case "gini" => Gini - case "entropy" => Entropy - case _ => throw new IllegalArgumentException( - s"Bad impurity parameter for regression: $impurityName") - } + input: RDD[LabeledPoint], + datasetInfo: DatasetInfo): DecisionTreeRegressorModel = { + new DecisionTreeRegressor(new DTRegressorParams()).train(input, datasetInfo) } /** - * Train a decision tree model for binary or multiclass regression. + * Train a decision tree model for regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClasses Number of classes (label types) for regression. - * @param categoricalFeaturesInfo A map from each categorical variable to the - * number of discrete values it takes. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It is important to note that features are - * zero-indexed. - * @param impurityName Criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree - * @param maxBins Maximum number of bins used for splitting features - * @param quantileStrategyName Algorithm for calculating quantiles + * Labels should be real values. + * @param datasetInfo Dataset metadata (number of features, number of classes, etc.) + * @param params The configuration parameters for the tree learning algorithm + * (tree depth, quantile calculation strategy, etc.) * @return DecisionTreeRegressorModel which can be used for prediction */ def train( - input: RDD[LabeledPoint], - numClasses: Int, - categoricalFeaturesInfo: Map[Int, Int], - impurityName: String, - maxDepth: Int, - maxBins: Int, - quantileStrategyName: String, - maxMemoryInMB: Int): DecisionTreeRegressorModel = { - - val impurity = getImpurity(impurityName) - val quantileStrategy = getQuantileStrategy(quantileStrategyName) - val params = - new DTRegressorParams(impurity, maxDepth, maxBins, quantileStrategy, maxMemoryInMB) - train(input, numClasses, categoricalFeaturesInfo, params) + input: RDD[LabeledPoint], + datasetInfo: DatasetInfo, + params: DTRegressorParams = new DTRegressorParams()): DecisionTreeRegressorModel = { + new DecisionTreeRegressor(params).train(input, datasetInfo) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala deleted file mode 100644 index 79a01f58319e8..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.configuration - -import org.apache.spark.annotation.Experimental - -/** - * :: Experimental :: - * Enum to select the algorithm for the decision tree - */ -@Experimental -object Algo extends Enumeration { - type Algo = Value - val Classification, Regression = Value -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala index 52013622ce1c2..d8d54c6c4fb17 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -18,9 +18,8 @@ package org.apache.spark.mllib.tree.configuration import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.configuration.DTParams import org.apache.spark.mllib.tree.impurity.{ClassificationImpurity, Gini} -import org.apache.spark.mllib.tree.configuration.QuantileStrategy +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ /** * :: Experimental :: @@ -34,17 +33,11 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy */ @Experimental class DTClassifierParams ( - val impurity: ClassificationImpurity = Gini, + var impurity: ClassificationImpurity = Gini, maxDepth: Int = 5, maxBins: Int = 100, - quantileStrategy: QuantileStrategy.QuantileStrategy = QuantileStrategy.Sort, + quantileStrategy: QuantileStrategy = Sort, maxMemoryInMB: Int = 128) extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { - /* - if (!List("gini", "entropy").contains(impurity)) { - throw new IllegalArgumentException(s"Bad impurity parameter for classification: $impurity") - } - */ - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala index 6287eeb7ebd0d..3098e2b66ce46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.tree.configuration import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.configuration.QuantileStrategy /** * :: Experimental :: @@ -31,9 +30,9 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy */ @Experimental class DTParams ( - val maxDepth: Int, - val maxBins: Int, - val quantileStrategy: QuantileStrategy.QuantileStrategy, - val maxMemoryInMB: Int) extends Serializable { + var maxDepth: Int, + var maxBins: Int, + var quantileStrategy: QuantileStrategy.QuantileStrategy, + var maxMemoryInMB: Int) extends Serializable { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala index bd76d2d794ee7..f38f9ab6ae7fd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -18,9 +18,7 @@ package org.apache.spark.mllib.tree.configuration import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.configuration.DTParams import org.apache.spark.mllib.tree.impurity.{RegressionImpurity, Variance} -import org.apache.spark.mllib.tree.configuration.QuantileStrategy /** * :: Experimental :: @@ -34,17 +32,11 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy */ @Experimental class DTRegressorParams ( - val impurity: RegressionImpurity = Variance, + var impurity: RegressionImpurity = Variance, maxDepth: Int = 5, maxBins: Int = 100, quantileStrategy: QuantileStrategy.QuantileStrategy = QuantileStrategy.Sort, maxMemoryInMB: Int = 128) extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { - /* - if (!List("variance").contains(impurity)) { - throw new IllegalArgumentException(s"Bad impurity parameter for regression: $impurity") - } - */ - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 052ef6c148d33..1103457086f22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -27,16 +27,23 @@ import org.apache.spark.annotation.Experimental object QuantileStrategy extends Enumeration { type QuantileStrategy = Value val Sort, MinMax, ApproxHist = Value +} + + +/** + * :: Experimental :: + * Factory for creating [[org.apache.spark.mllib.tree.configuration.QuantileStrategy]] instances. + */ +@Experimental +private[mllib] object QuantileStrategies { /** * Given a string with the name of a quantile strategy, get the QuantileStrategy type. */ - def getQuantileStrategy(strategyName: String): QuantileStrategy = { - strategyName match { - case "sort" => Sort - case _ => throw new IllegalArgumentException( - s"Bad QuantileStrategy parameter: $strategyName") - } + def strategy(strategyName: String): QuantileStrategy.QuantileStrategy = strategyName match { + case "sort" => QuantileStrategy.Sort + case _ => throw new IllegalArgumentException( + s"Bad QuantileStrategy parameter: $strategyName") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurities.scala new file mode 100644 index 0000000000000..f5646e98c4c55 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurities.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Factory class for constructing a [[org.apache.spark.mllib.tree.impurity.ClassificationImpurity]] + * type based on its name. + */ +@Experimental +object ClassificationImpurities { + + /** + * Mapping used for impurity names, used for parsing impurity settings. + * If you add a new impurity class, add it here. + */ + val impurityToNameMap: Map[ClassificationImpurity, String] = Map( + Gini -> "gini", + Entropy -> "entropy") + + val nameToImpurityMap: Map[String, ClassificationImpurity] = impurityToNameMap.map(_.swap) + + val names: List[String] = nameToImpurityMap.keys.toList + + /** + * Given impurity name, return type. + */ + def impurity(name: String): ClassificationImpurity = { + if (nameToImpurityMap.contains(name)) { + nameToImpurityMap(name) + } else { + throw new IllegalArgumentException(s"Bad impurity parameter for classification: $name") + } + } + + /** + * Given impurity type, return name. + */ + def impurityName(impurity: ClassificationImpurity): String = { + if (impurityToNameMap.contains(impurity)) { + impurityToNameMap(impurity) + } else { + throw new IllegalArgumentException( + s"ClassificationImpurity type ${impurity.toString}" + + " not registered in ClassificationImpurities factory.") + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurities.scala new file mode 100644 index 0000000000000..b27622e68d9e5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurities.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Factory class for constructing a [[org.apache.spark.mllib.tree.impurity.RegressionImpurity]] + * type based on its name. + */ +@Experimental +object RegressionImpurities { + + /** + * Mapping used for impurity names, used for parsing impurity settings. + * If you add a new impurity class, add it here. + */ + val impurityToNameMap: Map[RegressionImpurity, String] = Map( + Variance -> "variance") + + val nameToImpurityMap: Map[String, RegressionImpurity] = impurityToNameMap.map(_.swap) + + val names: List[String] = nameToImpurityMap.keys.toList + + /** + * Given impurity name, return type. + */ + def impurity(name: String): RegressionImpurity = { + if (nameToImpurityMap.contains(name)) { + nameToImpurityMap(name) + } else { + throw new IllegalArgumentException(s"Bad impurity parameter for regression: $name") + } + } + + /** + * Given impurity type, return name. + */ + def impurityName(impurity: RegressionImpurity): String = { + if (impurityToNameMap.contains(impurity)) { + impurityToNameMap(impurity) + } else { + throw new IllegalArgumentException( + s"RegressionImpurity type ${impurity.toString}" + + " not registered in RegressionImpurities factory.") + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 0feac65a574af..c6d674ddb7857 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -24,22 +24,11 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * Class for calculating variance during regression */ @Experimental -private[mllib] object Variance extends Impurity { +private[mllib] object Variance extends RegressionImpurity { /** * :: DeveloperApi :: - * information calculation for multiclass classification - * @param counts Array[Double] with counts for each label - * @param totalCount sum of counts for all labels - * @return information value - */ - @DeveloperApi - override def calculate(counts: Array[Double], totalCount: Double): Double = - throw new UnsupportedOperationException("Variance.calculate") - - /** - * :: DeveloperApi :: - * variance calculation + * information calculation for regression * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 5705916eb6531..bd9ef46928714 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -125,7 +125,7 @@ class Node ( println(prefix + s"If ${splitToString(split.get, true)}") leftNode.get.print(prefix + " ") println(prefix + s"Else ${splitToString(split.get, false)}") - tNode.get.print(prefix + " ") + rightNode.get.print(prefix + " ") } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5961a618c59d9..83e5c9a8a050c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.tree +import org.apache.spark.mllib.rdd.DatasetInfo import org.scalatest.FunSuite import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.model.Split -import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} -import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.{FeatureType, DTClassifierParams, DTRegressorParams} import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext @@ -31,12 +31,27 @@ import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { + private def getNumFeatures(data: Array[LabeledPoint]) = data.size match { + case 0 => 0 + case _ => data(0).features.size + } + + private val defaultClassifierParams = + new DTClassifierParams(Gini, maxDepth=3, maxBins=100) + + private val defaultRegressorParams = + new DTRegressorParams(Variance, maxDepth=3, maxBins=100) + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -47,14 +62,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 3, - numClassesForClassification = 2, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -127,14 +141,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 3, - numClassesForClassification = 2, - maxBins = 100, + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr), categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) // Check splits. @@ -244,14 +257,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 3, - numClassesForClassification = 100, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 100, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -338,14 +350,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() assert(arr.length === 3000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 3, - numClassesForClassification = 100, - maxBins = 100, + val datasetInfo = new DatasetInfo( + numClasses = 100, + numFeatures = getNumFeatures(arr), categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) // 2^10 - 1 > 100, so categorical variables will be ordered @@ -393,15 +404,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - numClassesForClassification = 2, - maxDepth = 3, - maxBins = 100, + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr), categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(7), 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -421,14 +431,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Regression, - Variance, - maxDepth = 3, - maxBins = 100, + val datasetInfo = new DatasetInfo( + numClasses = 0, + numFeatures = getNumFeatures(arr), categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + + val dtLearner = new DecisionTreeRegressor(defaultRegressorParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(7), 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -447,8 +457,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -456,7 +471,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(7), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -470,8 +485,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtLearner = new DecisionTreeClassifier(defaultClassifierParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -479,7 +499,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, Array(0.0), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -494,8 +514,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtParams = defaultClassifierParams + dtParams.impurity = Entropy + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -503,7 +530,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, Array(0.0), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -518,8 +545,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtParams = defaultClassifierParams + dtParams.impurity = Entropy + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -527,7 +561,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, Array(0.0), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -542,8 +576,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = getNumFeatures(arr)) + + val dtParams = defaultClassifierParams + dtParams.impurity = Entropy + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -557,7 +598,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, parentImpurities, 1, filters, splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -565,7 +606,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, + val bestSplitsWithGroups = dtLearner.findBestSplits(rdd, datasetInfo, parentImpurities, 1, filters, splits, bins, 0) assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) @@ -586,12 +627,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val rdd = sc.parallelize(arr) + val datasetInfo = new DatasetInfo( + numClasses = 3, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(datasetInfo.isMulticlass) + + val dtParams = defaultClassifierParams + dtParams.impurity = Entropy + dtParams.maxDepth = 5 + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -604,12 +653,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClassesForClassification = 3) - assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val rdd = sc.parallelize(arr) + val datasetInfo = new DatasetInfo( + numClasses = 3, + numFeatures = getNumFeatures(arr)) + assert(datasetInfo.isMulticlass) + + val dtParams = defaultClassifierParams + dtParams.maxDepth = 5 + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -624,12 +678,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) - assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val rdd = sc.parallelize(arr) + val datasetInfo = new DatasetInfo( + numClasses = 3, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 3)) + assert(datasetInfo.isMulticlass) + + val dtParams = defaultClassifierParams + dtParams.maxDepth = 5 + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -643,12 +704,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) - assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val rdd = sc.parallelize(arr) + val datasetInfo = new DatasetInfo( + numClasses = 3, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + assert(datasetInfo.isMulticlass) + + val dtParams = defaultClassifierParams + dtParams.maxDepth = 5 + val dtLearner = new DecisionTreeClassifier(dtParams) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) From a853bfc1929e9d1fb56d955241c827fd2a5c1351 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 23 Jul 2014 13:25:05 -0700 Subject: [PATCH 04/20] Last non-merge commit said it changed the maxDepth meaning, but it did not. This one implements this change: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit maxDepth: Changed meaning by 1. Previously, depth = 1 meant 1 leaf node; now it means 1 internal and 2 leaf nodes. This matches scikit-learn and rpart. Internally, this meant replacing: maxDepth <— maxDepth+1. In tests, decremented maxDepth by 1. --- .../spark/examples/mllib/DecisionTreeRunner.scala | 2 +- .../org/apache/spark/mllib/tree/DecisionTree.scala | 8 ++++---- .../tree/configuration/DTClassifierParams.scala | 5 +++-- .../spark/mllib/tree/configuration/DTParams.scala | 3 ++- .../mllib/tree/configuration/DTRegressorParams.scala | 5 +++-- .../apache/spark/mllib/tree/DecisionTreeSuite.scala | 12 ++++++------ 6 files changed, 19 insertions(+), 16 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index cc1b2e94bb6bc..a502f067e8981 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -48,7 +48,7 @@ object DecisionTreeRunner { dataFormat: String = null, algo: String = "classification", impurity: Option[String] = None, - maxDepth: Int = 5, + maxDepth: Int = 4, maxBins: Int = 100, fracTest: Double = 0.2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index e467d10d36581..225d1f7e83c86 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -63,7 +63,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va // depth of the decision tree val maxDepth = params.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -99,7 +99,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va var level = 0 var break = false - while (level < maxDepth && !break) { + while (level <= maxDepth && !break) { logDebug("#####################################") logDebug("level = " + level) @@ -238,7 +238,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = math.pow(2, level).toInt - 1 + index - val isLeaf = (stats.gain <= 0) || (level == params.maxDepth - 1) + val isLeaf = (stats.gain <= 0) || (level == params.maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node @@ -259,7 +259,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth - 1) { + if (level < maxDepth) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala index d8d54c6c4fb17..35d4166898490 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * :: Experimental :: * Stores all the configuration options for DecisionTreeClassifier construction * @param impurity criterion used for information gain calculation (e.g., Gini or Entropy) - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param maxBins maximum number of bins used for splitting features * @param quantileStrategy algorithm for calculating quantiles * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is @@ -34,7 +35,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @Experimental class DTClassifierParams ( var impurity: ClassificationImpurity = Gini, - maxDepth: Int = 5, + maxDepth: Int = 4, maxBins: Int = 100, quantileStrategy: QuantileStrategy = Sort, maxMemoryInMB: Int = 128) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala index 3098e2b66ce46..6d43f6098d166 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala @@ -22,7 +22,8 @@ import org.apache.spark.annotation.Experimental /** * :: Experimental :: * Stores configuration options for DecisionTree construction. - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param maxBins maximum number of bins used for splitting features * @param quantileStrategy algorithm for calculating quantiles * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala index f38f9ab6ae7fd..368beb8c0ad87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -23,7 +23,8 @@ import org.apache.spark.mllib.tree.impurity.{RegressionImpurity, Variance} /** * :: Experimental :: * Stores all the configuration options for DecisionTreeRegressor construction - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param maxBins maximum number of bins used for splitting features * @param quantileStrategy algorithm for calculating quantiles * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is @@ -33,7 +34,7 @@ import org.apache.spark.mllib.tree.impurity.{RegressionImpurity, Variance} @Experimental class DTRegressorParams ( var impurity: RegressionImpurity = Variance, - maxDepth: Int = 5, + maxDepth: Int = 4, maxBins: Int = 100, quantileStrategy: QuantileStrategy.QuantileStrategy = QuantileStrategy.Sort, maxMemoryInMB: Int = 128) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 83e5c9a8a050c..d4b6b8264b5b5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -37,10 +37,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } private val defaultClassifierParams = - new DTClassifierParams(Gini, maxDepth=3, maxBins=100) + new DTClassifierParams(Gini, maxDepth = 2, maxBins = 100) private val defaultRegressorParams = - new DTRegressorParams(Variance, maxDepth=3, maxBins=100) + new DTRegressorParams(Variance, maxDepth = 2, maxBins = 100) test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() @@ -636,7 +636,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val dtParams = defaultClassifierParams dtParams.impurity = Entropy - dtParams.maxDepth = 5 + dtParams.maxDepth = 4 val dtLearner = new DecisionTreeClassifier(dtParams) val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) @@ -660,7 +660,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(datasetInfo.isMulticlass) val dtParams = defaultClassifierParams - dtParams.maxDepth = 5 + dtParams.maxDepth = 4 val dtLearner = new DecisionTreeClassifier(dtParams) val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, @@ -686,7 +686,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(datasetInfo.isMulticlass) val dtParams = defaultClassifierParams - dtParams.maxDepth = 5 + dtParams.maxDepth = 4 val dtLearner = new DecisionTreeClassifier(dtParams) val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) @@ -712,7 +712,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(datasetInfo.isMulticlass) val dtParams = defaultClassifierParams - dtParams.maxDepth = 5 + dtParams.maxDepth = 4 val dtLearner = new DecisionTreeClassifier(dtParams) val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) From 45068442dbcf36548d32001d60f9d4bda68c6a87 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 23 Jul 2014 18:06:36 -0700 Subject: [PATCH 05/20] Changed all config/impurity classes/objects to be private[mllib]. Changed Params classes to take strings instead of special types. Made impurity names lists publicly accessible via Params classes. Simplified impurity factories. --- .../examples/mllib/DecisionTreeRunner.scala | 36 +++----- .../spark/mllib/tree/DecisionTree.scala | 26 ++++-- .../mllib/tree/DecisionTreeClassifier.scala | 10 +- .../mllib/tree/DecisionTreeRegressor.scala | 8 +- .../configuration/DTClassifierParams.scala | 25 ++++- .../mllib/tree/configuration/DTParams.scala | 28 +++++- .../configuration/DTRegressorParams.scala | 24 ++++- .../tree/configuration/QuantileStrategy.scala | 22 +++-- .../impurity/ClassificationImpurities.scala | 23 +---- .../spark/mllib/tree/impurity/Entropy.scala | 2 +- .../spark/mllib/tree/impurity/Impurity.scala | 49 ---------- .../tree/impurity/RegressionImpurities.scala | 21 +---- .../tree/impurity/RegressionImpurity.scala | 1 + .../mllib/tree/JavaDecisionTreeSuite.java | 91 +++++++++++++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 22 +++-- 15 files changed, 239 insertions(+), 149 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala create mode 100644 mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index a502f067e8981..51f9e6d196414 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -25,7 +25,6 @@ import org.apache.spark.mllib.rdd.DatasetInfo import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTreeClassifier, DecisionTreeRegressor} import org.apache.spark.mllib.tree.configuration.{DTClassifierParams, DTRegressorParams} -import org.apache.spark.mllib.tree.impurity.{ClassificationImpurities, RegressionImpurities} import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -52,10 +51,11 @@ object DecisionTreeRunner { maxBins: Int = 100, fracTest: Double = 0.2) + private val defaultCImpurity = new DTClassifierParams().impurity + private val defaultRImpurity = new DTRegressorParams().impurity + def main(args: Array[String]) { val defaultParams = Params() - val defaultCImpurity = ClassificationImpurities.impurityName(new DTClassifierParams().impurity) - val defaultRImpurity = RegressionImpurities.impurityName(new DTRegressorParams().impurity) val parser = new OptionParser[Params]("DecisionTreeRunner") { head("DecisionTreeRunner: an example decision tree app.") @@ -65,9 +65,9 @@ object DecisionTreeRunner { opt[String]("impurity") .text( s"impurity type\n" + - s"\tFor classification: ${ClassificationImpurities.names.mkString(",")}\n" + + s"\tFor classification: ${DTClassifierParams.supportedImpurities.mkString(",")}\n" + s"\t default: $defaultCImpurity" + - s"\tFor regression: ${RegressionImpurities.names.mkString(",")}\n" + + s"\tFor regression: ${DTRegressorParams.supportedImpurities.mkString(",")}\n" + s"\t default: $defaultRImpurity") .action((x, c) => c.copy(impurity = Some(x))) opt[Int]("maxDepth") @@ -91,14 +91,6 @@ object DecisionTreeRunner { if (!List("classification", "regression").contains(params.algo)) { failure(s"Did not recognize Algo: ${params.algo}") } - if (params.impurity != None) { - if ((params.algo == "classification" && - !ClassificationImpurities.names.contains(params.impurity)) || - (params.algo == "regression" && - !RegressionImpurities.names.contains(params.impurity))) { - failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") - } - } if (params.fracTest < 0 || params.fracTest > 1) { failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") } @@ -167,7 +159,7 @@ object DecisionTreeRunner { val numTraining = training.count() val numTest = test.count() - println(s"numTraining = $numTraining, numTest = $numTest.") + println(s"numTraining = $numTraining, numTest = $numTest") examples.unpersist(blocking = false) @@ -179,27 +171,27 @@ object DecisionTreeRunner { val dtParams = DecisionTreeClassifier.defaultParams() dtParams.maxDepth = params.maxDepth dtParams.maxBins = params.maxBins - if (params.impurity != None) { - dtParams.impurity = ClassificationImpurities.impurity(params.impurity.get) + if (params.impurity == None) { + dtParams.impurity = defaultCImpurity } val dtLearner = new DecisionTreeClassifier(dtParams) - val model = dtLearner.train(training, datasetInfo) + val model = dtLearner.run(training, datasetInfo) model.print() val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy.") + println(s"Test accuracy = $accuracy") } case "regression" => { val dtParams = DecisionTreeRegressor.defaultParams() dtParams.maxDepth = params.maxDepth dtParams.maxBins = params.maxBins - if (params.impurity != None) { - dtParams.impurity = RegressionImpurities.impurity(params.impurity.get) + if (params.impurity == None) { + dtParams.impurity = defaultRImpurity } val dtLearner = new DecisionTreeRegressor(dtParams) - val model = dtLearner.train(training, datasetInfo) + val model = dtLearner.run(training, datasetInfo) model.print() val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse.") + println(s"Test mean squared error = $mse") } case _ => { throw new IllegalArgumentException("Algo ${params.algo} not supported.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 225d1f7e83c86..4cd5ebcaa86e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -23,6 +23,7 @@ import org.apache.spark.mllib.rdd.DatasetInfo import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.DTParams import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategies import org.apache.spark.mllib.tree.configuration.QuantileStrategy import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -36,12 +37,22 @@ import org.apache.spark.util.random.XORShiftRandom * @param params The configuration parameters for the tree algorithm. */ @Experimental -private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected val params: DTParams) +private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTParams) extends Serializable with Logging { protected final val InvalidBinIndex = -1 - /** + // depth of the decision tree + protected val maxDepth: Int = params.maxDepth + + protected val maxBins: Int = params.maxBins + + protected val quantileStrategy: QuantileStrategy.QuantileStrategy = + QuantileStrategies.strategy(params.quantileStrategy) + + protected val maxMemoryInMB: Int = params.maxMemoryInMB + + /** * Method to train a decision tree model over an RDD * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * @param datasetInfo Dataset metadata. @@ -60,8 +71,6 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va val numBins = bins(0).length logDebug("numBins = " + numBins) - // depth of the decision tree - val maxDepth = params.maxDepth // the max number of nodes possible given the depth of the tree val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 // Initialize an array to hold filters applied to points for each node. @@ -76,7 +85,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va // Calculate level for single group construction // Max memory usage for aggregates - val maxMemoryUsage = params.maxMemoryInMB * 1024 * 1024 + val maxMemoryUsage = maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") val numElementsPerNode = getElementsPerNode(datasetInfo, numBins) logDebug("numElementsPerNode = " + numElementsPerNode) @@ -238,7 +247,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = math.pow(2, level).toInt - 1 + index - val isLeaf = (stats.gain <= 0) || (level == params.maxDepth) + val isLeaf = (stats.gain <= 0) || (level == maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node @@ -732,7 +741,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va * (b) For regression and binary classification, * and for multiclass classification with a high-arity feature, * there is one split per category. - * + * Categorical case (a) features are called unordered features. * Other cases are called ordered features. * @@ -751,7 +760,6 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va val numFeatures = datasetInfo.numFeatures - val maxBins = params.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) val isMulticlass = datasetInfo.isMulticlass @@ -784,7 +792,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (protected va val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) - params.quantileStrategy match { + quantileStrategy match { case QuantileStrategy.Sort => val splits = Array.ofDim[Split](numFeatures, numBins - 1) val bins = Array.ofDim[Bin](numFeatures, numBins) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala index 963d2ce5ea5ed..6987a11ddd2b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala @@ -22,6 +22,8 @@ import org.apache.spark.Logging import org.apache.spark.mllib.rdd.DatasetInfo import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.DTClassifierParams +import org.apache.spark.mllib.tree.impurity.ClassificationImpurities + //import org.apache.spark.mllib.tree.impurity.{ClassificationImpurity, ClassificationImpurities} import org.apache.spark.mllib.tree.model.{InformationGainStats, Bin, DecisionTreeClassifierModel} import org.apache.spark.rdd.RDD @@ -37,7 +39,7 @@ import org.apache.spark.rdd.RDD class DecisionTreeClassifier (params: DTClassifierParams) extends DecisionTree[DecisionTreeClassifierModel](params) { - private val impurityFunctor = params.impurity + private val impurityFunctor = ClassificationImpurities.impurity(params.impurity) /** * Method to train a decision tree model over an RDD @@ -45,7 +47,7 @@ class DecisionTreeClassifier (params: DTClassifierParams) * @param datasetInfo Dataset metadata specifying number of classes, features, etc. * @return a DecisionTreeClassifierModel that can be used for prediction */ - def train( + def run( input: RDD[LabeledPoint], datasetInfo: DatasetInfo): DecisionTreeClassifierModel = { @@ -538,7 +540,7 @@ object DecisionTreeClassifier extends Serializable with Logging { input: RDD[LabeledPoint], datasetInfo: DatasetInfo): DecisionTreeClassifierModel = { require(datasetInfo.numClasses >= 2) - new DecisionTreeClassifier(new DTClassifierParams()).train(input, datasetInfo) + new DecisionTreeClassifier(new DTClassifierParams()).run(input, datasetInfo) } /** @@ -556,7 +558,7 @@ object DecisionTreeClassifier extends Serializable with Logging { datasetInfo: DatasetInfo, params: DTClassifierParams): DecisionTreeClassifierModel = { require(datasetInfo.numClasses >= 2) - new DecisionTreeClassifier(params).train(input, datasetInfo) + new DecisionTreeClassifier(params).run(input, datasetInfo) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala index 2957afc271826..5b5ac139f7d87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala @@ -37,7 +37,7 @@ import org.apache.spark.rdd.RDD class DecisionTreeRegressor (params: DTRegressorParams) extends DecisionTree[DecisionTreeRegressorModel](params) { - private val impurityFunctor = params.impurity + private val impurityFunctor = RegressionImpurities.impurity(params.impurity) /** * Method to train a decision tree model over an RDD @@ -45,7 +45,7 @@ class DecisionTreeRegressor (params: DTRegressorParams) * @param datasetInfo Dataset metadata specifying number of classes, features, etc. * @return a DecisionTreeRegressorModel that can be used for prediction */ - def train( + def run( input: RDD[LabeledPoint], datasetInfo: DatasetInfo): DecisionTreeRegressorModel = { @@ -283,7 +283,7 @@ object DecisionTreeRegressor extends Serializable with Logging { def train( input: RDD[LabeledPoint], datasetInfo: DatasetInfo): DecisionTreeRegressorModel = { - new DecisionTreeRegressor(new DTRegressorParams()).train(input, datasetInfo) + new DecisionTreeRegressor(new DTRegressorParams()).run(input, datasetInfo) } /** @@ -300,7 +300,7 @@ object DecisionTreeRegressor extends Serializable with Logging { input: RDD[LabeledPoint], datasetInfo: DatasetInfo, params: DTRegressorParams = new DTRegressorParams()): DecisionTreeRegressorModel = { - new DecisionTreeRegressor(params).train(input, datasetInfo) + new DecisionTreeRegressor(params).run(input, datasetInfo) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala index 35d4166898490..d43b53d815d25 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -18,13 +18,12 @@ package org.apache.spark.mllib.tree.configuration import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.impurity.{ClassificationImpurity, Gini} -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.ClassificationImpurities /** * :: Experimental :: * Stores all the configuration options for DecisionTreeClassifier construction - * @param impurity criterion used for information gain calculation (e.g., Gini or Entropy) + * @param impurity criterion used for information gain calculation (e.g., "gini" or "entropy") * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param maxBins maximum number of bins used for splitting features @@ -34,11 +33,27 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ */ @Experimental class DTClassifierParams ( - var impurity: ClassificationImpurity = Gini, + var impurity: String = "gini", maxDepth: Int = 4, maxBins: Int = 100, - quantileStrategy: QuantileStrategy = Sort, + quantileStrategy: String = "sort", maxMemoryInMB: Int = 128) extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { + def setImpurity(impurity: String) = { + if (!ClassificationImpurities.nameToImpurityMap.contains(impurity)) { + throw new IllegalArgumentException(s"Bad impurity parameter for classification: $impurity") + } + this.impurity = impurity + } + } + +object DTClassifierParams { + + /** + * List of supported impurity options. + */ + final val supportedImpurities: List[String] = ClassificationImpurities.names + +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala index 6d43f6098d166..1eba6da327ebd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala @@ -33,7 +33,33 @@ import org.apache.spark.annotation.Experimental class DTParams ( var maxDepth: Int, var maxBins: Int, - var quantileStrategy: QuantileStrategy.QuantileStrategy, + var quantileStrategy: String, var maxMemoryInMB: Int) extends Serializable { + def setMaxDepth(maxDepth: Int) = { + this.maxDepth = maxDepth + } + + def setMaxBins(maxBins: Int) = { + this.maxBins = maxBins + } + + def setQuantileStrategy(quantileStrategy: String) = { + if (!QuantileStrategies.nameToStrategyMap.contains(quantileStrategy)) { + throw new IllegalArgumentException(s"Bad QuantileStrategy parameter: $quantileStrategy") + } + this.quantileStrategy = quantileStrategy + } + + def setMaxMemoryInMB(maxMemoryInMB: Int) = { + this.maxMemoryInMB = maxMemoryInMB + } + + /** + * Get list of supported quantileStrategy options. + */ + def supportedQuantileStrategies(): List[String] = { + QuantileStrategies.nameToStrategyMap.keys.toList + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala index 368beb8c0ad87..acdf69675eff6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -18,26 +18,42 @@ package org.apache.spark.mllib.tree.configuration import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.impurity.{RegressionImpurity, Variance} +import org.apache.spark.mllib.tree.impurity.RegressionImpurities /** * :: Experimental :: * Stores all the configuration options for DecisionTreeRegressor construction + * @param impurity criterion used for information gain calculation (e.g., Variance) * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param maxBins maximum number of bins used for splitting features * @param quantileStrategy algorithm for calculating quantiles * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * @param impurity criterion used for information gain calculation (e.g., Variance) */ @Experimental class DTRegressorParams ( - var impurity: RegressionImpurity = Variance, + var impurity: String = "variance", maxDepth: Int = 4, maxBins: Int = 100, - quantileStrategy: QuantileStrategy.QuantileStrategy = QuantileStrategy.Sort, + quantileStrategy: String = "sort", maxMemoryInMB: Int = 128) extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { + def setImpurity(impurity: String) = { + if (!RegressionImpurities.nameToImpurityMap.contains(impurity)) { + throw new IllegalArgumentException(s"Bad impurity parameter for regression: $impurity") + } + this.impurity = impurity + } + } + +object DTRegressorParams { + + /** + * List of supported impurity options. + */ + final val supportedImpurities: List[String] = RegressionImpurities.names + +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 1103457086f22..c7f0ae3beea76 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -24,12 +24,11 @@ import org.apache.spark.annotation.Experimental * Enum for selecting the quantile calculation strategy */ @Experimental -object QuantileStrategy extends Enumeration { +private[mllib] object QuantileStrategy extends Enumeration { type QuantileStrategy = Value val Sort, MinMax, ApproxHist = Value } - /** * :: Experimental :: * Factory for creating [[org.apache.spark.mllib.tree.configuration.QuantileStrategy]] instances. @@ -37,13 +36,24 @@ object QuantileStrategy extends Enumeration { @Experimental private[mllib] object QuantileStrategies { + import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ + + /** + * Mapping used for strategy names. + * If you add a new strategy type, add it here. + */ + val nameToStrategyMap: Map[String, QuantileStrategy] = Map( + "sort" -> Sort) + /** * Given a string with the name of a quantile strategy, get the QuantileStrategy type. */ - def strategy(strategyName: String): QuantileStrategy.QuantileStrategy = strategyName match { - case "sort" => QuantileStrategy.Sort - case _ => throw new IllegalArgumentException( - s"Bad QuantileStrategy parameter: $strategyName") + def strategy(name: String): QuantileStrategy = { + if (nameToStrategyMap.contains(name)) { + nameToStrategyMap(name) + } else { + throw new IllegalArgumentException(s"Bad QuantileStrategy parameter: $name") + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurities.scala index f5646e98c4c55..d415f8b718d56 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurities.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurities.scala @@ -25,17 +25,15 @@ import org.apache.spark.annotation.Experimental * type based on its name. */ @Experimental -object ClassificationImpurities { +private[mllib] object ClassificationImpurities { /** * Mapping used for impurity names, used for parsing impurity settings. * If you add a new impurity class, add it here. */ - val impurityToNameMap: Map[ClassificationImpurity, String] = Map( - Gini -> "gini", - Entropy -> "entropy") - - val nameToImpurityMap: Map[String, ClassificationImpurity] = impurityToNameMap.map(_.swap) + val nameToImpurityMap: Map[String, ClassificationImpurity] = Map( + "gini" -> Gini, + "entropy" -> Entropy) val names: List[String] = nameToImpurityMap.keys.toList @@ -50,17 +48,4 @@ object ClassificationImpurities { } } - /** - * Given impurity type, return name. - */ - def impurityName(impurity: ClassificationImpurity): String = { - if (impurityToNameMap.contains(impurity)) { - impurityToNameMap(impurity) - } else { - throw new IllegalArgumentException( - s"ClassificationImpurity type ${impurity.toString}" - + " not registered in ClassificationImpurities factory.") - } - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 5674848dbaf05..452a88c5ca69e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -27,7 +27,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} @Experimental private[mllib] object Entropy extends ClassificationImpurity { - private[tree] def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + private def log2(x: Double) = scala.math.log(x) / scala.math.log(2) /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala deleted file mode 100644 index 16b28c3471113..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impurity - -import org.apache.spark.annotation.{DeveloperApi, Experimental} - -/** - * :: Experimental :: - * Trait for calculating information gain. - */ -@Experimental -private[mllib] trait Impurity extends Serializable { - - /** - * :: DeveloperApi :: - * information calculation for multiclass classification - * @param counts Array[Double] with counts for each label - * @param totalCount sum of counts for all labels - * @return information value - */ - @DeveloperApi - def calculate(counts: Array[Double], totalCount: Double): Double - - /** - * :: DeveloperApi :: - * information calculation for regression - * @param count number of instances - * @param sum sum of labels - * @param sumSquares summation of squares of the labels - * @return information value - */ - @DeveloperApi - def calculate(count: Double, sum: Double, sumSquares: Double): Double -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurities.scala index b27622e68d9e5..7100026c35e8c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurities.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurities.scala @@ -25,16 +25,14 @@ import org.apache.spark.annotation.Experimental * type based on its name. */ @Experimental -object RegressionImpurities { +private[mllib] object RegressionImpurities { /** * Mapping used for impurity names, used for parsing impurity settings. * If you add a new impurity class, add it here. */ - val impurityToNameMap: Map[RegressionImpurity, String] = Map( - Variance -> "variance") - - val nameToImpurityMap: Map[String, RegressionImpurity] = impurityToNameMap.map(_.swap) + val nameToImpurityMap: Map[String, RegressionImpurity] = Map( + "variance" -> Variance) val names: List[String] = nameToImpurityMap.keys.toList @@ -49,17 +47,4 @@ object RegressionImpurities { } } - /** - * Given impurity type, return name. - */ - def impurityName(impurity: RegressionImpurity): String = { - if (impurityToNameMap.contains(impurity)) { - impurityToNameMap(impurity) - } else { - throw new IllegalArgumentException( - s"RegressionImpurity type ${impurity.toString}" - + " not registered in RegressionImpurities factory.") - } - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala index 6e01b0334c9fe..86cf1648931f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala @@ -36,4 +36,5 @@ private[mllib] trait RegressionImpurity extends Serializable { */ @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double + } diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java new file mode 100644 index 0000000000000..120ac228b447b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.rdd.DatasetInfo; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.configuration.DTClassifierParams; +import org.apache.spark.mllib.tree.model.DecisionTreeClassifierModel; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.Serializable; +import java.util.List; + +public class JavaDecisionTreeSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + int validatePrediction(List validationData, DecisionTreeClassifierModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + if (prediction == point.label()) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runDTUsingConstructor() { + scala.Tuple2, DatasetInfo> arr_datasetInfo = + DecisionTreeSuite.generateCategoricalDataPointsAsList(); + JavaRDD rdd = sc.parallelize(arr_datasetInfo._1()); + DatasetInfo datasetInfo = arr_datasetInfo._2(); + + DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); + dtParams.setMaxBins(200); + dtParams.setImpurity("entropy"); + DecisionTreeClassifier dtLearner = new DecisionTreeClassifier(dtParams); + DecisionTreeClassifierModel model = dtLearner.run(rdd.rdd(), datasetInfo); + + int numAccurate = validatePrediction(arr_datasetInfo._1(), model); + Assert.assertTrue(numAccurate == rdd.count()); + } + + @Test + public void runDTUsingStaticMethods() { + scala.Tuple2, DatasetInfo> arr_datasetInfo = + DecisionTreeSuite.generateCategoricalDataPointsAsList(); + JavaRDD rdd = sc.parallelize(arr_datasetInfo._1()); + DatasetInfo datasetInfo = arr_datasetInfo._2(); + + DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); + DecisionTreeClassifierModel model = + DecisionTreeClassifier.train(rdd.rdd(), datasetInfo, dtParams); + + int numAccurate = validatePrediction(arr_datasetInfo._1(), model); + Assert.assertTrue(numAccurate == rdd.count()); + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index d4b6b8264b5b5..905fffd3c4a29 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.mllib.tree +import scala.collection.JavaConversions._ + import org.apache.spark.mllib.rdd.DatasetInfo import org.scalatest.FunSuite -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.model.Split import org.apache.spark.mllib.tree.configuration.{FeatureType, DTClassifierParams, DTRegressorParams} @@ -37,10 +38,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } private val defaultClassifierParams = - new DTClassifierParams(Gini, maxDepth = 2, maxBins = 100) + new DTClassifierParams("gini", maxDepth = 2, maxBins = 100) private val defaultRegressorParams = - new DTRegressorParams(Variance, maxDepth = 2, maxBins = 100) + new DTRegressorParams("variance", maxDepth = 2, maxBins = 100) test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() @@ -519,7 +520,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numFeatures = getNumFeatures(arr)) val dtParams = defaultClassifierParams - dtParams.impurity = Entropy + dtParams.impurity = "entropy" val dtLearner = new DecisionTreeClassifier(dtParams) val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) @@ -550,7 +551,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numFeatures = getNumFeatures(arr)) val dtParams = defaultClassifierParams - dtParams.impurity = Entropy + dtParams.impurity = "entropy" val dtLearner = new DecisionTreeClassifier(dtParams) val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) @@ -581,7 +582,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numFeatures = getNumFeatures(arr)) val dtParams = defaultClassifierParams - dtParams.impurity = Entropy + dtParams.impurity = "entropy" val dtLearner = new DecisionTreeClassifier(dtParams) val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) @@ -635,7 +636,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(datasetInfo.isMulticlass) val dtParams = defaultClassifierParams - dtParams.impurity = Entropy dtParams.maxDepth = 4 val dtLearner = new DecisionTreeClassifier(dtParams) val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) @@ -764,6 +764,14 @@ object DecisionTreeSuite { arr } + def generateCategoricalDataPointsAsList(): (java.util.List[LabeledPoint], DatasetInfo) = { + val datasetInfo = new DatasetInfo( + numClasses = 2, + numFeatures = 2, + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) + (seqAsJavaList(generateCategoricalDataPoints()), datasetInfo) + } + def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { From b6b0809249a81e950f87b0a7f2c389f6c5d08f98 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 23 Jul 2014 18:07:26 -0700 Subject: [PATCH 06/20] removed mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java since it fails currently --- .../mllib/tree/JavaDecisionTreeSuite.java | 91 ------------------- 1 file changed, 91 deletions(-) delete mode 100644 mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java deleted file mode 100644 index 120ac228b447b..0000000000000 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.rdd.DatasetInfo; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.configuration.DTClassifierParams; -import org.apache.spark.mllib.tree.model.DecisionTreeClassifierModel; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import java.io.Serializable; -import java.util.List; - -public class JavaDecisionTreeSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaDecisionTreeSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } - - int validatePrediction(List validationData, DecisionTreeClassifierModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - if (prediction == point.label()) { - numAccurate++; - } - } - return numAccurate; - } - - @Test - public void runDTUsingConstructor() { - scala.Tuple2, DatasetInfo> arr_datasetInfo = - DecisionTreeSuite.generateCategoricalDataPointsAsList(); - JavaRDD rdd = sc.parallelize(arr_datasetInfo._1()); - DatasetInfo datasetInfo = arr_datasetInfo._2(); - - DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); - dtParams.setMaxBins(200); - dtParams.setImpurity("entropy"); - DecisionTreeClassifier dtLearner = new DecisionTreeClassifier(dtParams); - DecisionTreeClassifierModel model = dtLearner.run(rdd.rdd(), datasetInfo); - - int numAccurate = validatePrediction(arr_datasetInfo._1(), model); - Assert.assertTrue(numAccurate == rdd.count()); - } - - @Test - public void runDTUsingStaticMethods() { - scala.Tuple2, DatasetInfo> arr_datasetInfo = - DecisionTreeSuite.generateCategoricalDataPointsAsList(); - JavaRDD rdd = sc.parallelize(arr_datasetInfo._1()); - DatasetInfo datasetInfo = arr_datasetInfo._2(); - - DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); - DecisionTreeClassifierModel model = - DecisionTreeClassifier.train(rdd.rdd(), datasetInfo, dtParams); - - int numAccurate = validatePrediction(arr_datasetInfo._1(), model); - Assert.assertTrue(numAccurate == rdd.count()); - } - -} From a2a93115a1f2106e13bb122589a28669310d19f5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 23 Jul 2014 18:07:26 -0700 Subject: [PATCH 07/20] removed mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java since it fails currently MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Comments which should have been added to previous commit: Fixed one test in DecisionTreeSuite to undo a change in previous commit (“stump with categorical variables for multiclass classification”). Reverted impurity from Entropy back to Gini. Java compatibility: * Changed non-static train() methods’ names to run() to avoid conflicts with static train() methods in Java. * Added setter functions to *Params classes. --- .../mllib/tree/JavaDecisionTreeSuite.java | 91 ------------------- 1 file changed, 91 deletions(-) delete mode 100644 mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java deleted file mode 100644 index 120ac228b447b..0000000000000 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.rdd.DatasetInfo; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.configuration.DTClassifierParams; -import org.apache.spark.mllib.tree.model.DecisionTreeClassifierModel; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import java.io.Serializable; -import java.util.List; - -public class JavaDecisionTreeSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaDecisionTreeSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } - - int validatePrediction(List validationData, DecisionTreeClassifierModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - if (prediction == point.label()) { - numAccurate++; - } - } - return numAccurate; - } - - @Test - public void runDTUsingConstructor() { - scala.Tuple2, DatasetInfo> arr_datasetInfo = - DecisionTreeSuite.generateCategoricalDataPointsAsList(); - JavaRDD rdd = sc.parallelize(arr_datasetInfo._1()); - DatasetInfo datasetInfo = arr_datasetInfo._2(); - - DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); - dtParams.setMaxBins(200); - dtParams.setImpurity("entropy"); - DecisionTreeClassifier dtLearner = new DecisionTreeClassifier(dtParams); - DecisionTreeClassifierModel model = dtLearner.run(rdd.rdd(), datasetInfo); - - int numAccurate = validatePrediction(arr_datasetInfo._1(), model); - Assert.assertTrue(numAccurate == rdd.count()); - } - - @Test - public void runDTUsingStaticMethods() { - scala.Tuple2, DatasetInfo> arr_datasetInfo = - DecisionTreeSuite.generateCategoricalDataPointsAsList(); - JavaRDD rdd = sc.parallelize(arr_datasetInfo._1()); - DatasetInfo datasetInfo = arr_datasetInfo._2(); - - DTClassifierParams dtParams = DecisionTreeClassifier.defaultParams(); - DecisionTreeClassifierModel model = - DecisionTreeClassifier.train(rdd.rdd(), datasetInfo, dtParams); - - int numAccurate = validatePrediction(arr_datasetInfo._1(), model); - Assert.assertTrue(numAccurate == rdd.count()); - } - -} From 3ff5027c8fc7cd3e5a84233ceb763dc905ec6cc0 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 24 Jul 2014 15:31:04 -0700 Subject: [PATCH 08/20] =?UTF-8?q?Bug=20fix:=20Indexing=20was=20inconsisten?= =?UTF-8?q?t=20for=20aggregate=20calculations=20for=20unordered=20features?= =?UTF-8?q?=20(in=20multiclass=20classification=20with=20categorical=20fea?= =?UTF-8?q?tures,=20where=20the=20features=20had=20few=20enough=20values?= =?UTF-8?q?=20such=20that=20they=20could=20be=20considered=20unordered,=20?= =?UTF-8?q?i.e.,=20isSpaceSufficientForAllCategoricalSplits=3Dtrue).=20*?= =?UTF-8?q?=20updateBinForUnorderedFeature=20indexed=20agg=20as=20(node,?= =?UTF-8?q?=20feature,=20featureValue,=20binIndex),=20where=20**=20feature?= =?UTF-8?q?Value=20was=20from=20arr=20(so=20it=20was=20a=20feature=20value?= =?UTF-8?q?)=20**=20binIndex=20was=20in=20[0,=E2=80=A6,=202^(maxFeatureVal?= =?UTF-8?q?ue-1)-1)=20*=20The=20rest=20of=20the=20code=20indexed=20agg=20a?= =?UTF-8?q?s=20(node,=20feature,=20binIndex,=20label).=20*=20Corrected=20t?= =?UTF-8?q?his=20bug=20by=20changing=20updateBinForUnorderedFeature=20to?= =?UTF-8?q?=20use=20the=20second=20indexing=20pattern.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Unit tests in DecisionTreeSuite * Updated a few tests to train a model and test its training accuracy, which catches the indexing bug from updateBinForUnorderedFeature() discussed above. * Added new test (“stump with categorical variables for multiclass classification, with just enough bins”) to test bin extremes. Bug fix: calculateGainForSplit (for classification): * It used to return dummy prediction values when either the right or left children had 0 weight. These were incorrect for multiclass classification. It has been corrected. Updated impurities to allow for count = 0. This was related to the above bug fix for calculateGainForSplit (for classification). Small updates to documentation and coding style. --- .../spark/mllib/tree/DecisionTree.scala | 45 +++--- .../mllib/tree/DecisionTreeClassifier.scala | 143 ++++++++++-------- .../mllib/tree/DecisionTreeRegressor.scala | 2 +- .../configuration/DTClassifierParams.scala | 3 +- .../configuration/DTRegressorParams.scala | 3 +- .../impurity/ClassificationImpurity.scala | 2 +- .../spark/mllib/tree/impurity/Entropy.scala | 5 +- .../spark/mllib/tree/impurity/Gini.scala | 5 +- .../tree/impurity/RegressionImpurity.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 4 + .../spark/mllib/tree/DecisionTreeSuite.scala | 67 +++++++- 11 files changed, 186 insertions(+), 95 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4cd5ebcaa86e4..0b37a7926e238 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -32,8 +32,8 @@ import org.apache.spark.util.random.XORShiftRandom /** * :: Experimental :: - * A class that implements a decision tree algorithm for classification and regression. It - * supports both continuous and categorical features. + * An abstract class for decision tree algorithms for classification and regression. + * It supports both continuous and categorical features. * @param params The configuration parameters for the tree algorithm. */ @Experimental @@ -118,7 +118,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa val splitsStatsForLevel = findBestSplits(input, datasetInfo, parentImpurities, level, filters, splits, bins, maxLevelForSingleGroup) - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) // Extract info for nodes at the next lower level. @@ -525,7 +525,8 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa } else { // Perform sequential search to find bin for categorical features. val binIndex = { - if (isMulticlass && isSpaceSufficientForAllCategoricalSplits) { + val isUnorderedFeature = isMulticlass && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { sequentialBinSearchForOrderedCategoricalFeatureInClassification() @@ -539,13 +540,15 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa } /** - * Finds bins for all nodes (and all features) at a given level. + * Finds bins for all nodes (and all features) in a given (level, group). * For l nodes, k features the storage is as follows: * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, * where b_ij is an integer between 0 and numBins - 1 for regressions and binary * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. * + * For unordered features, the "bin index" returned is actually the feature value (category). + * * @return Array of size 1 + numFeatures * numNodes, where * arr(0) = label for labeledPoint, and * arr(1 + numFeatures * nodeIndex + featureIndex) = @@ -577,10 +580,10 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) } else { val featureCategories = featureInfo.get - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 + arr(shift + featureIndex) = + findBin(featureIndex, labeledPoint, isFeatureContinuous, isSpaceSufficientForAllCategoricalSplits) } featureIndex += 1 @@ -591,7 +594,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa arr } - // Find feature bins for all nodes at a level. + // Find feature bins for all nodes in this (level, group). val binMappedRDD = input.map(x => findBinsForLevel(x)) // Performs a sequential aggregation over a partition. @@ -621,7 +624,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa // Calculate bin aggregates. val binAggregates = { - binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) } logDebug("binAggregates.length = " + binAggregates.length) @@ -645,7 +648,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa /** * Find the best split for a node. - * @param binData Array[Double] of size 2 * numBins * numFeatures + * @param binData Bin data slice for this node, given by getBinDataForNode. * @param nodeImpurity impurity of the top node * @return tuple of split and information gain */ @@ -656,12 +659,13 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa logDebug("node impurity = " + nodeImpurity) // Extract left right node aggregates. - val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData, datasetInfo, numBins) + val (leftNodeAgg, rightNodeAgg) = + extractLeftRightNodeAggregates(binData, datasetInfo, numBins) // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - val (bestFeatureIndex,bestSplitIndex, gainStats) = { + val (bestFeatureIndex, bestSplitIndex, gainStats) = { // Initialize with infeasible values. var bestFeatureIndex = Int.MinValue var bestSplitIndex = Int.MinValue @@ -676,7 +680,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa if (isFeatureContinuous) { numBins - 1 } else { // Categorical feature - val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) + val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) val isSpaceSufficientForAllCategoricalSplits = numBins > math.pow(2, featureCategories.toInt - 1) - 1 if (isMulticlass && isSpaceSufficientForAllCategoricalSplits) { @@ -706,7 +710,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa (splits(bestFeatureIndex)(bestSplitIndex), gainStats) } - // Calculate best splits for all nodes at a given level + // Calculate best splits for all nodes in this (level, group). val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level var node = 0 @@ -765,7 +769,6 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa val isMulticlass = datasetInfo.isMulticlass logDebug("isMulticlass = " + isMulticlass) - /* * Ensure #bins is always greater than the categories. For multiclass classification, * #bins should be greater than 2^(maxCategories - 1) - 1. @@ -819,8 +822,9 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa numBins > math.pow(2, featureCategories.toInt - 1) - 1 // Use different bin/split calculation strategy for categorical features in multiclass - // classification that satisfy the space constraint - if (isMulticlass && isSpaceSufficientForAllCategoricalSplits) { + // classification that satisfy the space constraint. + val isUnorderedFeature = isMulticlass && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { // 2^(maxFeatureValue - 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { @@ -845,7 +849,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa } index += 1 } - } else { + } else { // ordered feature val centroidForCategories = computeCentroidForCategories(featureIndex, sampledInput, datasetInfo) @@ -914,6 +918,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa } + object DecisionTree extends Serializable with Logging { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala index 6987a11ddd2b4..31328eacd3997 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala @@ -87,7 +87,11 @@ class DecisionTreeClassifier (params: DTClassifierParams) /** * Extracts left and right split aggregates. - * @param binData Array[Double] of size 2 * numFeatures * numBins + * @param binData Aggregate array slice from getBinDataForNode. + * For unordered features, this is leftChildData ++ rightChildData, + * each of which is indexed by (feature, split/bin, class), + * with class being the least significant bit. + * For ordered features, this is of size numClasses * numBins * numFeatures. * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, * (numBins - 1), numClasses) @@ -134,6 +138,11 @@ class DecisionTreeClassifier (params: DTClassifierParams) } } + /** + * Reshape binData for this feature. + * Indexes binData as (feature, split, class) with class as the least significant bit. + * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value + */ def findAggForUnorderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], @@ -209,13 +218,12 @@ class DecisionTreeClassifier (params: DTClassifierParams) * either the left count or the right count of one of the p bins is * incremented based upon whether the feature is classified as 0 or 1. * @param agg Array storing aggregate calculation, of size: - * numClasses * numBins * numFeatures * numNodes - * TODO: FIX DOC + * numClasses * numBins * numFeatures * numNodes for ordered features, or + * 2 * numClasses * numBins * numFeatures * numNodes for unordered features * @param arr Bin mapping from findBinsForLevel. * Array of size 1 + (numFeatures * numNodes). * @return Array storing aggregate calculation, of size: - * 2 * numBins * numFeatures * numNodes for ordered features, or - * 2 * numClasses * numBins * numFeatures * numNodes for unordered features + * */ protected def binSeqOpSub( agg: Array[Double], @@ -225,19 +233,21 @@ class DecisionTreeClassifier (params: DTClassifierParams) bins: Array[Array[Bin]]): Array[Double] = { val numBins = bins(0).length if(datasetInfo.isMulticlassWithCategoricalFeatures) { - unorderedClassificationBinSeqOp(arr, agg, datasetInfo, numNodes, bins) + multiclassWithCategoricalBinSeqOp(arr, agg, datasetInfo, numNodes, bins) } else { - orderedClassificationBinSeqOp(arr, agg, datasetInfo, numNodes, numBins) + binaryOrNoCategoricalBinSeqOp(arr, agg, datasetInfo, numNodes, numBins) } agg } /** * Calculates the information gain for all splits based upon left/right split aggregates. - * @param leftNodeAgg left node aggregates + * @param leftNodeAgg Left node aggregates: + * leftNodeAgg(feature)(split)(class) = weight of examples * @param featureIndex feature index * @param splitIndex split index - * @param rightNodeAgg right node aggregate + * @param rightNodeAgg Right node aggregates: + * rightNodeAgg(feature)(split)(class) = weight of examples * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ @@ -252,20 +262,10 @@ class DecisionTreeClassifier (params: DTClassifierParams) val numClasses = datasetInfo.numClasses - val leftCounts: Array[Double] = new Array[Double](numClasses) - val rightCounts: Array[Double] = new Array[Double](numClasses) - var leftTotalCount = 0.0 - var rightTotalCount = 0.0 - var classIndex = 0 - while (classIndex < numClasses) { - val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) - val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) - leftCounts(classIndex) = leftClassCount - leftTotalCount += leftClassCount - rightCounts(classIndex) = rightClassCount - rightTotalCount += rightClassCount - classIndex += 1 - } + val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) + val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) + var leftTotalCount = leftCounts.sum + var rightTotalCount = rightCounts.sum val impurity = { if (level > 0) { @@ -282,33 +282,15 @@ class DecisionTreeClassifier (params: DTClassifierParams) } } - if (leftTotalCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) - } - if (rightTotalCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) - } - - val leftImpurity = impurityFunctor.calculate(leftCounts, leftTotalCount) - val rightImpurity = impurityFunctor.calculate(rightCounts, rightTotalCount) - - val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount) - val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount) - - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } - val totalCount = leftTotalCount + rightTotalCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } // Sum of count for each label - val leftRightCounts: Array[Double] - = leftCounts.zip(rightCounts) - .map{case (leftCount, rightCount) => leftCount + rightCount} + val leftRightCounts: Array[Double] = + leftCounts.zip(rightCounts).map{ case (leftCount, rightCount) => leftCount + rightCount } def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { @@ -322,11 +304,40 @@ class DecisionTreeClassifier (params: DTClassifierParams) val predict = indexOfLargestArrayElement(leftRightCounts) val prob = leftRightCounts(predict) / totalCount + val leftImpurity = if (leftTotalCount == 0) { + topImpurity + } else { + impurityFunctor.calculate(leftCounts, leftTotalCount) + } + val rightImpurity = if (rightTotalCount == 0) { + topImpurity + } else { + impurityFunctor.calculate(rightCounts, rightTotalCount) + } + + val leftWeight = leftTotalCount / totalCount + val rightWeight = rightTotalCount / totalCount + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) } /** * Get bin data for one node. + * + * @param node Node index in this (level, group). + * @param binAggregates For unordered features, + * the first half of binAggregates contains leftChildData, + * and the second half contains rightChildData. + * Each half is of size numNodes * numFeatures * numBins * numClasses. + * For ordered features, + * this is of size numNodes * numFeatures * numBins * numClasses. + * Indexing uses node as the most significant bit. + * @return For unordered features, returns leftChildData ++ rightChildData, + * each of which is indexed by (feature, bin/split, class), + * with class being the least significant bit. + * For ordered features, returns data of size numClasses * numBins * numFeatures. */ protected def getBinDataForNode( node: Int, @@ -338,11 +349,12 @@ class DecisionTreeClassifier (params: DTClassifierParams) val shift = datasetInfo.numClasses * node * numBins * datasetInfo.numFeatures val rightChildShift = datasetInfo.numClasses * numBins * datasetInfo.numFeatures * numNodes val binsForNode = { - val leftChildData - = binAggregates.slice(shift, shift + datasetInfo.numClasses * numBins * datasetInfo.numFeatures) - val rightChildData - = binAggregates.slice(rightChildShift + shift, - rightChildShift + shift + datasetInfo.numClasses * numBins * datasetInfo.numFeatures) + val leftChildData = binAggregates.slice( + shift, + shift + datasetInfo.numClasses * numBins * datasetInfo.numFeatures) + val rightChildData = binAggregates.slice( + rightChildShift + shift, + rightChildShift + shift + datasetInfo.numClasses * numBins * datasetInfo.numFeatures) leftChildData ++ rightChildData } binsForNode @@ -388,9 +400,8 @@ class DecisionTreeClassifier (params: DTClassifierParams) /** * - * @param arr Size numNodes * numFeatures + 1. - * Indexed by (node, feature) where feature is the least significant bit, - * shifted by 1. + * @param arr arr(0) = label. + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) * @param agg Indexed by (node, feature, bin, label) where label is the least significant bit. * @param rightChildShift * @param bins @@ -408,19 +419,22 @@ class DecisionTreeClassifier (params: DTClassifierParams) // Find the bin index for this feature. val arrIndex = 1 + datasetInfo.numFeatures * nodeIndex + featureIndex + val featureValue = arr(arrIndex) // Update the left or right count for one bin. - val aggShift = datasetInfo.numClasses * numBins * datasetInfo.numFeatures * nodeIndex - val aggIndex = aggShift + datasetInfo.numClasses * featureIndex * numBins + - arr(arrIndex).toInt * datasetInfo.numClasses + val aggShift = + nodeIndex * datasetInfo.numFeatures * numBins * datasetInfo.numClasses + + featureIndex * numBins * datasetInfo.numClasses + + label.toInt // Find all matching bins and increment their values val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { - if (bins(featureIndex)(binIndex).highSplit.categories.contains(label.toInt)) { - agg(aggIndex + binIndex) += 1 + val aggIndex = aggShift + binIndex * datasetInfo.numClasses + if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { + agg(aggIndex) += 1 } else { - agg(rightChildShift + aggIndex + binIndex) += 1 + agg(rightChildShift + aggIndex) += 1 } binIndex += 1 } @@ -436,7 +450,7 @@ class DecisionTreeClassifier (params: DTClassifierParams) * @param numNodes * @param numBins */ - private def orderedClassificationBinSeqOp( + private def binaryOrNoCategoricalBinSeqOp( arr: Array[Double], agg: Array[Double], datasetInfo: DatasetInfo, @@ -471,11 +485,13 @@ class DecisionTreeClassifier (params: DTClassifierParams) * numClasses * numBins * numFeatures * numNodes * // Size set by getElementsPerNode(): * // 2 * numClasses * numBins * numFeatures * numNodes + * SHOULD BE indexed by (node, feature, bin, class), + * with class being the least significant bit. (based on future use) * @param datasetInfo Dataset metadata. * @param numNodes Number of nodes in this (level, group). * @param bins */ - private def unorderedClassificationBinSeqOp( + private def multiclassWithCategoricalBinSeqOp( arr: Array[Double], agg: Array[Double], datasetInfo: DatasetInfo, @@ -518,6 +534,7 @@ class DecisionTreeClassifier (params: DTClassifierParams) } + object DecisionTreeClassifier extends Serializable with Logging { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala index 5b5ac139f7d87..c525eaebfa727 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala @@ -41,7 +41,7 @@ class DecisionTreeRegressor (params: DTRegressorParams) /** * Method to train a decision tree model over an RDD - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param datasetInfo Dataset metadata specifying number of classes, features, etc. * @return a DecisionTreeRegressorModel that can be used for prediction */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala index d43b53d815d25..7682e8478ccfa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -23,7 +23,8 @@ import org.apache.spark.mllib.tree.impurity.ClassificationImpurities /** * :: Experimental :: * Stores all the configuration options for DecisionTreeClassifier construction - * @param impurity criterion used for information gain calculation (e.g., "gini" or "entropy") + * @param impurity Criterion used for information gain calculation. + * Currently supported: "gini", "entropy" * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param maxBins maximum number of bins used for splitting features diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala index acdf69675eff6..22c18dc853fa0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -23,7 +23,8 @@ import org.apache.spark.mllib.tree.impurity.RegressionImpurities /** * :: Experimental :: * Stores all the configuration options for DecisionTreeRegressor construction - * @param impurity criterion used for information gain calculation (e.g., Variance) + * @param impurity Criterion used for information gain calculation. + * Currently supported: "variance" * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param maxBins maximum number of bins used for splitting features diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurity.scala index 3aade2eeaac72..1658cc38b806b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ClassificationImpurity.scala @@ -31,7 +31,7 @@ private[mllib] trait ClassificationImpurity extends Serializable { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if count = 0 */ @DeveloperApi def calculate(counts: Array[Double], totalCount: Double): Double diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 452a88c5ca69e..42d6eb4a08a44 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -34,10 +34,13 @@ private[mllib] object Entropy extends ClassificationImpurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 0.0 var classIndex = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 20ca09f4a0395..f91c4c08f4439 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -33,10 +33,13 @@ private[mllib] object Gini extends ClassificationImpurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 1.0 var classIndex = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala index 86cf1648931f7..71f075b02a22b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/RegressionImpurity.scala @@ -32,7 +32,7 @@ private[mllib] trait RegressionImpurity extends Serializable { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels - * @return information value + * @return information value, or 0 if count = 0 */ @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index c6d674ddb7857..0e561f6a681b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -32,9 +32,13 @@ private[mllib] object Variance extends RegressionImpurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return variance, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + if (count == 0) { + return 0 + } val squaredLoss = sumSquares - (sum * sum) / count squaredLoss / count } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 905fffd3c4a29..b29dc2370888c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -22,10 +22,11 @@ import scala.collection.JavaConversions._ import org.apache.spark.mllib.rdd.DatasetInfo import org.scalatest.FunSuite -import org.apache.spark.mllib.tree.model.Filter -import org.apache.spark.mllib.tree.model.Split import org.apache.spark.mllib.tree.configuration.{FeatureType, DTClassifierParams, DTRegressorParams} import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.model.DecisionTreeClassifierModel +import org.apache.spark.mllib.tree.model.Filter +import org.apache.spark.mllib.tree.model.Split import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.regression.LabeledPoint @@ -37,11 +38,25 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { case _ => data(0).features.size } - private val defaultClassifierParams = + def validateModel( + model: DecisionTreeClassifierModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map { x => model.predict(x.features) } + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy) + } + + private def defaultClassifierParams: DTClassifierParams = { new DTClassifierParams("gini", maxDepth = 2, maxBins = 100) + } - private val defaultRegressorParams = + private def defaultRegressorParams: DTRegressorParams = { new DTRegressorParams("variance", maxDepth = 2, maxBins = 100) + } test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() @@ -651,6 +666,40 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } + test("stump with categorical variables for multiclass classification, with just enough bins") { + val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() + val rdd = sc.parallelize(arr) + val datasetInfo = new DatasetInfo( + numClasses = 3, + numFeatures = getNumFeatures(arr), + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(datasetInfo.isMulticlass) + + val dtParams = defaultClassifierParams + dtParams.maxDepth = 4 + dtParams.maxBins = maxBins + val dtLearner = new DecisionTreeClassifier(dtParams) + + val model = dtLearner.run(rdd, datasetInfo) + validateModel(model, arr, 1.0) + + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + + val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(1)) + assert(bestSplit.featureType === Categorical) + val gain = bestSplits(0)._2 + assert(gain.leftImpurity == 0) + assert(gain.rightImpurity == 0) + } + test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val rdd = sc.parallelize(arr) @@ -662,10 +711,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val dtParams = defaultClassifierParams dtParams.maxDepth = 4 val dtLearner = new DecisionTreeClassifier(dtParams) + + val model = dtLearner.run(rdd, datasetInfo) + validateModel(model, arr, 0.9) + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, Array[List[Filter]](), splits, bins, 10) + assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -688,8 +742,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val dtParams = defaultClassifierParams dtParams.maxDepth = 4 val dtLearner = new DecisionTreeClassifier(dtParams) - val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) + val model = dtLearner.run(rdd, datasetInfo) + validateModel(model, arr, 0.9) + + val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) val bestSplits = dtLearner.findBestSplits(rdd, datasetInfo, new Array(31), 0, Array[List[Filter]](), splits, bins, 10) From e1243a56bfed4031ecfa00462474b4f16e07b937 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 24 Jul 2014 18:19:12 -0700 Subject: [PATCH 09/20] Fixed scala style issues reported by Jenkins --- .../spark/mllib/tree/DecisionTree.scala | 8 ----- .../mllib/tree/DecisionTreeClassifier.scala | 29 +++++++------------ .../mllib/tree/DecisionTreeRegressor.scala | 6 +--- .../configuration/DTClassifierParams.scala | 2 +- .../configuration/DTRegressorParams.scala | 2 +- .../model/DecisionTreeClassifierModel.scala | 2 -- 6 files changed, 14 insertions(+), 35 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 0b37a7926e238..b0fdf8a877aae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -149,10 +149,6 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa topNode } - //=========================================================================== - // Protected abstract methods - //=========================================================================== - /** * For a given categorical feature, use a subsample of the data * to choose how to arrange possible splits. @@ -232,10 +228,6 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa numNodes: Int, numBins: Int): Array[Double] - //=========================================================================== - // Protected (non-abstract) methods - //=========================================================================== - /** * Extract the decision tree node information for the given tree level and node index */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala index 31328eacd3997..5cfe1ebc25c3e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala @@ -23,8 +23,6 @@ import org.apache.spark.mllib.rdd.DatasetInfo import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.DTClassifierParams import org.apache.spark.mllib.tree.impurity.ClassificationImpurities - -//import org.apache.spark.mllib.tree.impurity.{ClassificationImpurity, ClassificationImpurities} import org.apache.spark.mllib.tree.model.{InformationGainStats, Bin, DecisionTreeClassifierModel} import org.apache.spark.rdd.RDD @@ -58,10 +56,6 @@ class DecisionTreeClassifier (params: DTClassifierParams) new DecisionTreeClassifierModel(topNode) } - //=========================================================================== - // Protected methods (abstract from DecisionTree) - //=========================================================================== - protected def computeCentroidForCategories( featureIndex: Int, sampledInput: Array[LabeledPoint], @@ -107,15 +101,16 @@ class DecisionTreeClassifier (params: DTClassifierParams) featureIndex: Int) { // shift for this featureIndex - val shift = datasetInfo.numClasses * featureIndex * numBins + val numClasses = datasetInfo.numClasses + val shift = numClasses * featureIndex * numBins var classIndex = 0 - while (classIndex < datasetInfo.numClasses) { + while (classIndex < numClasses) { // left node aggregate for the lowest split leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) // right node aggregate for the highest split rightNodeAgg(featureIndex)(numBins - 2)(classIndex) - = binData(shift + (datasetInfo.numClasses * (numBins - 1)) + classIndex) + = binData(shift + (numClasses * (numBins - 1)) + classIndex) classIndex += 1 } @@ -125,12 +120,12 @@ class DecisionTreeClassifier (params: DTClassifierParams) // calculating left node aggregate for a split as a sum of left node aggregate of a // lower split and the left bin aggregate of a bin where the split is a high split var innerClassIndex = 0 - while (innerClassIndex < datasetInfo.numClasses) { + while (innerClassIndex < numClasses) { leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) - = binData(shift + datasetInfo.numClasses * splitIndex + innerClassIndex) + + = binData(shift + numClasses * splitIndex + innerClassIndex) + leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = - binData(shift + (datasetInfo.numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + + binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) innerClassIndex += 1 } @@ -367,10 +362,6 @@ class DecisionTreeClassifier (params: DTClassifierParams) } } - //=========================================================================== - // Private methods - //=========================================================================== - /** * Increment aggregate in location for (node, feature, bin, label) * to indicate that, for this (example, @@ -513,7 +504,8 @@ class DecisionTreeClassifier (params: DTClassifierParams) while (featureIndex < datasetInfo.numFeatures) { val isFeatureContinuous = datasetInfo.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, datasetInfo, numBins) + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, datasetInfo, + numBins) } else { val featureCategories = datasetInfo.categoricalFeaturesInfo(featureIndex) val isSpaceSufficientForAllCategoricalSplits = @@ -522,7 +514,8 @@ class DecisionTreeClassifier (params: DTClassifierParams) updateBinForUnorderedFeature(arr, agg, nodeIndex, featureIndex, label, rightChildShift, datasetInfo, numBins, bins) } else { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, datasetInfo, numBins) + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex, datasetInfo, + numBins) } } featureIndex += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala index c525eaebfa727..8ee5ae3b984e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala @@ -57,16 +57,12 @@ class DecisionTreeRegressor (params: DTRegressorParams) new DecisionTreeRegressorModel(topNode) } - //=========================================================================== - // Protected methods (abstract from DecisionTree) - //=========================================================================== - protected def computeCentroidForCategories( featureIndex: Int, sampledInput: Array[LabeledPoint], datasetInfo: DatasetInfo): Map[Double,Double] = { // For categorical variables in regression, each bin is a category. - // The bins are sorted and are ordered by calculating the centroid of their corresponding labels. + // The bins are sorted and ordered by calculating the centroid of their corresponding labels. sampledInput.map(lp => (lp.features(featureIndex), lp.label)) .groupBy(_._1) .mapValues(x => x.map(_._2).sum / x.map(_._1).length) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala index 7682e8478ccfa..a42c1e53c8104 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -57,4 +57,4 @@ object DTClassifierParams { */ final val supportedImpurities: List[String] = ClassificationImpurities.names -} \ No newline at end of file +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala index 22c18dc853fa0..623f70d897a6b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -57,4 +57,4 @@ object DTRegressorParams { */ final val supportedImpurities: List[String] = RegressionImpurities.names -} \ No newline at end of file +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala index 3cd616820fbe7..e5e458ef34daf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala @@ -18,8 +18,6 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.Experimental -//import org.apache.spark.mllib.tree.model.DecisionTreeModel -//import org.apache.spark.mllib.tree.model.Node /** From 3eea3045c1c5bea61ad55a0d95fd65013b94fce5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 24 Jul 2014 22:36:07 -0700 Subject: [PATCH 10/20] Added Algo exception to MimaExcludes.scala --- project/MimaExcludes.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index e9220db6b1f9a..a5ff4ba124603 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -83,6 +83,8 @@ object MimaExcludes { MimaBuild.excludeSparkClass("storage.Entry") ++ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ Seq( + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.tree.configuration.Algo"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.mllib.tree.impurity.Gini.calculate"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( From cda2a8049474a684e08de3e4fb4beb77e8a77c71 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 24 Jul 2014 23:06:08 -0700 Subject: [PATCH 11/20] Added more exceptions to MimaExcludes.scala --- project/MimaExcludes.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a5ff4ba124603..208281d9aa78f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -83,13 +83,17 @@ object MimaExcludes { MimaBuild.excludeSparkClass("storage.Entry") ++ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ Seq( + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.mllib.tree.DecisionTree"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.log2"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.mllib.tree.configuration.Algo"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( + ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.tree.impurity.Gini.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( + ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( + ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.tree.impurity.Variance.calculate") ) case v if v.startsWith("1.0") => From e73dc326f7b6e6a8f8a24e559aedaef6ec4af2e3 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 25 Jul 2014 08:20:30 -0700 Subject: [PATCH 12/20] Added yet more exceptions to MimaExcludes.scala --- project/MimaExcludes.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 208281d9aa78f..8e4016d5480c3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -91,8 +91,12 @@ object MimaExcludes { "org.apache.spark.mllib.tree.configuration.Algo"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.tree.impurity.Gini.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Gini.calculate"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.tree.impurity.Variance.calculate") ) From 07e9c16401b0e0b1f5f0f773e814f2f4e961303a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 28 Jul 2014 11:23:01 -0700 Subject: [PATCH 13/20] Modified Decision Tree params classes to use Scala BeansProperty for getters/setters. --- .../configuration/DTClassifierParams.scala | 2 ++ .../mllib/tree/configuration/DTParams.scala | 24 ++++++------------- .../configuration/DTRegressorParams.scala | 2 ++ 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala index a42c1e53c8104..effb4b5d3acc6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -41,6 +41,8 @@ class DTClassifierParams ( maxMemoryInMB: Int = 128) extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { + def getImpurity: String = this.impurity + def setImpurity(impurity: String) = { if (!ClassificationImpurities.nameToImpurityMap.contains(impurity)) { throw new IllegalArgumentException(s"Bad impurity parameter for classification: $impurity") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala index 1eba6da327ebd..9ab7dff89dffd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree.configuration +import scala.beans.BeanProperty + import org.apache.spark.annotation.Experimental /** @@ -31,18 +33,12 @@ import org.apache.spark.annotation.Experimental */ @Experimental class DTParams ( - var maxDepth: Int, - var maxBins: Int, + @BeanProperty var maxDepth: Int, + @BeanProperty var maxBins: Int, var quantileStrategy: String, - var maxMemoryInMB: Int) extends Serializable { - - def setMaxDepth(maxDepth: Int) = { - this.maxDepth = maxDepth - } + @BeanProperty var maxMemoryInMB: Int) extends Serializable { - def setMaxBins(maxBins: Int) = { - this.maxBins = maxBins - } + def getQuantileStrategy: String = this.quantileStrategy def setQuantileStrategy(quantileStrategy: String) = { if (!QuantileStrategies.nameToStrategyMap.contains(quantileStrategy)) { @@ -51,15 +47,9 @@ class DTParams ( this.quantileStrategy = quantileStrategy } - def setMaxMemoryInMB(maxMemoryInMB: Int) = { - this.maxMemoryInMB = maxMemoryInMB - } - /** * Get list of supported quantileStrategy options. */ - def supportedQuantileStrategies(): List[String] = { - QuantileStrategies.nameToStrategyMap.keys.toList - } + def supportedQuantileStrategies: List[String] = QuantileStrategies.nameToStrategyMap.keys.toList } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala index 623f70d897a6b..35b224a0a9036 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -41,6 +41,8 @@ class DTRegressorParams ( maxMemoryInMB: Int = 128) extends DTParams(maxDepth, maxBins, quantileStrategy, maxMemoryInMB) { + def getImpurity: String = this.impurity + def setImpurity(impurity: String) = { if (!RegressionImpurities.nameToImpurityMap.contains(impurity)) { throw new IllegalArgumentException(s"Bad impurity parameter for regression: $impurity") From becec3fdc4277182c967a6dd065a53eecf2d9ddb Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 28 Jul 2014 14:38:51 -0700 Subject: [PATCH 14/20] Made DTParams class abstract. Moved supported* methods in DT*Params to objects. --- .../mllib/tree/configuration/DTClassifierParams.scala | 7 ++++++- .../apache/spark/mllib/tree/configuration/DTParams.scala | 9 +++++++-- .../mllib/tree/configuration/DTRegressorParams.scala | 7 ++++++- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala index effb4b5d3acc6..e28f7513d1dd2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -57,6 +57,11 @@ object DTClassifierParams { /** * List of supported impurity options. */ - final val supportedImpurities: List[String] = ClassificationImpurities.names + def supportedImpurities: List[String] = ClassificationImpurities.names + + /** + * Get list of supported quantileStrategy options. + */ + def supportedQuantileStrategies: List[String] = DTParams.supportedQuantileStrategies } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala index 9ab7dff89dffd..bfd24c0fc8903 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala @@ -32,7 +32,7 @@ import org.apache.spark.annotation.Experimental * 128 MB. */ @Experimental -class DTParams ( +abstract class DTParams ( @BeanProperty var maxDepth: Int, @BeanProperty var maxBins: Int, var quantileStrategy: String, @@ -47,9 +47,14 @@ class DTParams ( this.quantileStrategy = quantileStrategy } +} + + +object DTParams { + /** * Get list of supported quantileStrategy options. */ def supportedQuantileStrategies: List[String] = QuantileStrategies.nameToStrategyMap.keys.toList -} +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala index 35b224a0a9036..33a7b100c8696 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -57,6 +57,11 @@ object DTRegressorParams { /** * List of supported impurity options. */ - final val supportedImpurities: List[String] = RegressionImpurities.names + def supportedImpurities: List[String] = RegressionImpurities.names + + /** + * Get list of supported quantileStrategy options. + */ + def supportedQuantileStrategies: List[String] = DTParams.supportedQuantileStrategies } From c0a46be7c6c43d504149cfc0333fd9c6393fbd70 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 28 Jul 2014 14:50:03 -0700 Subject: [PATCH 15/20] added newline character for Scala style --- .../org/apache/spark/mllib/tree/configuration/DTParams.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala index bfd24c0fc8903..9cec12f47c83a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala @@ -57,4 +57,4 @@ object DTParams { */ def supportedQuantileStrategies: List[String] = QuantileStrategies.nameToStrategyMap.keys.toList -} \ No newline at end of file +} From 4bea4bd935a0b60a628b0d2d51fbee7d1962f972 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 28 Jul 2014 15:19:51 -0700 Subject: [PATCH 16/20] Updated documentation for Decision Trees based on new API --- docs/mllib-decision-tree.md | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 9cbd880897578..9b8bf7063b0fe 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -116,12 +116,10 @@ maximum tree depth of 5. The training error is calculated to measure the algorit
{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.rdd.DatasetInfo +import org.apache.spark.mllib.tree.DecisionTreeClassifier import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.Gini // Load and parse the data file val data = sc.textFile("data/mllib/sample_tree_data.csv") @@ -129,10 +127,17 @@ val parsedData = data.map { line => val parts = line.split(',').map(_.toDouble) LabeledPoint(parts(0), Vectors.dense(parts.tail)) } +val numFeatures = parsedData.take(1)(0).features.size +val datasetInfo = new DatasetInfo(numClasses = 2, numFeatures = numFeatures) // Run training algorithm to build the model -val maxDepth = 5 -val model = DecisionTree.train(parsedData, Classification, Gini, maxDepth) +val dtParams = DecisionTreeClassifier.defaultParams() +dtParams.impurity = "gini" +dtParams.maxDepth = 4 +val model = DecisionTreeClassifier.train(parsedData, datasetInfo, dtParams) + +// Print model in human-readable format. +model.print() // Evaluate model on training examples and compute training error val labelAndPreds = parsedData.map { point => @@ -155,12 +160,10 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.rdd.DatasetInfo +import org.apache.spark.mllib.tree.DecisionTreeRegressor import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.Variance // Load and parse the data file val data = sc.textFile("data/mllib/sample_tree_data.csv") @@ -168,10 +171,17 @@ val parsedData = data.map { line => val parts = line.split(',').map(_.toDouble) LabeledPoint(parts(0), Vectors.dense(parts.tail)) } +val numFeatures = parsedData.take(1)(0).features.size +val datasetInfo = new DatasetInfo(numClasses = 0, numFeatures = numFeatures) // Run training algorithm to build the model -val maxDepth = 5 -val model = DecisionTree.train(parsedData, Regression, Variance, maxDepth) +val dtParams = DecisionTreeRegressor.defaultParams() +dtParams.impurity = "variance" +dtParams.maxDepth = 4 +val model = DecisionTreeRegressor.train(parsedData, datasetInfo, dtParams) + +// Print model in human-readable format. +model.print() // Evaluate model on training examples and compute training error val valuesAndPreds = parsedData.map { point => @@ -179,7 +189,7 @@ val valuesAndPreds = parsedData.map { point => (point.label, prediction) } val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("training Mean Squared Error = " + MSE) +println("Training Mean Squared Error = " + MSE) {% endhighlight %}
From e67ea9c191107a28d5f1ffcc58a24151a3fe3b77 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 29 Jul 2014 10:37:43 -0700 Subject: [PATCH 17/20] Small updates based on @manishamde comments: * Eliminated model type parameter for DecisionTree abstract class. * In DecisionTree, renamed trainSub() to runSub(). * Updated DT*Params to print list of supported parameter options when an invalid one is given. * Made DTParams private[mllib]. --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- .../org/apache/spark/mllib/tree/DecisionTreeClassifier.scala | 5 ++--- .../org/apache/spark/mllib/tree/DecisionTreeRegressor.scala | 5 ++--- .../spark/mllib/tree/configuration/DTClassifierParams.scala | 3 ++- .../org/apache/spark/mllib/tree/configuration/DTParams.scala | 5 +++-- .../spark/mllib/tree/configuration/DTRegressorParams.scala | 3 ++- 6 files changed, 13 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b0fdf8a877aae..79961c381a6b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -37,7 +37,7 @@ import org.apache.spark.util.random.XORShiftRandom * @param params The configuration parameters for the tree algorithm. */ @Experimental -private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTParams) +private[mllib] abstract class DecisionTree (params: DTParams) extends Serializable with Logging { protected final val InvalidBinIndex = -1 @@ -58,7 +58,7 @@ private[mllib] abstract class DecisionTree[M <: DecisionTreeModel] (params: DTPa * @param datasetInfo Dataset metadata. * @return top node of a DecisionTreeModel */ - protected def trainSub( + protected def runSub( input: RDD[LabeledPoint], datasetInfo: DatasetInfo): Node = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala index 5cfe1ebc25c3e..df85a049be55a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala @@ -34,8 +34,7 @@ import org.apache.spark.rdd.RDD * @param params The configuration parameters for the tree algorithm. */ @Experimental -class DecisionTreeClassifier (params: DTClassifierParams) - extends DecisionTree[DecisionTreeClassifierModel](params) { +class DecisionTreeClassifier (params: DTClassifierParams) extends DecisionTree(params) { private val impurityFunctor = ClassificationImpurities.impurity(params.impurity) @@ -52,7 +51,7 @@ class DecisionTreeClassifier (params: DTClassifierParams) require(datasetInfo.isClassification) logDebug("algo = Classification") - val topNode = super.trainSub(input, datasetInfo) + val topNode = super.runSub(input, datasetInfo) new DecisionTreeClassifierModel(topNode) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala index 8ee5ae3b984e4..a2e91c12b1735 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala @@ -34,8 +34,7 @@ import org.apache.spark.rdd.RDD * @param params The configuration parameters for the tree algorithm. */ @Experimental -class DecisionTreeRegressor (params: DTRegressorParams) - extends DecisionTree[DecisionTreeRegressorModel](params) { +class DecisionTreeRegressor (params: DTRegressorParams) extends DecisionTree(params) { private val impurityFunctor = RegressionImpurities.impurity(params.impurity) @@ -52,7 +51,7 @@ class DecisionTreeRegressor (params: DTRegressorParams) require(datasetInfo.isRegression) logDebug("algo = Regression") - val topNode = super.trainSub(input, datasetInfo) + val topNode = super.runSub(input, datasetInfo) new DecisionTreeRegressorModel(topNode) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala index e28f7513d1dd2..99b0c82022303 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -45,7 +45,8 @@ class DTClassifierParams ( def setImpurity(impurity: String) = { if (!ClassificationImpurities.nameToImpurityMap.contains(impurity)) { - throw new IllegalArgumentException(s"Bad impurity parameter for classification: $impurity") + throw new IllegalArgumentException(s"Bad impurity parameter for classification: $impurity" + + s" Supported values: ${DTClassifierParams.supportedImpurities.mkString(", ")}.") } this.impurity = impurity } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala index 9cec12f47c83a..a416a77938abb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala @@ -32,7 +32,7 @@ import org.apache.spark.annotation.Experimental * 128 MB. */ @Experimental -abstract class DTParams ( +private[mllib] abstract class DTParams ( @BeanProperty var maxDepth: Int, @BeanProperty var maxBins: Int, var quantileStrategy: String, @@ -42,7 +42,8 @@ abstract class DTParams ( def setQuantileStrategy(quantileStrategy: String) = { if (!QuantileStrategies.nameToStrategyMap.contains(quantileStrategy)) { - throw new IllegalArgumentException(s"Bad QuantileStrategy parameter: $quantileStrategy") + throw new IllegalArgumentException(s"Bad quantileStrategy parameter: $quantileStrategy." + + s" Supported values: ${DTParams.supportedQuantileStrategies.mkString(", ")}.") } this.quantileStrategy = quantileStrategy } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala index 33a7b100c8696..a1f7b778412d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -45,7 +45,8 @@ class DTRegressorParams ( def setImpurity(impurity: String) = { if (!RegressionImpurities.nameToImpurityMap.contains(impurity)) { - throw new IllegalArgumentException(s"Bad impurity parameter for regression: $impurity") + throw new IllegalArgumentException(s"Bad impurity parameter for regression: $impurity" + + s" Supported values: ${DTRegressorParams.supportedImpurities.mkString(", ")}.") } this.impurity = impurity } From bdc2aa73783de8c312485410e4c4805d413e6c86 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 29 Jul 2014 13:07:49 -0700 Subject: [PATCH 18/20] Changed DecisionTree*Model print() methods to be called toString(). Changed prefix String parameter to indentFactor (Int), following JSON. --- .../model/DecisionTreeClassifierModel.scala | 7 +++-- .../model/DecisionTreeRegressorModel.scala | 8 +++--- .../apache/spark/mllib/tree/model/Node.scala | 27 +++++++++---------- 3 files changed, 19 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala index e5e458ef34daf..21dcd27840a23 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala @@ -30,11 +30,10 @@ import org.apache.spark.annotation.Experimental class DecisionTreeClassifierModel(topNode: Node) extends DecisionTreeModel(topNode) { /** - * Print tree. + * Print full model. */ - def print() { - println(s"DecisionTreeClassifierModel") - topNode.print(" ") + override def toString: String = { + s"DecisionTreeClassifierModel" + topNode.toStringRecursive(2) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala index 5a566ba1be1ce..126100ce0c04e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.Experimental + /** * :: Experimental :: * Decision tree model for regression. @@ -29,11 +30,10 @@ import org.apache.spark.annotation.Experimental class DecisionTreeRegressorModel(topNode: Node) extends DecisionTreeModel(topNode) { /** - * Print tree. + * Print full model. */ - def print() { - println(s"DecisionTreeRegressorModel") - topNode.print(" ") + override def toString: String = { + s"DecisionTreeRegressorModel" + topNode.toStringRecursive(2) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index bd9ef46928714..8cde72bc97a50 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -95,37 +95,34 @@ class Node ( } /** - * Recursive print functions. - * @param prefix Prefix for each printed line (for spacing). + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. */ - def print(prefix: String = "") { + def toStringRecursive(indentFactor: Int = 0): String = { def splitToString(split: Split, left: Boolean) : String = { split.featureType match { - case Continuous => { - if (left) { + case Continuous => if (left) { s"(feature ${split.feature} <= ${split.threshold})" } else { s"(feature ${split.feature} > ${split.threshold})" } - } - case Categorical => { - if (left) { + case Categorical => if (left) { s"(feature ${split.feature} in ${split.categories})" } else { s"(feature ${split.feature} not in ${split.categories})" } - } } } - + val prefix: String = " " * indentFactor if (isLeaf) { - println(prefix + s"Predict: $predict") + prefix + s"Predict: $predict\n" } else { - println(prefix + s"If ${splitToString(split.get, true)}") - leftNode.get.print(prefix + " ") - println(prefix + s"Else ${splitToString(split.get, false)}") - rightNode.get.print(prefix + " ") + prefix + s"If ${splitToString(split.get, left=true)}" + + leftNode.get.toStringRecursive(indentFactor + 1) + + prefix + s"Else ${splitToString(split.get, left=false)}" + + rightNode.get.toStringRecursive(indentFactor + 1) } } + } From 17dcc09cf36c5e65f37d68fc2e8b1828b60c3c9f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 29 Jul 2014 13:20:19 -0700 Subject: [PATCH 19/20] Added @Experimental tags to some Decision Tree objects. Added numNodes, depth methods to DecisionTreeModel, plus test of those in DecisionTreeSuite. --- .../spark/mllib/tree/DecisionTree.scala | 1 + .../mllib/tree/DecisionTreeClassifier.scala | 1 + .../mllib/tree/DecisionTreeRegressor.scala | 1 + .../configuration/DTClassifierParams.scala | 1 + .../mllib/tree/configuration/DTParams.scala | 1 + .../configuration/DTRegressorParams.scala | 1 + .../mllib/tree/model/DecisionTreeModel.scala | 15 ++++++++++++ .../apache/spark/mllib/tree/model/Node.scala | 23 +++++++++++++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 2 ++ 9 files changed, 46 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 79961c381a6b1..c193339591025 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -911,6 +911,7 @@ private[mllib] abstract class DecisionTree (params: DTParams) } +@Experimental object DecisionTree extends Serializable with Logging { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala index df85a049be55a..8099c4023f01b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeClassifier.scala @@ -527,6 +527,7 @@ class DecisionTreeClassifier (params: DTClassifierParams) extends DecisionTree(p } +@Experimental object DecisionTreeClassifier extends Serializable with Logging { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala index a2e91c12b1735..98b6e6dde3894 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRegressor.scala @@ -258,6 +258,7 @@ class DecisionTreeRegressor (params: DTRegressorParams) extends DecisionTree(par } +@Experimental object DecisionTreeRegressor extends Serializable with Logging { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala index 99b0c82022303..eec79b9f89b8c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTClassifierParams.scala @@ -53,6 +53,7 @@ class DTClassifierParams ( } +@Experimental object DTClassifierParams { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala index a416a77938abb..7b3ae5897d2ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTParams.scala @@ -51,6 +51,7 @@ private[mllib] abstract class DTParams ( } +@Experimental object DTParams { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala index a1f7b778412d0..640a6af64b8a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/DTRegressorParams.scala @@ -53,6 +53,7 @@ class DTRegressorParams ( } +@Experimental object DTRegressorParams { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 9c9b22763eedc..5083a4df99935 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -49,4 +49,19 @@ class DecisionTreeModel(val topNode: Node) extends Serializable { features.map(x => predict(x)) } + /** + * Get number of nodes in tree, including leaf nodes. + */ + def numNodes: Int = { + topNode.numNodesRecursive + } + + /** + * Get depth of tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + def depth: Int = { + topNode.depthRecursive + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 8cde72bc97a50..9d1c90575cbd8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -125,4 +125,27 @@ class Node ( } } + /** + * Get number of nodes in tree from this node, including leaf nodes. + */ + def numNodesRecursive: Int = { + if (isLeaf) { + 1 + } else { + 1 + leftNode.get.numNodesRecursive + rightNode.get.numNodesRecursive + } + } + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. + */ + def depthRecursive: Int = { + if (isLeaf) { + 0 + } else { + 1 + math.max(leftNode.get.depthRecursive, rightNode.get.depthRecursive) + } + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index b29dc2370888c..33055cac23f75 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -683,6 +683,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = dtLearner.run(rdd, datasetInfo) validateModel(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) val (splits, bins) = dtLearner.findSplitsBins(rdd, datasetInfo) From d2c1dad75f89db98a693b52f45b7320c4d93ef31 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 29 Jul 2014 13:46:43 -0700 Subject: [PATCH 20/20] Fixed bug in DecisionTreeRunner with old print function name. Added newlines in model toString functions. --- .../org/apache/spark/examples/mllib/DecisionTreeRunner.scala | 4 ++-- .../spark/mllib/tree/model/DecisionTreeClassifierModel.scala | 2 +- .../spark/mllib/tree/model/DecisionTreeRegressorModel.scala | 2 +- .../main/scala/org/apache/spark/mllib/tree/model/Node.scala | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 51f9e6d196414..62d5e4b032887 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -176,7 +176,7 @@ object DecisionTreeRunner { } val dtLearner = new DecisionTreeClassifier(dtParams) val model = dtLearner.run(training, datasetInfo) - model.print() + println(model.toString) val accuracy = accuracyScore(model, test) println(s"Test accuracy = $accuracy") } @@ -189,7 +189,7 @@ object DecisionTreeRunner { } val dtLearner = new DecisionTreeRegressor(dtParams) val model = dtLearner.run(training, datasetInfo) - model.print() + println(model.toString) val mse = meanSquaredError(model, test) println(s"Test mean squared error = $mse") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala index 21dcd27840a23..4dc8661b7f144 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeClassifierModel.scala @@ -33,7 +33,7 @@ class DecisionTreeClassifierModel(topNode: Node) extends DecisionTreeModel(topNo * Print full model. */ override def toString: String = { - s"DecisionTreeClassifierModel" + topNode.toStringRecursive(2) + s"DecisionTreeClassifierModel\n" + topNode.toStringRecursive(2) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala index 126100ce0c04e..ebe4da5a7a81d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeRegressorModel.scala @@ -33,7 +33,7 @@ class DecisionTreeRegressorModel(topNode: Node) extends DecisionTreeModel(topNod * Print full model. */ override def toString: String = { - s"DecisionTreeRegressorModel" + topNode.toStringRecursive(2) + s"DecisionTreeRegressorModel\n" + topNode.toStringRecursive(2) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 9d1c90575cbd8..1b4146fcf9459 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -118,9 +118,9 @@ class Node ( if (isLeaf) { prefix + s"Predict: $predict\n" } else { - prefix + s"If ${splitToString(split.get, left=true)}" + + prefix + s"If ${splitToString(split.get, left=true)}\n" + leftNode.get.toStringRecursive(indentFactor + 1) + - prefix + s"Else ${splitToString(split.get, left=false)}" + + prefix + s"Else ${splitToString(split.get, left=false)}\n" + rightNode.get.toStringRecursive(indentFactor + 1) } }